suncongcong commited on
Commit
6a6f453
·
verified ·
1 Parent(s): 575a9d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -244
app.py CHANGED
@@ -1,283 +1,142 @@
1
  import gradio as gr
2
-
3
  import torch
4
-
5
  import torch.nn.functional as F
6
-
7
  import numpy as np
8
-
9
  from transformers import CLIPImageProcessor
10
-
11
  from modeling_ast import ASTForRestoration
12
-
13
  from PIL import Image
14
-
15
  import requests
16
-
17
  from io import BytesIO
18
-
19
  from torchvision.transforms.functional import to_pil_image, to_tensor
20
-
21
  from tqdm import tqdm
22
-
23
  import math
24
 
25
-
26
-
27
  # --- 1. 配置 ---
28
-
29
  MODEL_IDS = {
30
-
31
-     "去雨 (Derain)": "Suncongcong/AST_DeRain",
32
-
33
-     "去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
34
-
35
-     "去雾 (Dehaze)": "Suncongcong/AST_Dehazing"
36
-
37
  }
38
-
39
  EXAMPLE_IMAGES = {
40
-
41
-     "去雨 (Derain)": [["derain_example1.png"], ["derain_example2.png"], ["derain_example3.png"]],
42
-
43
-     "去雨滴 (Deraindrop)": [["deraindrop_example1.png"], ["deraindrop_example2.png"], ["deraindrop_example3.png"]],
44
-
45
-     "去雾 (Dehaze)": [["dehaze_example1.jpg"],["dehaze_example2.jpg"],["dehaze_example3.jpg"]]
46
-
47
  }
48
-
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
50
-
51
  print(f"正在使用的设备: {device}")
52
 
53
-
54
-
55
  # --- 2. 加载所有模型和处理器 ---
56
-
57
  MODELS = {}
58
-
59
  PROCESSOR = None
60
-
61
  print("正在加载所有模型和处理器...")
62
-
63
  try:
64
-
65
-     for task_name, repo_id in MODEL_IDS.items():
66
-
67
-         print(f"正在加载模型: {task_name} ({repo_id})")
68
-
69
-         if PROCESSOR is None:
70
-
71
-             PROCESSOR = CLIPImageProcessor.from_pretrained(repo_id)
72
-
73
-             print("✅ 处理器加载成功。")
74
-
75
-         model = ASTForRestoration.from_pretrained(repo_id, trust_remote_code=True).to(device).eval()
76
-
77
-         MODELS[task_name] = model
78
-
79
-         print(f"✅ 模型 '{task_name}' 加载成功。")
80
-
81
  except Exception as e:
82
-
83
-     print(f"加载模型时出错: {e}")
84
-
85
-     def load_error_func(*args, **kwargs):
86
-
87
-         raise gr.Error(f"模型加载失败! 错误: {e}")
88
-
89
-     MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
90
-
91
  print("所有模型加载完毕,准备就绪!")
92
 
93
-
94
-
95
  # --- 3. 定义不同任务的处理函数 ---
96
-
97
  def process_with_pad_to_square(model, img_tensor):
98
-
99
-     def expand2square(timg, factor=128.0):
100
-
101
-         _, _, h, w = timg.size()
102
-
103
-         X = int(math.ceil(max(h, w) / factor) * factor)
104
-
105
-         img_padded = torch.zeros(1, 3, X, X).type_as(timg)
106
-
107
-         mask = torch.zeros(1, 1, X, X).type_as(timg)
108
-
109
-         img_padded[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)] = timg
110
-
111
-         mask[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)].fill_(1)
112
-
113
-         return img_padded, mask
114
-
115
-     original_h, original_w = img_tensor.shape[2:]
116
-
117
-     padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
118
-
119
-     with torch.no_grad():
120
-
121
-         restored_padded = model(padded_input)
122
-
123
-     restored_tensor = torch.masked_select(
124
-
125
-         restored_padded, mask.bool()
126
-
127
-     ).reshape(1, 3, original_h, original_w)
128
-
129
-     return restored_tensor
130
-
131
-
132
 
