suncongcong commited on
Commit
1f6f9df
·
verified ·
1 Parent(s): 942c73d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -37
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  import numpy as np
4
  from transformers import CLIPImageProcessor
5
  from modeling_ast import ASTForRestoration
@@ -12,20 +13,21 @@ from tqdm import tqdm
12
  # --- 1. 配置 ---
13
  repo_id = "suncongcong/AST_DeRain"
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
- PATCH_SIZE = 256 # 模型期望的输入尺寸
16
- OVERLAP = 64 # 裁切块之间的重叠区域,可以调整
17
 
18
  print(f"正在使用的设备: {device}")
19
 
20
  # --- 2. 加载模型和处理器 ---
21
  print(f"正在从 '{repo_id}' 加载模型和处理器...")
22
  processor = CLIPImageProcessor.from_pretrained(repo_id)
23
- # 注意:我们不再修改处理器的尺寸,因为我们会手动裁切
24
- print(f"图像处理器加载完成。")
 
25
  model = ASTForRestoration.from_pretrained(
26
  repo_id,
27
  trust_remote_code=True
28
- ).to(device).eval() # 设置为评估模式
29
  print("✅ 模型加载成功,准备就绪!")
30
 
31
 
@@ -33,70 +35,63 @@ print("✅ 模型加载成功,准备就绪!")
33
  def derain_image_Tiled(input_image: Image.Image, progress=gr.Progress(track_tqdm=True)):
34
  if input_image is None:
35
  return None
36
-
37
  img = input_image.convert("RGB")
38
  img_tensor = to_tensor(img).unsqueeze(0).to(device)
39
  b, c, h, w = img_tensor.shape
40
 
41
- # 创建一个空的画布用于存放结果,和一个用于计算平均值的权重图
42
  output_canvas = torch.zeros_like(img_tensor).to(device)
43
  weight_map = torch.zeros_like(img_tensor).to(device)
44
-
45
  stride = PATCH_SIZE - OVERLAP
46
-
47
- # 计算需要裁切的块数,用于进度条
48
- total_patches = len(range(0, h, stride)) * len(range(0, w, stride))
49
-
50
- # 使用tqdm来创建进度条
 
51
  pbar = tqdm(total=total_patches, desc="正在处理图像块...")
52
 
53
  for y in range(0, h, stride):
54
  for x in range(0, w, stride):
55
- # 1. 裁切 (Crop)
56
  y_end = min(y + PATCH_SIZE, h)
57
  x_end = min(x + PATCH_SIZE, w)
58
  patch_in = img_tensor[:, :, y:y_end, x:x_end]
59
-
60
- # 如果边缘块尺寸不够,进行填充 (padding)
61
  ph, pw = patch_in.shape[2:]
62
  pad_h = PATCH_SIZE - ph
63
  pad_w = PATCH_SIZE - pw
64
- patch_padded = F.pad(patch_in, (0, pad_w, 0, pad_h), 'reflect')
65
-
66
- # 2. 推理 (Inference)
 
 
67
  with torch.no_grad():
68
- # 注意:这里我们不再使用processor,因为已经手动处理了
69
- # 直接将 (1, 3, 256, 256) 的 tensor 送入模型
70
  outputs = model(patch_padded)
71
-
72
  patch_out = outputs[0] if isinstance(outputs, tuple) else outputs
73
  patch_out = torch.clamp(patch_out, 0, 1)
74
 
75
- # 移除填充部分
76
  patch_out_unpadded = patch_out[:, :, :ph, :pw]
77
-
78
- # 3. 合并 (Merge)
79
- # 将处理后的块加到输出画布上,并更新权重图
80
  output_canvas[:, :, y:y_end, x:x_end] += patch_out_unpadded
81
  weight_map[:, :, y:y_end, x:x_end] += 1
82
-
83
- pbar.update(1) # 更新进度条
84
 
85
  pbar.close()
86
 
87
- # 4. 平均 (Average)
88
- # 用输出画布除以权重图,得到重叠区域的平均像素值
89
  restored_tensor = output_canvas / weight_map
90
-
91
  restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
92
-
93
  return restored_image
94
 
95
  # --- 4. 创建并启动 Gradio 界面 ---
