suncongcong commited on
Commit
ae5af4b
·
verified ·
1 Parent(s): e774921

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -26
app.py CHANGED
@@ -20,7 +20,7 @@ MODEL_IDS = {
20
  EXAMPLE_IMAGES = {
21
  "去雨痕 (Derain)": [["derain_example1.png"], ["derain_example2.png"], ["derain_example3.png"]],
22
  "去雨滴 (Deraindrop)": [["deraindrop_example1.png"], ["deraindrop_example2.png"], ["deraindrop_example3.png"]],
23
- "去雾 (Dehaze)": [["dehaze_example1.jpg"],["dehaze_example2.jpg"],["dehaze_example3.jpg"]]
24
  }
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  print(f"正在使用的设备: {device}")
@@ -45,41 +45,68 @@ except Exception as e:
45
  MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
46
  print("所有模型加载完毕,准备就绪!")
47
 
48
-
49
- # --- 3. 定义统一的处理函数 ---
50
- def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
51
- if input_image is None: return None
52
-
53
- model = MODELS[task_name]
54
- if isinstance(model, str): raise gr.Error(model)
55
-
56
- print(f"执行任务: {task_name}, 使用 Pad-to-Square 策略")
57
- img_tensor = to_tensor(input_image.convert("RGB")).unsqueeze(0)
58
-
59
- # Pad-to-Square 逻辑
60
  def expand2square(timg, factor=128.0):
61
- _, _, h, w = timg.size(); X = int(math.ceil(max(h, w) / factor) * factor)
 
62
  img_padded = torch.zeros(1, 3, X, X).type_as(timg)
63
  mask = torch.zeros(1, 1, X, X).type_as(timg)
64
  img_padded[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)] = timg
65
  mask[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)].fill_(1)
66
  return img_padded, mask
67
-
68
  original_h, original_w = img_tensor.shape[2:]
69
  padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
70
-
71
  with torch.no_grad():
72
  restored_padded = model(padded_input)
