suncongcong commited on
Commit
7ec5312
·
verified ·
1 Parent(s): ae5af4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -58
app.py CHANGED
@@ -11,7 +11,7 @@ from torchvision.transforms.functional import to_pil_image, to_tensor
11
  from tqdm import tqdm
12
  import math
13
 
14
- # --- 1. 配置 ---
15
  MODEL_IDS = {
16
  "去雨痕 (Derain)": "Suncongcong/AST_DeRain",
17
  "去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
@@ -25,7 +25,7 @@ EXAMPLE_IMAGES = {
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  print(f"正在使用的设备: {device}")
27
 
28
- # --- 2. 加载所有模型和处理器 ---
29
  MODELS = {}
30
  PROCESSOR = None
31
  print("正在加载所有模型和处理器...")
@@ -45,67 +45,141 @@ except Exception as e:
45
  MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
46
  print("所有模型加载完毕,准备就绪!")
47
 
 
48
  # --- 3. 定义不同任务的处理函数 ---
 
49
  def process_with_pad_to_square(model, img_tensor):
 
50
  def expand2square(timg, factor=128.0):
 
 
51
  _, _, h, w = timg.size()
52
  X = int(math.ceil(max(h, w) / factor) * factor)
53
- img_padded = torch.zeros(1, 3, X, X).type_as(timg)
54
- mask = torch.zeros(1, 1, X, X).type_as(timg)
55
- img_padded[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)] = timg
56
- mask[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)].fill_(1)
 
 
 
57
  return img_padded, mask
 
58
  original_h, original_w = img_tensor.shape[2:]
59
  padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
 
60
  with torch.no_grad():
61
  restored_padded = model(padded_input)
 
 
 
 
62
  restored_tensor = torch.masked_select(
63
- restored_padded, mask.bool()
64
  ).reshape(1, 3, original_h, original_w)
 
65
  return restored_tensor
66
 
 
67
  def process_with_dehaze_tiling(model, img_tensor, progress):
68
- CROP_SIZE, OVERLAP = 1152, 384
 
 
 
 
69
  b, c, h_orig, w_orig = img_tensor.shape
70
  stride = CROP_SIZE - OVERLAP
71
- h_pad = (stride - (h_orig - OVERLAP) % stride) % stride
72
- w_pad = (stride - (w_orig - OVERLAP) % stride) % stride
 
 
 
73
  img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
74
- _, _, h_padded, w_padded = img_padded.shape
 
 
75
  output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
76
  weight_map = torch.zeros_like(output_canvas)
77
- h_steps = len(range(0, h_padded - OVERLAP, stride)) if h_padded > OVERLAP else 1
78
- w_steps = len(range(0, w_padded - OVERLAP, stride)) if w_padded > OVERLAP else 1
79
- pbar = tqdm(total=h_steps * w_steps, desc=f"正在执行去雾...")
80
- for y in range(0, h_padded - OVERLAP, stride) if h_padded > OVERLAP else [0]:
81
- for x in range(0, w_padded - OVERLAP, stride) if w_padded > OVERLAP else [0]:
 
 
 
 
82
  patch_in = img_padded[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE]
83
  with torch.no_grad():
84
  patch_out = model(patch_in.to(device)).cpu()
 