96
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
97
  gr.Markdown(
98
  """
99
- #AST 图像去雨模型在线演示 (裁切/合并策略)
100
  上传任意尺寸的带雨图片,模型将会分块处理并拼接成完整的高清输出。
101
  模型仓库地址: [suncongcong/AST_DeRain](https://huggingface.co/suncongcong/AST_DeRain)
102
  """
@@ -104,10 +99,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
104
  with gr.Row():
105
  input_img = gr.Image(type="pil", label="输入带雨图片 (Input Rainy Image)")
106
  output_img = gr.Image(type="pil", label="输出清晰图片 (Output Deraided Image)")
107
-
108
  submit_btn = gr.Button("开始去雨 (Start Deraining)", variant="primary")
109
-
110
- # 将新的处理函数绑定到按钮
111
  submit_btn.click(fn=derain_image_Tiled, inputs=input_img, outputs=output_img)
112
-
113
  demo.launch(server_name="0.0.0.0")
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F
4
  import numpy as np
5
  from transformers import CLIPImageProcessor
6
  from modeling_ast import ASTForRestoration
 
13
  # --- 1. 配置 ---
14
  repo_id = "suncongcong/AST_DeRain"
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ PATCH_SIZE = 256
17
+ OVERLAP = 64
18
 
19
  print(f"正在使用的设备: {device}")
20
 
21
  # --- 2. 加载模型和处理器 ---
22
  print(f"正在从 '{repo_id}' 加载模型和处理器...")
23
  processor = CLIPImageProcessor.from_pretrained(repo_id)
24
+ processor.size = {"height": 256, "width": 256}
25
+ processor.crop_size = {"height": 256, "width": 256}
26
+ print(f"图像处理器尺寸已强制设置为: {processor.size}")
27
  model = ASTForRestoration.from_pretrained(
28
  repo_id,
29
  trust_remote_code=True
30
+ ).to(device).eval()
31
  print("✅ 模型加载成功,准备就绪!")
32
 
33
 
 
35
  def derain_image_Tiled(input_image: Image.Image, progress=gr.Progress(track_tqdm=True)):
36
  if input_image is None:
37
  return None
38
+
39
  img = input_image.convert("RGB")
40
  img_tensor = to_tensor(img).unsqueeze(0).to(device)
41
  b, c, h, w = img_tensor.shape
42
 
 
43
  output_canvas = torch.zeros_like(img_tensor).to(device)
44
  weight_map = torch.zeros_like(img_tensor).to(device)
45
+
46
  stride = PATCH_SIZE - OVERLAP
47
+
48
+ # 计算需要裁切的块数
49
+ h_steps = len(range(0, h, stride))
50
+ w_steps = len(range(0, w, stride))
51
+ total_patches = h_steps * w_steps
52
+
53
  pbar = tqdm(total=total_patches, desc="正在处理图像块...")
54
 
55
  for y in range(0, h, stride):
56
  for x in range(0, w, stride):
 
57
  y_end = min(y + PATCH_SIZE, h)
58
  x_end = min(x + PATCH_SIZE, w)
59
  patch_in = img_tensor[:, :, y:y_end, x:x_end]
60
+
 
61
  ph, pw = patch_in.shape[2:]
62
  pad_h = PATCH_SIZE - ph
63
  pad_w = PATCH_SIZE - pw
64
+ if pad_h > 0 or pad_w > 0:
65
+ patch_padded = F.pad(patch_in, (0, pad_w, 0, pad_h), 'reflect')
66
+ else:
67
+ patch_padded = patch_in
68
+
69
  with torch.no_grad():
 
 
70
  outputs = model(patch_padded)
71
+
72
  patch_out = outputs[0] if isinstance(outputs, tuple) else outputs
73
  patch_out = torch.clamp(patch_out, 0, 1)
74
 
 
75
  patch_out_unpadded = patch_out[:, :, :ph, :pw]
76
+
 
 
77
  output_canvas[:, :, y:y_end, x:x_end] += patch_out_unpadded
78
  weight_map[:, :, y:y_end, x:x_end] += 1
79
+
80
+ pbar.update(1)
81
 
82
  pbar.close()
83
 
 
 
84
  restored_tensor = output_canvas / weight_map
85
+
86
  restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
87
+
88
  return restored_image
89
 
90
  # --- 4. 创建并启动 Gradio 界面 ---
91
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
92
  gr.Markdown(
93
  """
94
+ # 🖼️ AST 图像去雨模型在线演示 (裁切/合并策略)
95
  上传任意尺寸的带雨图片,模型将会分块处理并拼接成完整的高清输出。
96
  模型仓库地址: [suncongcong/AST_DeRain](https://huggingface.co/suncongcong/AST_DeRain)
97
  """
 
99
  with gr.Row():
100
  input_img = gr.Image(type="pil", label="输入带雨图片 (Input Rainy Image)")
101
  output_img = gr.Image(type="pil", label="输出清晰图片 (Output Deraided Image)")
102
+
103
  submit_btn = gr.Button("开始去雨 (Start Deraining)", variant="primary")
104
+
 
105
  submit_btn.click(fn=derain_image_Tiled, inputs=input_img, outputs=output_img)
106
+
107
  demo.launch(server_name="0.0.0.0")