Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,7 +11,7 @@ from torchvision.transforms.functional import to_pil_image, to_tensor
|
|
| 11 |
from tqdm import tqdm
|
| 12 |
import math
|
| 13 |
|
| 14 |
-
# --- 1. 配置 ---
|
| 15 |
MODEL_IDS = {
|
| 16 |
"去雨痕 (Derain)": "Suncongcong/AST_DeRain",
|
| 17 |
"去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
|
|
@@ -25,7 +25,7 @@ EXAMPLE_IMAGES = {
|
|
| 25 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
print(f"正在使用的设备: {device}")
|
| 27 |
|
| 28 |
-
# --- 2. 加载所有模型和处理器 ---
|
| 29 |
MODELS = {}
|
| 30 |
PROCESSOR = None
|
| 31 |
print("正在加载所有模型和处理器...")
|
|
@@ -45,67 +45,141 @@ except Exception as e:
|
|
| 45 |
MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
|
| 46 |
print("所有模型加载完毕,准备就绪!")
|
| 47 |
|
|
|
|
| 48 |
# --- 3. 定义不同任务的处理函数 ---
|
|
|
|
| 49 |
def process_with_pad_to_square(model, img_tensor):
|
|
|
|
| 50 |
def expand2square(timg, factor=128.0):
|
|
|
|
|
|
|
| 51 |
_, _, h, w = timg.size()
|
| 52 |
X = int(math.ceil(max(h, w) / factor) * factor)
|
| 53 |
-
img_padded = torch.zeros(1, 3, X, X
|
| 54 |
-
mask = torch.zeros(1, 1, X, X
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
| 57 |
return img_padded, mask
|
|
|
|
| 58 |
original_h, original_w = img_tensor.shape[2:]
|
| 59 |
padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
|
|
|
|
| 60 |
with torch.no_grad():
|
| 61 |
restored_padded = model(padded_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
restored_tensor = torch.masked_select(
|
| 63 |
-
restored_padded,
|
| 64 |
).reshape(1, 3, original_h, original_w)
|
|
|
|
| 65 |
return restored_tensor
|
| 66 |
|
|
|
|
| 67 |
def process_with_dehaze_tiling(model, img_tensor, progress):
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
b, c, h_orig, w_orig = img_tensor.shape
|
| 70 |
stride = CROP_SIZE - OVERLAP
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
| 73 |
img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
|
| 74 |
-
|
|
|
|
|
|
|
| 75 |
output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
|
| 76 |
weight_map = torch.zeros_like(output_canvas)
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
patch_in = img_padded[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE]
|
| 83 |
with torch.no_grad():
|
| 84 |
patch_out = model(patch_in.to(device)).cpu()
|
|
|
|
| 85 |
output_canvas[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += patch_out
|
| 86 |
weight_map[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += 1
|
| 87 |
-
|
| 88 |
-
pbar.close()
|
| 89 |
restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
|
| 90 |
return restored_padded_tensor[:, :, :h_orig, :w_orig]
|
| 91 |
|
|
|
|
| 92 |
def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 111 |
gr.Markdown(
|
|
@@ -116,30 +190,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 116 |
)
|
| 117 |
with gr.Tabs():
|
| 118 |
for task_name in MODEL_IDS.keys():
|
| 119 |
-
|
| 120 |
-
with gr.Row():
|
| 121 |
-
input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
|
| 122 |
-
output_img = gr.Image(type="pil", label="输出图片 (Output)")
|
| 123 |
-
|
| 124 |
-
task_id_box = gr.Textbox(task_name, visible=False)
|
| 125 |
-
submit_btn = gr.Button("开始处理 (Process)", variant="primary")
|
| 126 |
-
|
| 127 |
-
# “提交”按钮的点击事件,保持不变
|
| 128 |
-
submit_btn.click(fn=process_image, inputs=[input_img, task_id_box], outputs=output_img)
|
| 129 |
-
|
| 130 |
-
# --- 最终修正 ---
|
| 131 |
-
# 重新构造 lambda 函数,确保它总是传递正确的 task_name
|
| 132 |
-
# 我们不再依赖外部的 task_name 变量,而是直接使用在循环中定义的那个
|
| 133 |
-
def create_example_fn(current_task_name):
|
| 134 |
-
return lambda img, prog: process_image(img, current_task_name, prog)
|
| 135 |
-
|
| 136 |
-
if EXAMPLE_IMAGES.get(task_name):
|
| 137 |
-
gr.Examples(
|
| 138 |
-
examples=EXAMPLE_IMAGES.get(task_name, []),
|
| 139 |
-
inputs=input_img,
|
| 140 |
-
outputs=output_img,
|
| 141 |
-
fn=create_example_fn(task_name), # 关键:为每个循环的 task_name 创建一个独立的函数
|
| 142 |
-
cache_examples=True
|
| 143 |
-
)
|
| 144 |
|
| 145 |
demo.launch(server_name="0.0.0.0")
|
|
|
|
| 11 |
from tqdm import tqdm
|
| 12 |
import math
|
| 13 |
|
| 14 |
+
# --- 1. 配置 (无变化) ---
|
| 15 |
MODEL_IDS = {
|
| 16 |
"去雨痕 (Derain)": "Suncongcong/AST_DeRain",
|
| 17 |
"去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
|
|
|
|
| 25 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
print(f"正在使用的设备: {device}")
|
| 27 |
|
| 28 |
+
# --- 2. 加载所有模型和处理器 (无变化) ---
|
| 29 |
MODELS = {}
|
| 30 |
PROCESSOR = None
|
| 31 |
print("正在加载所有模型和处理器...")
|
|
|
|
| 45 |
MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
|
| 46 |
print("所有模型加载完毕,准备就绪!")
|
| 47 |
|
| 48 |
+
|
| 49 |
# --- 3. 定义不同任务的处理函数 ---
|
| 50 |
+
|
| 51 |
def process_with_pad_to_square(model, img_tensor):
|
| 52 |
+
"""将图片填充为正方形后进行处理,适用于去雨/去雨滴任务。"""
|
| 53 |
def expand2square(timg, factor=128.0):
|
| 54 |
+
# ✨ 优化点: 增加注释,解释 factor 的作用
|
| 55 |
+
# factor: 模型的网络结构要求输入的尺寸最好是该值的整数倍
|
| 56 |
_, _, h, w = timg.size()
|
| 57 |
X = int(math.ceil(max(h, w) / factor) * factor)
|
| 58 |
+
img_padded = torch.zeros(1, 3, X, X, device=timg.device, dtype=timg.dtype)
|
| 59 |
+
mask = torch.zeros(1, 1, X, X, device=timg.device, dtype=timg.dtype)
|
| 60 |
+
|
| 61 |
+
pad_h = (X - h) // 2
|
| 62 |
+
pad_w = (X - w) // 2
|
| 63 |
+
img_padded[:, :, pad_h:pad_h + h, pad_w:pad_w + w] = timg
|
| 64 |
+
mask[:, :, pad_h:pad_h + h, pad_w:pad_w + w].fill_(1)
|
| 65 |
return img_padded, mask
|
| 66 |
+
|
| 67 |
original_h, original_w = img_tensor.shape[2:]
|
| 68 |
padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
|
| 69 |
+
|
| 70 |
with torch.no_grad():
|
| 71 |
restored_padded = model(padded_input)
|
| 72 |
+
|
| 73 |
+
# ✨ 优化点: 确保 mask 和 restored_padded 在同一设备上
|
| 74 |
+
mask_bool = mask.bool().to(restored_padded.device)
|
| 75 |
+
|
| 76 |
restored_tensor = torch.masked_select(
|
| 77 |
+
restored_padded, mask_bool
|
| 78 |
).reshape(1, 3, original_h, original_w)
|
| 79 |
+
|
| 80 |
return restored_tensor
|
| 81 |
|
| 82 |
+
|
| 83 |
def process_with_dehaze_tiling(model, img_tensor, progress):
|
| 84 |
+
"""使用重叠分块策略处理图像,适用于去雾任务。"""
|
| 85 |
+
# ✨ 优化点: 将“魔法数字”定义为常量并添加注释
|
| 86 |
+
CROP_SIZE = 1152 # 每个图块的尺寸
|
| 87 |
+
OVERLAP = 384 # 图块之间的重叠区域大小,以避免边缘效应
|
| 88 |
+
|
| 89 |
b, c, h_orig, w_orig = img_tensor.shape
|
| 90 |
stride = CROP_SIZE - OVERLAP
|
| 91 |
+
|
| 92 |
+
# 计算需要填充的尺寸
|
| 93 |
+
h_pad = (stride - (h_orig - OVERLAP) % stride) % stride if h_orig > CROP_SIZE else 0
|
| 94 |
+
w_pad = (stride - (w_orig - OVERLAP) % stride) % stride if w_orig > CROP_SIZE else 0
|
| 95 |
+
|
| 96 |
img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
|
| 97 |
+
b, c, h_padded, w_padded = img_padded.shape
|
| 98 |
+
|
| 99 |
+
# 使用CPU来存储最终结果,避免占用大量显存
|
| 100 |
output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
|
| 101 |
weight_map = torch.zeros_like(output_canvas)
|
| 102 |
+
|
| 103 |
+
h_steps_range = range(0, h_padded - CROP_SIZE + 1, stride) if h_padded > CROP_SIZE else [0]
|
| 104 |
+
w_steps_range = range(0, w_padded - CROP_SIZE + 1, stride) if w_padded > CROP_SIZE else [0]
|
| 105 |
+
|
| 106 |
+
total_steps = len(h_steps_range) * len(w_steps_range)
|
| 107 |
+
|
| 108 |
+
# ✨ 优化点: 使用Gradio的进度条,而不是手动的tqdm
|
| 109 |
+
for y in progress.tqdm(h_steps_range, desc="正在分块去雾..."):
|
| 110 |
+
for x in w_steps_range:
|
| 111 |
patch_in = img_padded[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE]
|
| 112 |
with torch.no_grad():
|
| 113 |
patch_out = model(patch_in.to(device)).cpu()
|
| 114 |
+
|
| 115 |
output_canvas[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += patch_out
|
| 116 |
weight_map[:, :, y:y+CROP_SIZE, x:x+CROP_SIZE] += 1
|
| 117 |
+
|
|
|
|
| 118 |
restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
|
| 119 |
return restored_padded_tensor[:, :, :h_orig, :w_orig]
|
| 120 |
|
| 121 |
+
|
| 122 |
def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
|
| 123 |
+
"""主处理函数,根据任务名分派到不同的处理流程。"""
|
| 124 |
+
if input_image is None:
|
| 125 |
+
gr.Warning("请输入一张图片!")
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
# ✨ 优化点: 增加完整的运行时错误捕获
|
| 129 |
+
try:
|
| 130 |
+
model = MODELS[task_name]
|
| 131 |
+
print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
|
| 132 |
+
|
| 133 |
+
# 检查模型是否加载成功
|
| 134 |
+
if not isinstance(model, torch.nn.Module):
|
| 135 |
+
model() # 如果加载失败,这里会触发 load_error_func 并抛出异常
|
| 136 |
+
|
| 137 |
+
img = input_image.convert("RGB")
|
| 138 |
+
img_tensor = to_tensor(img).unsqueeze(0)
|
| 139 |
+
|
| 140 |
+
if task_name == "去雾 (Dehaze)":
|
| 141 |
+
restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress)
|
| 142 |
+
else:
|
| 143 |
+
restored_tensor = process_with_pad_to_square(model, img_tensor)
|
| 144 |
+
|
| 145 |
+
restored_tensor = torch.clamp(restored_tensor, 0, 1)
|
| 146 |
+
restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
|
| 147 |
+
return restored_image
|
| 148 |
+
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"处理图片时发生错误: {e}")
|
| 151 |
+
# 在UI上给用户一个清晰的错误提示
|
| 152 |
+
gr.Error(f"处理失败!错误信息: {e}")
|
| 153 |
+
# 返回原始图像,而不是空着或保留上一次的结果
|
| 154 |
+
return input_image
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# --- 4. Gradio UI ---
|
| 158 |
+
|
| 159 |
+
# ✨ 优化点 3: 将创建Tab的逻辑封装成函数,使UI代码更干净
|
| 160 |
+
def create_task_tab(task_name: str):
|
| 161 |
+
with gr.TabItem(task_name, id=task_name):
|
| 162 |
+
with gr.Row():
|
| 163 |
+
input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
|
| 164 |
+
output_img = gr.Image(type="pil", label="输出图片 (Output)")
|
| 165 |
+
|
| 166 |
+
submit_btn = gr.Button("开始处理 (Process)", variant="primary")
|
| 167 |
+
|
| 168 |
+
# ✨ 优化点 1: 创建一个处理函数,它已经“知道”自己的任务名
|
| 169 |
+
# 这样就不再需要隐藏的 Textbox 来传递 task_name
|
| 170 |
+
def specific_process_fn(img, prog):
|
| 171 |
+
return process_image(img, task_name, prog)
|
| 172 |
+
|
| 173 |
+
submit_btn.click(fn=specific_process_fn, inputs=[input_img], outputs=output_img)
|
| 174 |
|
| 175 |
+
if EXAMPLE_IMAGES.get(task_name):
|
| 176 |
+
gr.Examples(
|
| 177 |
+
examples=EXAMPLE_IMAGES.get(task_name, []),
|
| 178 |
+
inputs=input_img,
|
| 179 |
+
outputs=output_img,
|
| 180 |
+
fn=specific_process_fn, # 复用上面为按钮创建的处理函数
|
| 181 |
+
cache_examples=True,
|
| 182 |
+
)
|
| 183 |
|
| 184 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 185 |
gr.Markdown(
|
|
|
|
| 190 |
)
|
| 191 |
with gr.Tabs():
|
| 192 |
for task_name in MODEL_IDS.keys():
|
| 193 |
+
create_task_tab(task_name) # 调用函数创建每个Tab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
demo.launch(server_name="0.0.0.0")
|