73
-
74
  restored_tensor = torch.masked_select(
75
  restored_padded, mask.bool()
76
  ).reshape(1, 3, original_h, original_w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
 
 
 
 
 
78
  restored_tensor = torch.clamp(restored_tensor, 0, 1)
79
- return to_pil_image(restored_tensor.cpu().squeeze(0))
80
-
81
 
82
- # --- 4. 创建并启动 Gradio 界面 ---
83
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
84
  gr.Markdown(
85
  """
@@ -94,19 +121,24 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
94
  input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
95
  output_img = gr.Image(type="pil", label="输出图片 (Output)")
96
 
 
97
  submit_btn = gr.Button("开始处理 (Process)", variant="primary")
98
 
99
- # 使用 gr.State 来传递 task_name,这是更稳健的方式
100
- task_state = gr.State(task_name)
101
-
102
- submit_btn.click(fn=process_image, inputs=[input_img, task_state], outputs=output_img)
103
 
 
 
 
 
 
 
104
  if EXAMPLE_IMAGES.get(task_name):
105
  gr.Examples(
106
  examples=EXAMPLE_IMAGES.get(task_name, []),
107
- inputs=[input_img],
108
  outputs=output_img,
109
- fn=lambda img: process_image(img, task_name=task_name),
110
  cache_examples=True
111
  )
112
 
 
20
  EXAMPLE_IMAGES = {
21
  "去雨痕 (Derain)": [["derain_example1.png"], ["derain_example2.png"], ["derain_example3.png"]],
22
  "去雨滴 (Deraindrop)": [["deraindrop_example1.png"], ["deraindrop_example2.png"], ["deraindrop_example3.png"]],
23
+ "去雾 (Dehaze)": [["dehaze_example1.jpg"],["dehaze_example2.jpg"],["dehaze_example3.jpg"]]
24
  }
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  print(f"正在使用的设备: {device}")
 
45
  MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
46
  print("所有模型加载完毕,准备就绪!")
47
 
48
+ # --- 3. 定义不同任务的处理函数 ---
49
+ def process_with_pad_to_square(model, img_tensor):
 
 
 
 
 
 
 
 
 
 
50
  def expand2square(timg, factor=128.0):
51
+ _, _, h, w = timg.size()
52
+ X = int(math.ceil(max(h, w) / factor) * factor)
53
  img_padded = torch.zeros(1, 3, X, X).type_as(timg)
54
  mask = torch.zeros(1, 1, X, X).type_as(timg)
55
  img_padded[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)] = timg
56
  mask[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)].fill_(1)
57
  return img_padded, mask
 
58
  original_h, original_w = img_tensor.shape[2:]
59
  padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
 
60
  with torch.no_grad():
61
  restored_padded = model(padded_input)
 
62
  restored_tensor = torch.masked_select(
63
  restored_padded, mask.bool()
64
  ).reshape(1, 3, original_h, original_w)
65
+ return restored_tensor
66
+
67
+ def process_with_dehaze_tiling(model, img_tensor, progress):
68
+ CROP_SIZE, OVERLAP = 1152, 384
69
+ b, c, h_orig, w_orig = img_tensor.shape
70
+ stride = CROP_SIZE - OVERLAP
71
+ h_pad = (stride - (h_orig - OVERLAP) % stride) % stride
72
+ w_pad = (stride - (w_orig - OVERLAP) % stride) % stride
73
+ img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
74
+ _, _, h_padded, w_padded = img_padded.shape
75
+ output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
76
+ weight_map = torch.zeros_like(output_canvas)
77
+ h_steps = len(range(0, h_padded - OVERLAP, stride)) if h_padded > OVERLAP else 1
78
+ w_steps = len(range(0, w_padded - OVERLAP, stride)) if w_padded > OVERLAP else 1
79
+ pbar = tqdm(total=h_steps * w_steps, desc=f"正在执行去雾...")
80
+ for y in range(0, h_padded - OVERLAP, stride) if h_padded > OVERLAP else [0]:
81
+ for x in range(0, w_padded - OVERLAP, stride) if w_padded > OVERLAP else [0]:
82
+ patch_in = img_padded[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE]
83
+ with torch.no_grad():
84
+ patch_out = model(patch_in.to(device)).cpu()
85
+ output_canvas[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += patch_out
86
+ weight_map[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += 1
87
+ pbar.update(1)
88
+ pbar.close()
89
+ restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
90
+ return restored_padded_tensor[:, :, :h_orig, :w_orig]
91
+
92
+ def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
93
+ if input_image is None: return None
94
+ model = MODELS[task_name]
95
+ print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
96
+ if not isinstance(model, torch.nn.Module): model()
97
+ img = input_image.convert("RGB")
98
+ img_tensor = to_tensor(img).unsqueeze(0)
99
 
100
+ # 关键修正:在 process_image 函数内部也进行判断
101
+ if task_name == "去雾 (Dehaze)":
102
+ restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress)
103
+ else:
104
+ restored_tensor = process_with_pad_to_square(model, img_tensor)
105
+
106
  restored_tensor = torch.clamp(restored_tensor, 0, 1)
107
+ restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
108
+ return restored_image
109
 
 
110
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
111
  gr.Markdown(
112
  """
 
121
  input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
122
  output_img = gr.Image(type="pil", label="输出图片 (Output)")
123
 
124
+ task_id_box = gr.Textbox(task_name, visible=False)
125
  submit_btn = gr.Button("开始处理 (Process)", variant="primary")
126
 
127
+ # “提交”按钮的点击事件,保持不变
128
+ submit_btn.click(fn=process_image, inputs=[input_img, task_id_box], outputs=output_img)
 
 
129
 
130
+ # --- 最终修正 ---
131
+ # 重新构造 lambda 函数,确保它总是传递正确的 task_name
132
+ # 我们不再依赖外部的 task_name 变量,而是直接使用在循环中定义的那个
133
+ def create_example_fn(current_task_name):
134
+ return lambda img, prog: process_image(img, current_task_name, prog)
135
+
136
  if EXAMPLE_IMAGES.get(task_name):
137
  gr.Examples(
138
  examples=EXAMPLE_IMAGES.get(task_name, []),
139
+ inputs=input_img,
140
  outputs=output_img,
141
+ fn=create_example_fn(task_name), # 关键:为每个循环的 task_name 创建一个独立的函数
142
  cache_examples=True
143
  )
144