133
  def process_with_dehaze_tiling(model, img_tensor, progress):
134
-
135
-     CROP_SIZE, OVERLAP = 1152, 384
136
-
137
-     b, c, h_orig, w_orig = img_tensor.shape
138
-
139
-     stride = CROP_SIZE - OVERLAP
140
-
141
-     h_pad = (stride - (h_orig - OVERLAP) % stride) % stride
142
-
143
-     w_pad = (stride - (w_orig - OVERLAP) % stride) % stride
144
-
145
-     img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
146
-
147
-     _, _, h_padded, w_padded = img_padded.shape
148
-
149
-     output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
150
-
151
-     weight_map = torch.zeros_like(output_canvas)
152
-
153
-     h_steps = len(range(0, h_padded - OVERLAP, stride)) if h_padded > OVERLAP else 1
154
-
155
-     w_steps = len(range(0, w_padded - OVERLAP, stride)) if w_padded > OVERLAP else 1
156
-
157
-     pbar = tqdm(total=h_steps * w_steps, desc=f"正在执行去雾...")
158
-
159
-     for y in range(0, h_padded - OVERLAP, stride) if h_padded > OVERLAP else [0]:
160
-
161
-         for x in range(0, w_padded - OVERLAP, stride) if w_padded > OVERLAP else [0]:
162
-
163
-             patch_in = img_padded[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE]
164
-
165
-             with torch.no_grad():
166
-
167
-                 patch_out = model(patch_in.to(device)).cpu()
168
-
169
-             output_canvas[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += patch_out
170
-
171
-             weight_map[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += 1
172
-
173
-             pbar.update(1)
174
-
175
-     pbar.close()
176
-
177
-     restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
178
-
179
-     return restored_padded_tensor[:, :, :h_orig, :w_orig]
180
-
181
-
182
 
183
  def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
184
-
185
-     if input_image is None: return None
186
-
187
-     model = MODELS[task_name]
188
-
189
-     print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
190
-
191
-     if not isinstance(model, torch.nn.Module): model()
192
-
193
-     img = input_image.convert("RGB")
194
-
195
-     img_tensor = to_tensor(img).unsqueeze(0)
196
-
197
-     
198
-
199
-     # 关键修正:在 process_image 函数内部也进行判断
200
-
201
-     if task_name == "去雾 (Dehaze)":
202
-
203
-         restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress)
204
-
205
-     else: 
206
-
207
-         restored_tensor = process_with_pad_to_square(model, img_tensor)
208
-
209
-         
210
-
211
-     restored_tensor = torch.clamp(restored_tensor, 0, 1)
212
-
213
-     restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
214
-
215
-     return restored_image
216
-
217
-
218
 
219
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
220
-
221
-     gr.Markdown(
222
-
223
-         """
224
-
225
-         # 🖼️ 多功能图像复原工具 (AST 模型)
226
-
227
-         请选择一个任务,然后上传对应的图片或点击下方的示例图片进行处理。
228
-
229
-         """
230
-
231
-     )
232
-
233
-     with gr.Tabs():
234
-
235
-         for task_name in MODEL_IDS.keys():
236
-
237
-             with gr.TabItem(task_name, id=task_name):
238
-
239
-                 with gr.Row():
240
-
241
-                     input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
242
-
243
-                     output_img = gr.Image(type="pil", label="输出图片 (Output)")
244
-
245
-                 
246
-
247
-                 task_id_box = gr.Textbox(task_name, visible=False)
248
-
249
-                 submit_btn = gr.Button("开始处理 (Process)", variant="primary")
250
-
251
-                 
252
-
253
-                 # “提交”按钮的点击事件,保持不变
254
-
255
-                 submit_btn.click(fn=process_image, inputs=[input_img, task_id_box], outputs=output_img)
256
-
257
-                 
258
-
259
-                 def create_example_fn(current_task_name):
260
-
261
-                     return lambda img, prog: process_image(img, current_task_name, prog)
262
-
263
-
264
-
265
-                 if EXAMPLE_IMAGES.get(task_name):
266
-
267
-                     gr.Examples(
268
-
269
-                         examples=EXAMPLE_IMAGES.get(task_name, []),
270
-
271
-                         inputs=input_img,
272
-
273
-                         outputs=output_img,
274
-
275
-                         fn=create_example_fn(task_name), # 关键:为每个循环的 task_name 创建一个独立的函数
276
-
277
-                         cache_examples=True
278
-
279
-                     )
280
-
281
-
282
 
283
  demo.launch(server_name="0.0.0.0")
 
1
  import gradio as gr
 
2
  import torch
 
3
  import torch.nn.functional as F
 
4
  import numpy as np
 
5
  from transformers import CLIPImageProcessor
 
6
  from modeling_ast import ASTForRestoration
 
7
  from PIL import Image
 
8
  import requests
 
9
  from io import BytesIO
 
10
  from torchvision.transforms.functional import to_pil_image, to_tensor
 
11
  from tqdm import tqdm
 
12
  import math
13
 
 
 
14
  # --- 1. 配置 ---
 
15
  MODEL_IDS = {
16
+ "去雨痕 (Derain)": "Suncongcong/AST_DeRain",
17
+ "去雨 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
18
+ "去雾 (Dehaze)": "Suncongcong/AST_Dehazing"
 
 
 
 
19
  }
 
20
  EXAMPLE_IMAGES = {
21
+ "去雨痕 (Derain)": [["derain_example1.png"], ["derain_example2.png"], ["derain_example3.png"]],
22
+ "去雨 (Deraindrop)": [["deraindrop_example1.png"], ["deraindrop_example2.png"], ["deraindrop_example3.png"]],
23
+ "去雾 (Dehaze)": []
 
 
 
 
24
  }
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
26
  print(f"正在使用的设备: {device}")
27
 
 
 
28
  # --- 2. 加载所有模型和处理器 ---
 
29
  MODELS = {}
 
30
  PROCESSOR = None
 
31
  print("正在加载所有模型和处理器...")
 
32
  try:
33
+ for task_name, repo_id in MODEL_IDS.items():
34
+ print(f"正在加载模型: {task_name} ({repo_id})")
35
+ if PROCESSOR is None:
36
+ PROCESSOR = CLIPImageProcessor.from_pretrained(repo_id)
37
+ print("✅ 处理器加载成功。")
38
+ model = ASTForRestoration.from_pretrained(repo_id, trust_remote_code=True).to(device).eval()
39
+ MODELS[task_name] = model
40
+ print(f"✅ 模型 '{task_name}' 加载成功。")
 
 
 
 
 
 
 
 
 
41
  except Exception as e:
42
+ print(f"加载模型时出错: {e}")
43
+ def load_error_func(*args, **kwargs):
44
+ raise gr.Error(f"模型加载失败! 错误: {e}")
45
+ MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
 
 
 
 
 
46
  print("所有模型加载完毕,准备就绪!")
47
 
 
 
48
  # --- 3. 定义不同任务的处理函数 ---
 
49
  def process_with_pad_to_square(model, img_tensor):
50
+ def expand2square(timg, factor=128.0):
51
+ _, _, h, w = timg.size()
52
+ X = int(math.ceil(max(h, w) / factor) * factor)
53
+ img_padded = torch.zeros(1, 3, X, X).type_as(timg)
54
+ mask = torch.zeros(1, 1, X, X).type_as(timg)
55
+ img_padded[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)] = timg
56
+ mask[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)].fill_(1)
57
+ return img_padded, mask
58
+ original_h, original_w = img_tensor.shape[2:]
59
+ padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
60
+ with torch.no_grad():
61
+ restored_padded = model(padded_input)
62
+ restored_tensor = torch.masked_select(
63
+ restored_padded, mask.bool()
64
+ ).reshape(1, 3, original_h, original_w)
65
+ return restored_tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def process_with_dehaze_tiling(model, img_tensor, progress):
68
+ CROP_SIZE, OVERLAP = 1152, 384
69
+ b, c, h_orig, w_orig = img_tensor.shape
70
+ stride = CROP_SIZE - OVERLAP
71
+ h_pad = (stride - (h_orig - OVERLAP) % stride) % stride
72
+ w_pad = (stride - (w_orig - OVERLAP) % stride) % stride
73
+ img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
74
+ _, _, h_padded, w_padded = img_padded.shape
75
+ output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
76
+ weight_map = torch.zeros_like(output_canvas)
77
+ h_steps = len(range(0, h_padded - OVERLAP, stride)) if h_padded > OVERLAP else 1
78
+ w_steps = len(range(0, w_padded - OVERLAP, stride)) if w_padded > OVERLAP else 1
79
+ pbar = tqdm(total=h_steps * w_steps, desc=f"正在执行去雾...")
80
+ for y in range(0, h_padded - OVERLAP, stride) if h_padded > OVERLAP else [0]:
81
+ for x in range(0, w_padded - OVERLAP, stride) if w_padded > OVERLAP else [0]:
82
+ patch_in = img_padded[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE]
83
+ with torch.no_grad():
84
+ patch_out = model(patch_in.to(device)).cpu()
85
+ output_canvas[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += patch_out
86
+ weight_map[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += 1
87
+ pbar.update(1)
88
+ pbar.close()
89
+ restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
90
+ return restored_padded_tensor[:, :, :h_orig, :w_orig]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
93
+ if input_image is None: return None
94
+ model = MODELS[task_name]
95
+ print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
96
+ if not isinstance(model, torch.nn.Module): model()
97
+ img = input_image.convert("RGB")
98
+ img_tensor = to_tensor(img).unsqueeze(0)
99
+
100
+ # 关键修正:在 process_image 函数内部也进行判断
101
+ if task_name == "去雾 (Dehaze)":
102
+ restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress)
103
+ else:
104
+ restored_tensor = process_with_pad_to_square(model, img_tensor)
105
+
106
+ restored_tensor = torch.clamp(restored_tensor, 0, 1)
107
+ restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
108
+ return restored_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
111
+ gr.Markdown(
112
+ """
113
+ # 🖼️ 多功能图像复原工具 (AST 模型)
114
+ 请选择一个任务,然后上传对应的图片或点击下方的示例图片进行处理。
115
+ """
116
+ )
117
+ with gr.Tabs():
118
+ for task_name in MODEL_IDS.keys():
119
+ with gr.TabItem(task_name, id=task_name):
120
+ with gr.Row():
121
+ input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
122
+ output_img = gr.Image(type="pil", label="输出图片 (Output)")
123
+
124
+ task_id_box = gr.Textbox(task_name, visible=False)
125
+ submit_btn = gr.Button("开始处理 (Process)", variant="primary")
126
+
127
+ # “提交”按钮的点击事件,保持不变
128
+ submit_btn.click(fn=process_image, inputs=[input_img, task_id_box], outputs=output_img)
129
+
130
+ def create_example_fn(current_task_name):
131
+ return lambda img, prog: process_image(img, current_task_name, prog)
132
+
133
+ if EXAMPLE_IMAGES.get(task_name):
134
+ gr.Examples(
135
+ examples=EXAMPLE_IMAGES.get(task_name, []),
136
+ inputs=input_img,
137
+ outputs=output_img,
138
+ fn=create_example_fn(task_name), # 关键:为每个循环的 task_name 创建一个独立的函数
139
+ cache_examples=True
140
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  demo.launch(server_name="0.0.0.0")