suncongcong commited on
Commit
9ae093c
·
verified ·
1 Parent(s): 1f6f9df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -17
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- import torch.nn.functional as F
4
  import numpy as np
5
  from transformers import CLIPImageProcessor
6
  from modeling_ast import ASTForRestoration
@@ -35,21 +35,20 @@ print("✅ 模型加载成功,准备就绪!")
35
  def derain_image_Tiled(input_image: Image.Image, progress=gr.Progress(track_tqdm=True)):
36
  if input_image is None:
37
  return None
38
-
39
  img = input_image.convert("RGB")
40
  img_tensor = to_tensor(img).unsqueeze(0).to(device)
41
  b, c, h, w = img_tensor.shape
42
 
43
  output_canvas = torch.zeros_like(img_tensor).to(device)
44
  weight_map = torch.zeros_like(img_tensor).to(device)
45
-
46
  stride = PATCH_SIZE - OVERLAP
47
-
48
- # 计算需要裁切的块数
49
  h_steps = len(range(0, h, stride))
50
  w_steps = len(range(0, w, stride))
51
  total_patches = h_steps * w_steps
52
-
53
  pbar = tqdm(total=total_patches, desc="正在处理图像块...")
54
 
55
  for y in range(0, h, stride):
@@ -57,34 +56,34 @@ def derain_image_Tiled(input_image: Image.Image, progress=gr.Progress(track_tqdm
57
  y_end = min(y + PATCH_SIZE, h)
58
  x_end = min(x + PATCH_SIZE, w)
59
  patch_in = img_tensor[:, :, y:y_end, x:x_end]
60
-
61
  ph, pw = patch_in.shape[2:]
62
  pad_h = PATCH_SIZE - ph
63
  pad_w = PATCH_SIZE - pw
64
  if pad_h > 0 or pad_w > 0:
65
- patch_padded = F.pad(patch_in, (0, pad_w, 0, pad_h), 'reflect')
66
  else:
67
  patch_padded = patch_in
68
-
69
  with torch.no_grad():
70
  outputs = model(patch_padded)
71
-
72
  patch_out = outputs[0] if isinstance(outputs, tuple) else outputs
73
  patch_out = torch.clamp(patch_out, 0, 1)
74
 
75
  patch_out_unpadded = patch_out[:, :, :ph, :pw]
76
-
77
  output_canvas[:, :, y:y_end, x:x_end] += patch_out_unpadded
78
  weight_map[:, :, y:y_end, x:x_end] += 1
79
-
80
  pbar.update(1)
81
 
82
  pbar.close()
83
 
84
  restored_tensor = output_canvas / weight_map
85
-
86
  restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
87
-
88
  return restored_image
89
 
90
  # --- 4. 创建并启动 Gradio 界面 ---
@@ -99,9 +98,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
99
  with gr.Row():
100
  input_img = gr.Image(type="pil", label="输入带雨图片 (Input Rainy Image)")
101
  output_img = gr.Image(type="pil", label="输出清晰图片 (Output Deraided Image)")
102
-
103
  submit_btn = gr.Button("开始去雨 (Start Deraining)", variant="primary")
104
-
105
  submit_btn.click(fn=derain_image_Tiled, inputs=input_img, outputs=output_img)
106
-
107
  demo.launch(server_name="0.0.0.0")
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F
4
  import numpy as np
5
  from transformers import CLIPImageProcessor
6
  from modeling_ast import ASTForRestoration
 
35
  def derain_image_Tiled(input_image: Image.Image, progress=gr.Progress(track_tqdm=True)):
36
  if input_image is None:
37
  return None
38
+
39
  img = input_image.convert("RGB")
40
  img_tensor = to_tensor(img).unsqueeze(0).to(device)
41
  b, c, h, w = img_tensor.shape
42
 
43
  output_canvas = torch.zeros_like(img_tensor).to(device)
44
  weight_map = torch.zeros_like(img_tensor).to(device)
45
+
46
  stride = PATCH_SIZE - OVERLAP
47
+
 
48
  h_steps = len(range(0, h, stride))
49
  w_steps = len(range(0, w, stride))
50
  total_patches = h_steps * w_steps
51
+
52
  pbar = tqdm(total=total_patches, desc="正在处理图像块...")
53
 
54
  for y in range(0, h, stride):
 
56
  y_end = min(y + PATCH_SIZE, h)
57
  x_end = min(x + PATCH_SIZE, w)
58
  patch_in = img_tensor[:, :, y:y_end, x:x_end]
59
+
60
  ph, pw = patch_in.shape[2:]
61
  pad_h = PATCH_SIZE - ph
62
  pad_w = PATCH_SIZE - pw
63
  if pad_h > 0 or pad_w > 0:
64
+ patch_padded = F.pad(patch_in, (0, pad_w, 0, pad_h), 'replicate') # <-- 最终修正
65
  else:
66
  patch_padded = patch_in
67
+
68
  with torch.no_grad():
69
  outputs = model(patch_padded)
70
+
71
  patch_out = outputs[0] if isinstance(outputs, tuple) else outputs
72
  patch_out = torch.clamp(patch_out, 0, 1)
73
 
74
  patch_out_unpadded = patch_out[:, :, :ph, :pw]
75
+
76
  output_canvas[:, :, y:y_end, x:x_end] += patch_out_unpadded
77
  weight_map[:, :, y:y_end, x:x_end] += 1
78
+
79
  pbar.update(1)
80
 
81
  pbar.close()
82
 
83
  restored_tensor = output_canvas / weight_map
84
+
85
  restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
86
+
87
  return restored_image
88
 
89
  # --- 4. 创建并启动 Gradio 界面 ---
 
98
  with gr.Row():
99
  input_img = gr.Image(type="pil", label="输入带雨图片 (Input Rainy Image)")
100
  output_img = gr.Image(type="pil", label="输出清晰图片 (Output Deraided Image)")
101
+
102
  submit_btn = gr.Button("开始去雨 (Start Deraining)", variant="primary")
103
+
104
  submit_btn.click(fn=derain_image_Tiled, inputs=input_img, outputs=output_img)
105
+
106
  demo.launch(server_name="0.0.0.0")