suncongcong commited on
Commit
8525111
·
verified ·
1 Parent(s): 3e22329

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -31
app.py CHANGED
@@ -11,31 +11,62 @@ from torchvision.transforms.functional import to_pil_image, to_tensor
11
  from tqdm import tqdm
12
 
13
  # --- 1. 配置 ---
14
- repo_id = "suncongcong/AST_DeRain"
 
 
 
 
 
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  PATCH_SIZE = 256
17
  OVERLAP = 64
18
 
19
  print(f"正在使用的设备: {device}")
20
 
21
- # --- 2. 加载模型和处理器 ---
22
- print(f"正在从 '{repo_id}' 加载模型和处理器...")
23
- processor = CLIPImageProcessor.from_pretrained(repo_id)
24
- processor.size = {"height": 256, "width": 256}
25
- processor.crop_size = {"height": 256, "width": 256}
26
- print(f"图像处理器尺寸已强制设置为: {processor.size}")
27
- model = ASTForRestoration.from_pretrained(
28
- repo_id,
29
- trust_remote_code=True
30
- ).to(device).eval()
31
- print("✅ 模型加载成功,准备就绪!")
32
-
33
-
34
- # --- 3. 定义“裁切-推理-合并”的核心处理函数 ---
35
- def derain_image_Tiled(input_image: Image.Image, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  if input_image is None:
37
  return None
38
-
 
 
 
 
 
 
 
 
39
  img = input_image.convert("RGB")
40
  img_tensor = to_tensor(img).unsqueeze(0).to(device)
41
  b, c, h, w = img_tensor.shape
@@ -49,7 +80,7 @@ def derain_image_Tiled(input_image: Image.Image, progress=gr.Progress(track_tqdm
49
  w_steps = len(range(0, w, stride))
50
  total_patches = h_steps * w_steps
51
 
52
- pbar = tqdm(total=total_patches, desc="正在处理图像块...")
53
 
54
  for y in range(0, h, stride):
55
  for x in range(0, w, stride):
@@ -61,7 +92,7 @@ def derain_image_Tiled(input_image: Image.Image, progress=gr.Progress(track_tqdm
61
  pad_h = PATCH_SIZE - ph
62
  pad_w = PATCH_SIZE - pw
63
  if pad_h > 0 or pad_w > 0:
64
- patch_padded = F.pad(patch_in, (0, pad_w, 0, pad_h), 'replicate') # <-- 最终修正
65
  else:
66
  patch_padded = patch_in
67
 
@@ -81,26 +112,35 @@ def derain_image_Tiled(input_image: Image.Image, progress=gr.Progress(track_tqdm
81
  pbar.close()
82
 
83
  restored_tensor = output_canvas / weight_map
84
-
85
  restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
86
 
87
  return restored_image
88
 
89
- # --- 4. 创建并启动 Gradio 界面 ---
90
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
91
  gr.Markdown(
92
  """
93
- #AST 图像去雨模型在线演示
94
- 上传任意尺寸带雨图片,模型将会分块处理并拼接成完整的高清输出
95
- 模型仓库地址: [suncongcong/AST_DeRain](https://huggingface.co/suncongcong/AST_DeRain)
96
  """
97
  )
98
- with gr.Row():
99
- input_img = gr.Image(type="pil", label="输入带雨图片 (Input Rainy Image)")
100
- output_img = gr.Image(type="pil", label="输出清晰图片 (Output Deraided Image)")
101
-
102
- submit_btn = gr.Button("开始去雨 (Start Deraining)", variant="primary")
103
-
104
- submit_btn.click(fn=derain_image_Tiled, inputs=input_img, outputs=output_img)
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  demo.launch(server_name="0.0.0.0")
 
11
  from tqdm import tqdm
12
 
13
  # --- 1. 配置 ---
14
+ # 使用您提供的准确的模型仓库ID
15
+ MODEL_IDS = {
16
+ "去雨痕 (Derain)": "Suncongcong/AST_DeRain",
17
+ "去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
18
+ "去雾 (Dehaze)": "Suncongcong/AST_Dehazing"
19
+ }
20
+
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  PATCH_SIZE = 256
23
  OVERLAP = 64
24
 
25
  print(f"正在使用的设备: {device}")
26
 
27
+ # --- 2. 加载所有模型和处理器 ---
28
+ MODELS = {}
29
+ PROCESSOR = None
30
+
31
+ print("正在加载所有模型和处理器...")
32
+ # 使用 try-except 来增加鲁棒性
33
+ try:
34
+ for task_name, repo_id in MODEL_IDS.items():
35
+ print(f"正在加载模型: {task_name} ({repo_id})")
36
+ if PROCESSOR is None:
37
+ PROCESSOR = CLIPImageProcessor.from_pretrained(repo_id)
38
+ print("✅ 处理器加载成功。")
39
+
40
+ model = ASTForRestoration.from_pretrained(
41
+ repo_id,
42
+ trust_remote_code=True
43
+ ).to(device).eval()
44
+ MODELS[task_name] = model
45
+ print(f"✅ 模型 '{task_name}' 加载成功。")
46
+ except Exception as e:
47
+ print(f"加载模型时出错: {e}")
48
+ # 创建一个占位符函数,以便在模型加载失败时 Gradio 仍能启动并显示错误
49
+ def load_error_func(*args, **kwargs):
50
+ raise gr.Error(f"模型加载失败! 错误: {e}")
51
+ MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
52
+
53
+
54
+ print("所有模型加载完毕,准备就绪!")
55
+
56
+
57
+ # --- 3. 定义统一的、可选择模型的处理函数 ---
58
+ def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
59
  if input_image is None:
60
  return None
61
+
62
+ # 根据传入的任务名称,选择对应的模型
63
+ model = MODELS[task_name]
64
+ print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
65
+
66
+ # 检查模型是否加载成功
67
+ if not isinstance(model, torch.nn.Module):
68
+ model() # 这会触发上面定义的错误函数
69
+
70
  img = input_image.convert("RGB")
71
  img_tensor = to_tensor(img).unsqueeze(0).to(device)
72
  b, c, h, w = img_tensor.shape
 
80
  w_steps = len(range(0, w, stride))
81
  total_patches = h_steps * w_steps
82
 
83
+ pbar = tqdm(total=total_patches, desc=f"正在执行 {task_name}...")
84
 
85
  for y in range(0, h, stride):
86
  for x in range(0, w, stride):
 
92
  pad_h = PATCH_SIZE - ph
93
  pad_w = PATCH_SIZE - pw
94
  if pad_h > 0 or pad_w > 0:
95
+ patch_padded = F.pad(patch_in, (0, pad_w, 0, pad_h), 'replicate')
96
  else:
97
  patch_padded = patch_in
98
 
 
112
  pbar.close()
113
 
114
  restored_tensor = output_canvas / weight_map
 
115
  restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
116
 
117
  return restored_image
118
 
119
+ # --- 4. 创建并启动带选项卡的 Gradio 界面 ---
120
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
121
  gr.Markdown(
122
  """
123
+ # 🖼️ 多功能图像复原工具 (AST 模型)
124
+ 请选择一个任务,然后上传对应的图片进行处理。
 
125
  """
126
  )
 
 
 
 
 
 
 
127
 
128
+ with gr.Tabs():
129
+ # 根据 MODEL_IDS 字典自动创建选项卡
130
+ for task_name in MODEL_IDS.keys():
131
+ with gr.TabItem(task_name, id=task_name):
132
+ with gr.Row():
133
+ input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
134
+ output_img = gr.Image(type="pil", label="输出图片 (Output)")
135
+
136
+ task_id_box = gr.Textbox(task_name, visible=False)
137
+
138
+ submit_btn = gr.Button("开始处理 (Process)", variant="primary")
139
+
140
+ submit_btn.click(
141
+ fn=process_image,
142
+ inputs=[input_img, task_id_box],
143
+ outputs=output_img
144
+ )
145
+
146
  demo.launch(server_name="0.0.0.0")