suncongcong commited on
Commit
9367fe4
·
verified ·
1 Parent(s): 7ec5312

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -23
app.py CHANGED
@@ -11,7 +11,7 @@ from torchvision.transforms.functional import to_pil_image, to_tensor
11
  from tqdm import tqdm
12
  import math
13
 
14
- # --- 1. 配置 (无变化) ---
15
  MODEL_IDS = {
16
  "去雨痕 (Derain)": "Suncongcong/AST_DeRain",
17
  "去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
@@ -25,7 +25,7 @@ EXAMPLE_IMAGES = {
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  print(f"正在使用的设备: {device}")
27
 
28
- # --- 2. 加载所有模型和处理器 (无变化) ---
29
  MODELS = {}
30
  PROCESSOR = None
31
  print("正在加载所有模型和处理器...")
@@ -51,10 +51,10 @@ print("所有模型加载完毕,准备就绪!")
51
  def process_with_pad_to_square(model, img_tensor):
52
  """将图片填充为正方形后进行处理,适用于去雨/去雨滴任务。"""
53
  def expand2square(timg, factor=128.0):
54
- # ✨ 优化点: 增加注释,解释 factor 的作用
55
  # factor: 模型的网络结构要求输入的尺寸最好是该值的整数倍
56
  _, _, h, w = timg.size()
57
  X = int(math.ceil(max(h, w) / factor) * factor)
 
58
  img_padded = torch.zeros(1, 3, X, X, device=timg.device, dtype=timg.dtype)
59
  mask = torch.zeros(1, 1, X, X, device=timg.device, dtype=timg.dtype)
60
 
@@ -70,7 +70,7 @@ def process_with_pad_to_square(model, img_tensor):
70
  with torch.no_grad():
71
  restored_padded = model(padded_input)
72
 
73
- # ✨ 优化点: 确保 mask 和 restored_padded 在同一设备上
74
  mask_bool = mask.bool().to(restored_padded.device)
75
 
76
  restored_tensor = torch.masked_select(
@@ -82,7 +82,7 @@ def process_with_pad_to_square(model, img_tensor):
82
 
83
  def process_with_dehaze_tiling(model, img_tensor, progress):
84
  """使用重叠分块策略处理图像,适用于去雾任务。"""
85
- # ✨ 优化点: 将“魔法数字”定义为常量并添加注释
86
  CROP_SIZE = 1152 # 每个图块的尺寸
87
  OVERLAP = 384 # 图块之间的重叠区域大小,以避免边缘效应
88
 
@@ -90,8 +90,8 @@ def process_with_dehaze_tiling(model, img_tensor, progress):
90
  stride = CROP_SIZE - OVERLAP
91
 
92
  # 计算需要填充的尺寸
93
- h_pad = (stride - (h_orig - OVERLAP) % stride) % stride if h_orig > CROP_SIZE else 0
94
- w_pad = (stride - (w_orig - OVERLAP) % stride) % stride if w_orig > CROP_SIZE else 0
95
 
96
  img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
97
  b, c, h_padded, w_padded = img_padded.shape
@@ -100,20 +100,22 @@ def process_with_dehaze_tiling(model, img_tensor, progress):
100
  output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
101
  weight_map = torch.zeros_like(output_canvas)
102
 
103
- h_steps_range = range(0, h_padded - CROP_SIZE + 1, stride) if h_padded > CROP_SIZE else [0]
104
- w_steps_range = range(0, w_padded - CROP_SIZE + 1, stride) if w_padded > CROP_SIZE else [0]
105
 
106
- total_steps = len(h_steps_range) * len(w_steps_range)
107
-
108
- # ✨ 优化点: 使用Gradio的进度条,而不是手动的tqdm
109
  for y in progress.tqdm(h_steps_range, desc="正在分块去雾..."):
110
  for x in w_steps_range:
111
- patch_in = img_padded[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE]
 
 
 
 
112
  with torch.no_grad():
113
  patch_out = model(patch_in.to(device)).cpu()
114
 
115
- output_canvas[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += patch_out
116
- weight_map[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += 1
117
 
118
  restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
119
  return restored_padded_tensor[:, :, :h_orig, :w_orig]
@@ -125,7 +127,7 @@ def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress
125
  gr.Warning("请输入一张图片!")
126
  return None
127
 
128
- # ✨ 优化点: 增加完整的运行时错误捕获
129
  try:
130
  model = MODELS[task_name]
131
  print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
@@ -156,8 +158,8 @@ def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress
156
 
157
  # --- 4. Gradio UI ---
158
 
159
- # ✨ 优化点 3: 将创建Tab的逻辑封装成函数,使UI代码更干净
160
  def create_task_tab(task_name: str):
 
161
  with gr.TabItem(task_name, id=task_name):
162
  with gr.Row():
163
  input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
@@ -165,11 +167,14 @@ def create_task_tab(task_name: str):
165
 
166
  submit_btn = gr.Button("开始处理 (Process)", variant="primary")
167
 
168
- # ✨ 优化点 1: 创建一个处理函数,它已经“知道”自己的任务名
169
- # 这样就不再需要隐藏的 Textbox 来传递 task_name
170
- def specific_process_fn(img, prog):
171
- return process_image(img, task_name, prog)
172
-
 
 
 
173
  submit_btn.click(fn=specific_process_fn, inputs=[input_img], outputs=output_img)
174
 
175
  if EXAMPLE_IMAGES.get(task_name):
@@ -181,6 +186,7 @@ def create_task_tab(task_name: str):
181
  cache_examples=True,
182
  )
183
 
 
184
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
185
  gr.Markdown(
186
  """
@@ -190,6 +196,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
190
  )
191
  with gr.Tabs():
192
  for task_name in MODEL_IDS.keys():
193
- create_task_tab(task_name) # 调用函数创建每个Tab
194
 
 
195
  demo.launch(server_name="0.0.0.0")
 
11
  from tqdm import tqdm
12
  import math
13
 
14
+ # --- 1. 配置 ---
15
  MODEL_IDS = {
16
  "去雨痕 (Derain)": "Suncongcong/AST_DeRain",
17
  "去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  print(f"正在使用的设备: {device}")
27
 
28
+ # --- 2. 加载所有模型和处理器 ---
29
  MODELS = {}
30
  PROCESSOR = None
31
  print("正在加载所有模型和处理器...")
 
51
  def process_with_pad_to_square(model, img_tensor):
52
  """将图片填充为正方形后进行处理,适用于去雨/去雨滴任务。"""
53
  def expand2square(timg, factor=128.0):
 
54
  # factor: 模型的网络结构要求输入的尺寸最好是该值的整数倍
55
  _, _, h, w = timg.size()
56
  X = int(math.ceil(max(h, w) / factor) * factor)
57
+ # 确保创建的张量在正确的设备上
58
  img_padded = torch.zeros(1, 3, X, X, device=timg.device, dtype=timg.dtype)
59
  mask = torch.zeros(1, 1, X, X, device=timg.device, dtype=timg.dtype)
60
 
 
70
  with torch.no_grad():
71
  restored_padded = model(padded_input)
72
 
73
+ # 确保 mask 和 restored_padded 在同一设备上
74
  mask_bool = mask.bool().to(restored_padded.device)
75
 
76
  restored_tensor = torch.masked_select(
 
82
 
83
  def process_with_dehaze_tiling(model, img_tensor, progress):
84
  """使用重叠分块策略处理图像,适用于去雾任务。"""
85
+ # 将“魔法数字”定义为常量并添加注释
86
  CROP_SIZE = 1152 # 每个图块的尺寸
87
  OVERLAP = 384 # 图块之间的重叠区域大小,以避免边缘效应
88
 
 
90
  stride = CROP_SIZE - OVERLAP
91
 
92
  # 计算需要填充的尺寸
93
+ h_pad = (stride - (h_orig - OVERLAP) % stride) % stride if h_orig > OVERLAP else 0
94
+ w_pad = (stride - (w_orig - OVERLAP) % stride) % stride if w_orig > OVERLAP else 0
95
 
96
  img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
97
  b, c, h_padded, w_padded = img_padded.shape
 
100
  output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
101
  weight_map = torch.zeros_like(output_canvas)
102
 
103
+ h_steps_range = range(0, h_padded - OVERLAP, stride) if h_padded > OVERLAP else [0]
104
+ w_steps_range = range(0, w_padded - OVERLAP, stride) if w_padded > OVERLAP else [0]
105
 
106
+ # 使用Gradio的进度条
 
 
107
  for y in progress.tqdm(h_steps_range, desc="正在分块去雾..."):
108
  for x in w_steps_range:
109
+ # 确保切片范围正确
110
+ y_end = min(y + CROP_SIZE, h_padded)
111
+ x_end = min(x + CROP_SIZE, w_padded)
112
+ patch_in = img_padded[:, :, y:y_end, x:x_end]
113
+
114
  with torch.no_grad():
115
  patch_out = model(patch_in.to(device)).cpu()
116
 
117
+ output_canvas[:, :, y:y_end, x:x_end] += patch_out
118
+ weight_map[:, :, y:y_end, x:x_end] += 1
119
 
120
  restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
121
  return restored_padded_tensor[:, :, :h_orig, :w_orig]
 
127
  gr.Warning("请输入一张图片!")
128
  return None
129
 
130
+ # 增加完整的运行时错误捕获
131
  try:
132
  model = MODELS[task_name]
133
  print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
 
158
 
159
  # --- 4. Gradio UI ---
160
 
 
161
  def create_task_tab(task_name: str):
162
+ """动态创建每个任务的UI选项卡。"""
163
  with gr.TabItem(task_name, id=task_name):
164
  with gr.Row():
165
  input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
 
167
 
168
  submit_btn = gr.Button("开始处理 (Process)", variant="primary")
169
 
170
+ # ✨ 修正后的处理函数
171
+ # 这个函数只接收它能从 inputs 中得到的 `img` 参数。
172
+ def specific_process_fn(img):
173
+ # 调用 process_image 时不传递 progress 参数,
174
+ # 从而让 process_image 自动使用其函数定义中的默认值: progress=gr.Progress(...)
175
+ return process_image(img, task_name)
176
+
177
+ # click 事件的 inputs 列表只有一个元素,对应 specific_process_fn 的 img 参数
178
  submit_btn.click(fn=specific_process_fn, inputs=[input_img], outputs=output_img)
179
 
180
  if EXAMPLE_IMAGES.get(task_name):
 
186
  cache_examples=True,
187
  )
188
 
189
+ # 创建应用主界面
190
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
191
  gr.Markdown(
192
  """
 
196
  )
197
  with gr.Tabs():
198
  for task_name in MODEL_IDS.keys():
199
+ create_task_tab(task_name) # 调用函数为每个任务创建Tab
200
 
201
+ # 启动应用
202
  demo.launch(server_name="0.0.0.0")