suncongcong commited on
Commit
575a9d2
·
verified ·
1 Parent(s): 97d6718

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -90
app.py CHANGED
@@ -1,127 +1,283 @@
1
  import gradio as gr
 
2
  import torch
 
3
  import torch.nn.functional as F
 
4
  import numpy as np
 
5
  from transformers import CLIPImageProcessor
 
6
  from modeling_ast import ASTForRestoration
 
7
  from PIL import Image
 
8
  import requests
 
9
  from io import BytesIO
 
10
  from torchvision.transforms.functional import to_pil_image, to_tensor
 
11
  from tqdm import tqdm
 
12
  import math
13
 
 
 
14
  # --- 1. 配置 ---
 
15
  MODEL_IDS = {
16
- "去雨痕 (Derain)": "Suncongcong/AST_DeRain",
17
- "去雨 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
18
- "去雾 (Dehaze)": "Suncongcong/AST_Dehazing"
 
 
 
 
19
  }
 
20
  EXAMPLE_IMAGES = {
21
- "去雨痕 (Derain)": [["derain_example1.png"], ["derain_example2.png"], ["derain_example3.png"]],
22
- "去雨 (Deraindrop)": [["deraindrop_example1.png"], ["deraindrop_example2.png"], ["deraindrop_example3.png"]],
23
- "去雾 (Dehaze)": []
 
 
 
 
24
  }
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
26
  print(f"正在使用的设备: {device}")
27
 
 
 
28
  # --- 2. 加载所有模型和处理器 ---
 
29
  MODELS = {}
 
30
  PROCESSOR = None
 
31
  print("正在加载所有模型和处理器...")
 
32
  try:
33
- for task_name, repo_id in MODEL_IDS.items():
34
- print(f"正在加载模型: {task_name} ({repo_id})")
35
- if PROCESSOR is None:
36
- PROCESSOR = CLIPImageProcessor.from_pretrained(repo_id)
37
- print("✅ 处理器加载成功。")
38
- model = ASTForRestoration.from_pretrained(repo_id, trust_remote_code=True).to(device).eval()
39
- MODELS[task_name] = model
40
- print(f"✅ 模型 '{task_name}' 加载成功。")
 
 
 
 
 
 
 
 
 
41
  except Exception as e:
42
- print(f"加载模型时出错: {e}")
43
- def load_error_func(*args, **kwargs):
44
- raise gr.Error(f"模型加载失败! 错误: {e}")
45
- MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
 
 
 
 
 
46
  print("所有模型加载完毕,准备就绪!")
47
 
48
 
 
49
  # --- 3. 定义不同任务的处理函数 ---
 
50
  def process_with_pad_to_square(model, img_tensor):