85
  output_canvas[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += patch_out
86
  weight_map[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += 1
87
- pbar.update(1)
88
- pbar.close()
89
  restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
90
  return restored_padded_tensor[:, :, :h_orig, :w_orig]
91
 
 
92
  def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
93
- if input_image is None: return None
94
- model = MODELS[task_name]
95
- print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
96
- if not isinstance(model, torch.nn.Module): model()
97
- img = input_image.convert("RGB")
98
- img_tensor = to_tensor(img).unsqueeze(0)
99
-
100
- # 关键修正:在 process_image 函数内部也进行判断
101
- if task_name == "去雾 (Dehaze)":
102
- restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress)
103
- else:
104
- restored_tensor = process_with_pad_to_square(model, img_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- restored_tensor = torch.clamp(restored_tensor, 0, 1)
107
- restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
108
- return restored_image
 
 
 
 
 
109
 
110
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
111
  gr.Markdown(
@@ -116,30 +190,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
116
  )
117
  with gr.Tabs():
118
  for task_name in MODEL_IDS.keys():
119
- with gr.TabItem(task_name, id=task_name):
120
- with gr.Row():
121
- input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
122
- output_img = gr.Image(type="pil", label="输出图片 (Output)")
123
-
124
- task_id_box = gr.Textbox(task_name, visible=False)
125
- submit_btn = gr.Button("开始处理 (Process)", variant="primary")
126
-
127
- # “提交”按钮的点击事件,保持不变
128
- submit_btn.click(fn=process_image, inputs=[input_img, task_id_box], outputs=output_img)
129
-
130
- # --- 最终修正 ---
131
- # 重新构造 lambda 函数,确保它总是传递正确的 task_name
132
- # 我们不再依赖外部的 task_name 变量,而是直接使用在循环中定义的那个
133
- def create_example_fn(current_task_name):
134
- return lambda img, prog: process_image(img, current_task_name, prog)
135
-
136
- if EXAMPLE_IMAGES.get(task_name):
137
- gr.Examples(
138
- examples=EXAMPLE_IMAGES.get(task_name, []),
139
- inputs=input_img,
140
- outputs=output_img,
141
- fn=create_example_fn(task_name), # 关键:为每个循环的 task_name 创建一个独立的函数
142
- cache_examples=True
143
- )
144
 
145
  demo.launch(server_name="0.0.0.0")
 
11
  from tqdm import tqdm
12
  import math
13
 
14
+ # --- 1. 配置 (无变化) ---
15
  MODEL_IDS = {
16
  "去雨痕 (Derain)": "Suncongcong/AST_DeRain",
17
  "去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  print(f"正在使用的设备: {device}")
27
 
28
+ # --- 2. 加载所有模型和处理器 (无变化) ---
29
  MODELS = {}
30
  PROCESSOR = None
31
  print("正在加载所有模型和处理器...")
 
45
  MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
46
  print("所有模型加载完毕,准备就绪!")
47
 
48
+
49
  # --- 3. 定义不同任务的处理函数 ---
50
+
51
  def process_with_pad_to_square(model, img_tensor):
52
+ """将图片填充为正方形后进行处理,适用于去雨/去雨滴任务。"""
53
  def expand2square(timg, factor=128.0):
54
+ # ✨ 优化点: 增加注释,解释 factor 的作用
55
+ # factor: 模型的网络结构要求输入的尺寸最好是该值的整数倍
56
  _, _, h, w = timg.size()
57
  X = int(math.ceil(max(h, w) / factor) * factor)
58
+ img_padded = torch.zeros(1, 3, X, X, device=timg.device, dtype=timg.dtype)
59
+ mask = torch.zeros(1, 1, X, X, device=timg.device, dtype=timg.dtype)
60
+
61
+ pad_h = (X - h) // 2
62
+ pad_w = (X - w) // 2
63
+ img_padded[:, :, pad_h:pad_h + h, pad_w:pad_w + w] = timg
64
+ mask[:, :, pad_h:pad_h + h, pad_w:pad_w + w].fill_(1)
65
  return img_padded, mask
66
+
67
  original_h, original_w = img_tensor.shape[2:]
68
  padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
69
+
70
  with torch.no_grad():
71
  restored_padded = model(padded_input)
72
+
73
+ # ✨ 优化点: 确保 mask 和 restored_padded 在同一设备上
74
+ mask_bool = mask.bool().to(restored_padded.device)
75
+
76
  restored_tensor = torch.masked_select(
77
+ restored_padded, mask_bool
78
  ).reshape(1, 3, original_h, original_w)
79
+
80
  return restored_tensor
81
 
82
+
83
  def process_with_dehaze_tiling(model, img_tensor, progress):
84
+ """使用重叠分块策略处理图像,适用于去雾任务。"""
85
+ # ✨ 优化点: 将“魔法数字”定义为常量并添加注释
86
+ CROP_SIZE = 1152 # 每个图块的尺寸
87
+ OVERLAP = 384 # 图块之间的重叠区域大小,以避免边缘效应
88
+
89
  b, c, h_orig, w_orig = img_tensor.shape
90
  stride = CROP_SIZE - OVERLAP
91
+
92
+ # 计算需要填充的尺寸
93
+ h_pad = (stride - (h_orig - OVERLAP) % stride) % stride if h_orig > CROP_SIZE else 0
94
+ w_pad = (stride - (w_orig - OVERLAP) % stride) % stride if w_orig > CROP_SIZE else 0
95
+
96
  img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
97
+ b, c, h_padded, w_padded = img_padded.shape
98
+
99
+ # 使用CPU来存储最终结果,避免占用大量显存
100
  output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
101
  weight_map = torch.zeros_like(output_canvas)
102
+
103
+ h_steps_range = range(0, h_padded - CROP_SIZE + 1, stride) if h_padded > CROP_SIZE else [0]
104
+ w_steps_range = range(0, w_padded - CROP_SIZE + 1, stride) if w_padded > CROP_SIZE else [0]
105
+
106
+ total_steps = len(h_steps_range) * len(w_steps_range)
107
+
108
+ # ✨ 优化点: 使用Gradio的进度条,而不是手动的tqdm
109
+ for y in progress.tqdm(h_steps_range, desc="正在分块去雾..."):
110
+ for x in w_steps_range:
111
  patch_in = img_padded[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE]
112
  with torch.no_grad():
113
  patch_out = model(patch_in.to(device)).cpu()
114
+
115
  output_canvas[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += patch_out
116
  weight_map[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += 1
117
+
 
118
  restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
119
  return restored_padded_tensor[:, :, :h_orig, :w_orig]
120
 
121
+
122
  def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
123
+ """主处理函数,根据任务名分派到不同的处理流程。"""
124
+ if input_image is None:
125
+ gr.Warning("请输入一张图片!")
126
+ return None
127
+
128
+ # 优化点: 增加完整的运行时错误捕获
129
+ try:
130
+ model = MODELS[task_name]
131
+ print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
132
+
133
+ # 检查模型是否加载成功
134
+ if not isinstance(model, torch.nn.Module):
135
+ model() # 如果加载失败,这里会触发 load_error_func 并抛出异常
136
+
137
+ img = input_image.convert("RGB")
138
+ img_tensor = to_tensor(img).unsqueeze(0)
139
+
140
+ if task_name == "去雾 (Dehaze)":
141
+ restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress)
142
+ else:
143
+ restored_tensor = process_with_pad_to_square(model, img_tensor)
144
+
145
+ restored_tensor = torch.clamp(restored_tensor, 0, 1)
146
+ restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
147
+ return restored_image
148
+
149
+ except Exception as e:
150
+ print(f"处理图片时发生错误: {e}")
151
+ # 在UI上给用户一个清晰的错误提示
152
+ gr.Error(f"处理失败!错误信息: {e}")
153
+ # 返回原始图像,而不是空着或保留上一次的结果
154
+ return input_image
155
+
156
+
157
+ # --- 4. Gradio UI ---
158
+
159
+ # ✨ 优化点 3: 将创建Tab的逻辑封装成函数,使UI代码更干净
160
+ def create_task_tab(task_name: str):
161
+ with gr.TabItem(task_name, id=task_name):
162
+ with gr.Row():
163
+ input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
164
+ output_img = gr.Image(type="pil", label="输出图片 (Output)")
165
+
166
+ submit_btn = gr.Button("开始处理 (Process)", variant="primary")
167
+
168
+ # ✨ 优化点 1: 创建一个处理函数,它已经“知道”自己的任务名
169
+ # 这样就不再需要隐藏的 Textbox 来传递 task_name
170
+ def specific_process_fn(img, prog):
171
+ return process_image(img, task_name, prog)
172
+
173
+ submit_btn.click(fn=specific_process_fn, inputs=[input_img], outputs=output_img)
174
 
175
+ if EXAMPLE_IMAGES.get(task_name):
176
+ gr.Examples(
177
+ examples=EXAMPLE_IMAGES.get(task_name, []),
178
+ inputs=input_img,
179
+ outputs=output_img,
180
+ fn=specific_process_fn, # 复用上面为按钮创建的处理函数
181
+ cache_examples=True,
182
+ )
183
 
184
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
185
  gr.Markdown(
 
190
  )
191
  with gr.Tabs():
192
  for task_name in MODEL_IDS.keys():
193
+ create_task_tab(task_name) # 调用函数创建每个Tab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  demo.launch(server_name="0.0.0.0")