Spaces:
Runtime error
Runtime error
File size: 8,271 Bytes
a143c1a 9ae093c 942c73d a143c1a 942c73d d5dd056 a143c1a 9367fe4 8525111 6a6f453 8525111 e10fd9a 6a6f453 ae5af4b e10fd9a a143c1a 9367fe4 8525111 6a6f453 8525111 6a6f453 8525111 7ec5312 ae5af4b 7ec5312 ae5af4b 7ec5312 6a6f453 7ec5312 ae5af4b 9367fe4 7ec5312 6a6f453 7ec5312 6a6f453 7ec5312 6a6f453 7ec5312 9367fe4 7ec5312 6a6f453 7ec5312 6a6f453 7ec5312 ae5af4b 7ec5312 ae5af4b 7ec5312 9367fe4 7ec5312 ae5af4b 7ec5312 9367fe4 7ec5312 ae5af4b 7ec5312 ae5af4b 7ec5312 9367fe4 7ec5312 9367fe4 7ec5312 9367fe4 ae5af4b 7ec5312 9367fe4 7ec5312 ae5af4b 7ec5312 ae5af4b 7ec5312 9367fe4 7ec5312 9367fe4 7ec5312 9367fe4 7ec5312 ae5af4b 7ec5312 575a9d2 9367fe4 942c73d 6a6f453 9367fe4 97d6718 9367fe4 97d6718 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
from transformers import CLIPImageProcessor
from modeling_ast import ASTForRestoration
from PIL import Image
import requests
from io import BytesIO
from torchvision.transforms.functional import to_pil_image, to_tensor
from tqdm import tqdm
import math
# --- 1. 配置 ---
MODEL_IDS = {
"去雨痕 (Derain)": "Suncongcong/AST_DeRain",
"去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
"去雾 (Dehaze)": "Suncongcong/AST_Dehazing"
}
EXAMPLE_IMAGES = {
"去雨痕 (Derain)": [["derain_example1.png"], ["derain_example2.png"], ["derain_example3.png"]],
"去雨滴 (Deraindrop)": [["deraindrop_example1.png"], ["deraindrop_example2.png"], ["deraindrop_example3.png"]],
"去雾 (Dehaze)": [["dehaze_example1.jpg"],["dehaze_example2.jpg"],["dehaze_example3.jpg"]]
}
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"正在使用的设备: {device}")
# --- 2. 加载所有模型和处理器 ---
MODELS = {}
PROCESSOR = None
print("正在加载所有模型和处理器...")
try:
for task_name, repo_id in MODEL_IDS.items():
print(f"正在加载模型: {task_name} ({repo_id})")
if PROCESSOR is None:
PROCESSOR = CLIPImageProcessor.from_pretrained(repo_id)
print("✅ 处理器加载成功。")
model = ASTForRestoration.from_pretrained(repo_id, trust_remote_code=True).to(device).eval()
MODELS[task_name] = model
print(f"✅ 模型 '{task_name}' 加载成功。")
except Exception as e:
print(f"加载模型时出错: {e}")
def load_error_func(*args, **kwargs):
raise gr.Error(f"模型加载失败! 错误: {e}")
MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
print("所有模型加载完毕,准备就绪!")
# --- 3. 定义不同任务的处理函数 ---
def process_with_pad_to_square(model, img_tensor):
"""将图片填充为正方形后进行处理,适用于去雨/去雨滴任务。"""
def expand2square(timg, factor=128.0):
# factor: 模型的网络结构要求输入的尺寸最好是该值的整数倍
_, _, h, w = timg.size()
X = int(math.ceil(max(h, w) / factor) * factor)
# 确保创建的张量在正确的设备上
img_padded = torch.zeros(1, 3, X, X, device=timg.device, dtype=timg.dtype)
mask = torch.zeros(1, 1, X, X, device=timg.device, dtype=timg.dtype)
pad_h = (X - h) // 2
pad_w = (X - w) // 2
img_padded[:, :, pad_h:pad_h + h, pad_w:pad_w + w] = timg
mask[:, :, pad_h:pad_h + h, pad_w:pad_w + w].fill_(1)
return img_padded, mask
original_h, original_w = img_tensor.shape[2:]
padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
with torch.no_grad():
restored_padded = model(padded_input)
# 确保 mask 和 restored_padded 在同一设备上
mask_bool = mask.bool().to(restored_padded.device)
restored_tensor = torch.masked_select(
restored_padded, mask_bool
).reshape(1, 3, original_h, original_w)
return restored_tensor
def process_with_dehaze_tiling(model, img_tensor, progress):
"""使用重叠分块策略处理图像,适用于去雾任务。"""
# 将“魔法数字”定义为常量并添加注释
CROP_SIZE = 1152 # 每个图块的尺寸
OVERLAP = 384 # 图块之间的重叠区域大小,以避免边缘效应
b, c, h_orig, w_orig = img_tensor.shape
stride = CROP_SIZE - OVERLAP
# 计算需要填充的尺寸
h_pad = (stride - (h_orig - OVERLAP) % stride) % stride if h_orig > OVERLAP else 0
w_pad = (stride - (w_orig - OVERLAP) % stride) % stride if w_orig > OVERLAP else 0
img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
b, c, h_padded, w_padded = img_padded.shape
# 使用CPU来存储最终结果,避免占用大量显存
output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
weight_map = torch.zeros_like(output_canvas)
h_steps_range = range(0, h_padded - OVERLAP, stride) if h_padded > OVERLAP else [0]
w_steps_range = range(0, w_padded - OVERLAP, stride) if w_padded > OVERLAP else [0]
# 使用Gradio的进度条
for y in progress.tqdm(h_steps_range, desc="正在分块去雾..."):
for x in w_steps_range:
# 确保切片范围正确
y_end = min(y + CROP_SIZE, h_padded)
x_end = min(x + CROP_SIZE, w_padded)
patch_in = img_padded[:, :, y:y_end, x:x_end]
with torch.no_grad():
patch_out = model(patch_in.to(device)).cpu()
output_canvas[:, :, y:y_end, x:x_end] += patch_out
weight_map[:, :, y:y_end, x:x_end] += 1
restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
return restored_padded_tensor[:, :, :h_orig, :w_orig]
def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
"""主处理函数,根据任务名分派到不同的处理流程。"""
if input_image is None:
gr.Warning("请输入一张图片!")
return None
# 增加完整的运行时错误捕获
try:
model = MODELS[task_name]
print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
# 检查模型是否加载成功
if not isinstance(model, torch.nn.Module):
model() # 如果加载失败,这里会触发 load_error_func 并抛出异常
img = input_image.convert("RGB")
img_tensor = to_tensor(img).unsqueeze(0)
if task_name == "去雾 (Dehaze)":
restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress)
else:
restored_tensor = process_with_pad_to_square(model, img_tensor)
restored_tensor = torch.clamp(restored_tensor, 0, 1)
restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
return restored_image
except Exception as e:
print(f"处理图片时发生错误: {e}")
# 在UI上给用户一个清晰的错误提示
gr.Error(f"处理失败!错误信息: {e}")
# 返回原始图像,而不是空着或保留上一次的结果
return input_image
# --- 4. Gradio UI ---
def create_task_tab(task_name: str):
"""动态创建每个任务的UI选项卡。"""
with gr.TabItem(task_name, id=task_name):
with gr.Row():
input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
output_img = gr.Image(type="pil", label="输出图片 (Output)")
submit_btn = gr.Button("开始处理 (Process)", variant="primary")
# ✨ 修正后的处理函数 ✨
# 这个函数只接收它能从 inputs 中得到的 `img` 参数。
def specific_process_fn(img):
# 调用 process_image 时不传递 progress 参数,
# 从而让 process_image 自动使用其函数定义中的默认值: progress=gr.Progress(...)
return process_image(img, task_name)
# click 事件的 inputs 列表只有一个元素,对应 specific_process_fn 的 img 参数
submit_btn.click(fn=specific_process_fn, inputs=[input_img], outputs=output_img)
if EXAMPLE_IMAGES.get(task_name):
gr.Examples(
examples=EXAMPLE_IMAGES.get(task_name, []),
inputs=input_img,
outputs=output_img,
fn=specific_process_fn, # 复用上面为按钮创建的处理函数
cache_examples=True,
)
# 创建应用主界面
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🖼️ 多功能图像复原工具 (AST 模型)
请选择一个任务,然后上传对应的图片或点击下方的示例图片进行处理。
"""
)
with gr.Tabs():
for task_name in MODEL_IDS.keys():
create_task_tab(task_name) # 调用函数为每个任务创建Tab
# 启动应用
demo.launch(server_name="0.0.0.0") |