51
- def expand2square(timg, factor=128.0):
52
- _, _, h, w = timg.size(); X = int(math.ceil(max(h, w) / factor) * factor)
53
- img_padded = torch.zeros(1, 3, X, X).type_as(timg); mask = torch.zeros(1, 1, X, X).type_as(timg)
54
- img_padded[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)] = timg
55
- mask[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)].fill_(1)
56
- return img_padded, mask
57
- original_h, original_w = img_tensor.shape[2:]
58
- padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
59
- with torch.no_grad():
60
- restored_padded = model(padded_input)
61
- restored_tensor = torch.masked_select(restored_padded, mask.bool()).reshape(1, 3, original_h, original_w)
62
- return restored_tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def process_with_dehaze_tiling(model, img_tensor, progress):
65
- CROP_SIZE, OVERLAP = 1152, 384
66
- b, c, h_orig, w_orig = img_tensor.shape; stride = CROP_SIZE - OVERLAP
67
- h_pad = (stride - (h_orig - OVERLAP) % stride) % stride; w_pad = (stride - (w_orig - OVERLAP) % stride) % stride
68
- img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
69
- _, _, h_padded, w_padded = img_padded.shape
70
- output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu'); weight_map = torch.zeros_like(output_canvas)
71
- h_steps = len(range(0, h_padded - OVERLAP, stride)) if h_padded > OVERLAP else 1
72
- w_steps = len(range(0, w_padded - OVERLAP, stride)) if w_padded > OVERLAP else 1
73
- pbar = tqdm(total=h_steps * w_steps, desc=f"正在执行去雾...")
74
- for y in range(0, h_padded - OVERLAP, stride) if h_padded > OVERLAP else [0]:
75
- for x in range(0, w_padded - OVERLAP, stride) if w_padded > OVERLAP else [0]:
76
- patch_in = img_padded[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE]
77
- with torch.no_grad(): patch_out = model(patch_in.to(device)).cpu()
78
- output_canvas[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += patch_out
79
- weight_map[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += 1
80
- pbar.update(1)
81
- # --- 最终修正:只有在 progress 对象存在时才更新它 ---
82
- if progress:
83
- progress(pbar.n / pbar.total, desc=pbar.desc)
84
- pbar.close()
85
- restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
86
- return restored_padded_tensor[:, :, :h_orig, :w_orig]
87
-
88
- # 主调度函数
89
- def run_task(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
90
- if input_image is None: return None
91
- model = MODELS[task_name]
92
- if isinstance(model, str): raise gr.Error(model)
93
- img_tensor = to_tensor(input_image.convert("RGB")).unsqueeze(0)
94
- print(f"执行任务: {task_name}")
95
- if task_name == "去雾 (Dehaze)":
96
- restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress)
97
- else:
98
- restored_tensor = process_with_pad_to_square(model, img_tensor)
99
- restored_tensor = torch.clamp(restored_tensor, 0, 1)
100
- return to_pil_image(restored_tensor.cpu().squeeze(0))
101
-
102
- # --- 4. 创建并启动 Gradio 界面 ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
104
- gr.Markdown(
105
- """
106
- # 🖼️ 多功能图像复原工具 (AST 模型)
107
- 请选择一个任务,然后上传对应的图片或点击下方的示例图片进行处理。
108
- """
109
- )
110
- with gr.Tabs():
111
- for task_name in MODEL_IDS.keys():
112
- with gr.TabItem(task_name, id=task_name):
113
- with gr.Row():
114
- input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
115
- output_img = gr.Image(type="pil", label="输出图片 (Output)")
116
- submit_btn = gr.Button("开始处理 (Process)", variant="primary")
117
- submit_btn.click(fn=run_task, inputs=[input_img, gr.State(task_name)], outputs=output_img)
118
- if EXAMPLE_IMAGES.get(task_name):
119
- gr.Examples(
120
- examples=EXAMPLE_IMAGES.get(task_name, []),
121
- inputs=input_img,
122
- outputs=output_img,
123
- fn=lambda img: run_task(img, task_name=task_name), # 示例点击时不传递 progress
124
- cache_examples=True
125
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  demo.launch(server_name="0.0.0.0")
 
1
  import gradio as gr
2
+
3
  import torch
4
+
5
  import torch.nn.functional as F
6
+
7
  import numpy as np
8
+
9
  from transformers import CLIPImageProcessor
10
+
11
  from modeling_ast import ASTForRestoration
12
+
13
  from PIL import Image
14
+
15
  import requests
16
+
17
  from io import BytesIO
18
+
19
  from torchvision.transforms.functional import to_pil_image, to_tensor
20
+
21
  from tqdm import tqdm
22
+
23
  import math
24
 
25
+
26
+
27
  # --- 1. 配置 ---
28
+
29
  MODEL_IDS = {
30
+
31
+     "去雨 (Derain)": "Suncongcong/AST_DeRain",
32
+
33
+     "去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
34
+
35
+     "去雾 (Dehaze)": "Suncongcong/AST_Dehazing"
36
+
37
  }
38
+
39
  EXAMPLE_IMAGES = {
40
+
41
+     "去雨 (Derain)": [["derain_example1.png"], ["derain_example2.png"], ["derain_example3.png"]],
42
+
43
+     "去雨滴 (Deraindrop)": [["deraindrop_example1.png"], ["deraindrop_example2.png"], ["deraindrop_example3.png"]],
44
+
45
+     "去雾 (Dehaze)": [["dehaze_example1.jpg"],["dehaze_example2.jpg"],["dehaze_example3.jpg"]]
46
+
47
  }
48
+
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
50
+
51
  print(f"正在使用的设备: {device}")
52
 
53
+
54
+
55
  # --- 2. 加载所有模型和处理器 ---
56
+
57
  MODELS = {}
58
+
59
  PROCESSOR = None
60
+
61
  print("正在加载所有模型和处理器...")
62
+
63
  try:
64
+
65
+     for task_name, repo_id in MODEL_IDS.items():
66
+
67
+         print(f"正在加载模型: {task_name} ({repo_id})")
68
+
69
+         if PROCESSOR is None:
70
+
71
+             PROCESSOR = CLIPImageProcessor.from_pretrained(repo_id)
72
+
73
+             print("✅ 处理器加载成功。")
74
+
75
+         model = ASTForRestoration.from_pretrained(repo_id, trust_remote_code=True).to(device).eval()
76
+
77
+         MODELS[task_name] = model
78
+
79
+         print(f"✅ 模型 '{task_name}' 加载成功。")
80
+
81
  except Exception as e:
82
+
83
+     print(f"加载模型时出错: {e}")
84
+
85
+     def load_error_func(*args, **kwargs):
86
+
87
+         raise gr.Error(f"模型加载失败! 错误: {e}")
88
+
89
+     MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
90
+
91
  print("所有模型加载完毕,准备就绪!")
92
 
93
 
94
+
95
  # --- 3. 定义不同任务的处理函数 ---
96
+
97
  def process_with_pad_to_square(model, img_tensor):
98
+
99
+     def expand2square(timg, factor=128.0):
100
+
101
+         _, _, h, w = timg.size()
102
+
103
+         X = int(math.ceil(max(h, w) / factor) * factor)
104
+
105
+         img_padded = torch.zeros(1, 3, X, X).type_as(timg)
106
+
107
+         mask = torch.zeros(1, 1, X, X).type_as(timg)
108
+
109
+         img_padded[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)] = timg
110
+
111
+         mask[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)].fill_(1)
112
+
113
+         return img_padded, mask
114
+
115
+     original_h, original_w = img_tensor.shape[2:]
116
+
117
+     padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
118
+
119
+     with torch.no_grad():
120
+
121
+         restored_padded = model(padded_input)
122
+
123
+     restored_tensor = torch.masked_select(
124
+
125
+         restored_padded, mask.bool()
126
+
127
+     ).reshape(1, 3, original_h, original_w)
128
+
129
+     return restored_tensor
130
+
131
+
132
 
133
  def process_with_dehaze_tiling(model, img_tensor, progress):
134
+
135
+     CROP_SIZE, OVERLAP = 1152, 384
136
+
137
+     b, c, h_orig, w_orig = img_tensor.shape
138
+
139
+     stride = CROP_SIZE - OVERLAP
140
+
141
+     h_pad = (stride - (h_orig - OVERLAP) % stride) % stride
142
+
143
+     w_pad = (stride - (w_orig - OVERLAP) % stride) % stride
144
+
145
+     img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
146
+
147
+     _, _, h_padded, w_padded = img_padded.shape
148
+
149
+     output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
150
+
151
+     weight_map = torch.zeros_like(output_canvas)
152
+
153
+     h_steps = len(range(0, h_padded - OVERLAP, stride)) if h_padded > OVERLAP else 1
154
+
155
+     w_steps = len(range(0, w_padded - OVERLAP, stride)) if w_padded > OVERLAP else 1
156
+
157
+     pbar = tqdm(total=h_steps * w_steps, desc=f"正在执行去雾...")
158
+
159
+     for y in range(0, h_padded - OVERLAP, stride) if h_padded > OVERLAP else [0]:
160
+
161
+         for x in range(0, w_padded - OVERLAP, stride) if w_padded > OVERLAP else [0]:
162
+
163
+             patch_in = img_padded[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE]
164
+
165
+             with torch.no_grad():
166
+
167
+                 patch_out = model(patch_in.to(device)).cpu()
168
+
169
+             output_canvas[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += patch_out
170
+
171
+             weight_map[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += 1
172
+
173
+             pbar.update(1)
174
+
175
+     pbar.close()
176
+
177
+     restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
178
+
179
+     return restored_padded_tensor[:, :, :h_orig, :w_orig]
180
+
181
+
182
+
183
+ def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
184
+
185
+     if input_image is None: return None
186
+
187
+     model = MODELS[task_name]
188
+
189
+     print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
190
+
191
+     if not isinstance(model, torch.nn.Module): model()
192
+
193
+     img = input_image.convert("RGB")
194
+
195
+     img_tensor = to_tensor(img).unsqueeze(0)
196
+
197
+     
198
+
199
+     # 关键修正:在 process_image 函数内部也进行判断
200
+
201
+     if task_name == "去雾 (Dehaze)":
202
+
203
+         restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress)
204
+
205
+     else: 
206
+
207
+         restored_tensor = process_with_pad_to_square(model, img_tensor)
208
+
209
+         
210
+
211
+     restored_tensor = torch.clamp(restored_tensor, 0, 1)
212
+
213
+     restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
214
+
215
+     return restored_image
216
+
217
+
218
+
219
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
220
+
221
+     gr.Markdown(
222
+
223
+         """
224
+
225
+         # 🖼️ 多功能图像复原工具 (AST 模型)
226
+
227
+         请选择一个任务,然后上传对应的图片或点击下方的示例图片进行处理。
228
+
229
+         """
230
+
231
+     )
232
+
233
+     with gr.Tabs():
234
+
235
+         for task_name in MODEL_IDS.keys():
236
+
237
+             with gr.TabItem(task_name, id=task_name):
238
+
239
+                 with gr.Row():
240
+
241
+                     input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
242
+
243
+                     output_img = gr.Image(type="pil", label="输出图片 (Output)")
244
+
245
+                 
246
+
247
+                 task_id_box = gr.Textbox(task_name, visible=False)
248
+
249
+                 submit_btn = gr.Button("开始处理 (Process)", variant="primary")
250
+
251
+                 
252
+
253
+                 # “提交”按钮的点击事件,保持不变
254
+
255
+                 submit_btn.click(fn=process_image, inputs=[input_img, task_id_box], outputs=output_img)
256
+
257
+                 
258
+
259
+                 def create_example_fn(current_task_name):
260
+
261
+                     return lambda img, prog: process_image(img, current_task_name, prog)
262
+
263
+
264
+
265
+                 if EXAMPLE_IMAGES.get(task_name):
266
+
267
+                     gr.Examples(
268
+
269
+                         examples=EXAMPLE_IMAGES.get(task_name, []),
270
+
271
+                         inputs=input_img,
272
+
273
+                         outputs=output_img,
274
+
275
+                         fn=create_example_fn(task_name), # 关键:为每个循环的 task_name 创建一个独立的函数
276
+
277
+                         cache_examples=True
278
+
279
+                     )
280
+
281
+
282
 
283
  demo.launch(server_name="0.0.0.0")