AST-Demo / app.py
suncongcong's picture
Update app.py
9367fe4 verified
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
from transformers import CLIPImageProcessor
from modeling_ast import ASTForRestoration
from PIL import Image
import requests
from io import BytesIO
from torchvision.transforms.functional import to_pil_image, to_tensor
from tqdm import tqdm
import math
# --- 1. 配置 ---
MODEL_IDS = {
"去雨痕 (Derain)": "Suncongcong/AST_DeRain",
"去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
"去雾 (Dehaze)": "Suncongcong/AST_Dehazing"
}
EXAMPLE_IMAGES = {
"去雨痕 (Derain)": [["derain_example1.png"], ["derain_example2.png"], ["derain_example3.png"]],
"去雨滴 (Deraindrop)": [["deraindrop_example1.png"], ["deraindrop_example2.png"], ["deraindrop_example3.png"]],
"去雾 (Dehaze)": [["dehaze_example1.jpg"],["dehaze_example2.jpg"],["dehaze_example3.jpg"]]
}
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"正在使用的设备: {device}")
# --- 2. 加载所有模型和处理器 ---
MODELS = {}
PROCESSOR = None
print("正在加载所有模型和处理器...")
try:
for task_name, repo_id in MODEL_IDS.items():
print(f"正在加载模型: {task_name} ({repo_id})")
if PROCESSOR is None:
PROCESSOR = CLIPImageProcessor.from_pretrained(repo_id)
print("✅ 处理器加载成功。")
model = ASTForRestoration.from_pretrained(repo_id, trust_remote_code=True).to(device).eval()
MODELS[task_name] = model
print(f"✅ 模型 '{task_name}' 加载成功。")
except Exception as e:
print(f"加载模型时出错: {e}")
def load_error_func(*args, **kwargs):
raise gr.Error(f"模型加载失败! 错误: {e}")
MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
print("所有模型加载完毕,准备就绪!")
# --- 3. 定义不同任务的处理函数 ---
def process_with_pad_to_square(model, img_tensor):
"""将图片填充为正方形后进行处理,适用于去雨/去雨滴任务。"""
def expand2square(timg, factor=128.0):
# factor: 模型的网络结构要求输入的尺寸最好是该值的整数倍
_, _, h, w = timg.size()
X = int(math.ceil(max(h, w) / factor) * factor)
# 确保创建的张量在正确的设备上
img_padded = torch.zeros(1, 3, X, X, device=timg.device, dtype=timg.dtype)
mask = torch.zeros(1, 1, X, X, device=timg.device, dtype=timg.dtype)
pad_h = (X - h) // 2
pad_w = (X - w) // 2
img_padded[:, :, pad_h:pad_h + h, pad_w:pad_w + w] = timg
mask[:, :, pad_h:pad_h + h, pad_w:pad_w + w].fill_(1)
return img_padded, mask
original_h, original_w = img_tensor.shape[2:]
padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
with torch.no_grad():
restored_padded = model(padded_input)
# 确保 mask 和 restored_padded 在同一设备上
mask_bool = mask.bool().to(restored_padded.device)
restored_tensor = torch.masked_select(
restored_padded, mask_bool
).reshape(1, 3, original_h, original_w)
return restored_tensor
def process_with_dehaze_tiling(model, img_tensor, progress):
"""使用重叠分块策略处理图像,适用于去雾任务。"""
# 将“魔法数字”定义为常量并添加注释
CROP_SIZE = 1152 # 每个图块的尺寸
OVERLAP = 384 # 图块之间的重叠区域大小,以避免边缘效应
b, c, h_orig, w_orig = img_tensor.shape
stride = CROP_SIZE - OVERLAP
# 计算需要填充的尺寸
h_pad = (stride - (h_orig - OVERLAP) % stride) % stride if h_orig > OVERLAP else 0
w_pad = (stride - (w_orig - OVERLAP) % stride) % stride if w_orig > OVERLAP else 0
img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
b, c, h_padded, w_padded = img_padded.shape
# 使用CPU来存储最终结果,避免占用大量显存
output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
weight_map = torch.zeros_like(output_canvas)
h_steps_range = range(0, h_padded - OVERLAP, stride) if h_padded > OVERLAP else [0]
w_steps_range = range(0, w_padded - OVERLAP, stride) if w_padded > OVERLAP else [0]
# 使用Gradio的进度条
for y in progress.tqdm(h_steps_range, desc="正在分块去雾..."):
for x in w_steps_range:
# 确保切片范围正确
y_end = min(y + CROP_SIZE, h_padded)
x_end = min(x + CROP_SIZE, w_padded)
patch_in = img_padded[:, :, y:y_end, x:x_end]
with torch.no_grad():
patch_out = model(patch_in.to(device)).cpu()
output_canvas[:, :, y:y_end, x:x_end] += patch_out
weight_map[:, :, y:y_end, x:x_end] += 1
restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
return restored_padded_tensor[:, :, :h_orig, :w_orig]
def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
"""主处理函数,根据任务名分派到不同的处理流程。"""
if input_image is None:
gr.Warning("请输入一张图片!")
return None
# 增加完整的运行时错误捕获
try:
model = MODELS[task_name]
print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
# 检查模型是否加载成功
if not isinstance(model, torch.nn.Module):
model() # 如果加载失败,这里会触发 load_error_func 并抛出异常
img = input_image.convert("RGB")
img_tensor = to_tensor(img).unsqueeze(0)
if task_name == "去雾 (Dehaze)":
restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress)
else:
restored_tensor = process_with_pad_to_square(model, img_tensor)
restored_tensor = torch.clamp(restored_tensor, 0, 1)
restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
return restored_image
except Exception as e:
print(f"处理图片时发生错误: {e}")
# 在UI上给用户一个清晰的错误提示
gr.Error(f"处理失败!错误信息: {e}")
# 返回原始图像,而不是空着或保留上一次的结果
return input_image
# --- 4. Gradio UI ---
def create_task_tab(task_name: str):
"""动态创建每个任务的UI选项卡。"""
with gr.TabItem(task_name, id=task_name):
with gr.Row():
input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
output_img = gr.Image(type="pil", label="输出图片 (Output)")
submit_btn = gr.Button("开始处理 (Process)", variant="primary")
# ✨ 修正后的处理函数 ✨
# 这个函数只接收它能从 inputs 中得到的 `img` 参数。
def specific_process_fn(img):
# 调用 process_image 时不传递 progress 参数,
# 从而让 process_image 自动使用其函数定义中的默认值: progress=gr.Progress(...)
return process_image(img, task_name)
# click 事件的 inputs 列表只有一个元素,对应 specific_process_fn 的 img 参数
submit_btn.click(fn=specific_process_fn, inputs=[input_img], outputs=output_img)
if EXAMPLE_IMAGES.get(task_name):
gr.Examples(
examples=EXAMPLE_IMAGES.get(task_name, []),
inputs=input_img,
outputs=output_img,
fn=specific_process_fn, # 复用上面为按钮创建的处理函数
cache_examples=True,
)
# 创建应用主界面
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🖼️ 多功能图像复原工具 (AST 模型)
请选择一个任务,然后上传对应的图片或点击下方的示例图片进行处理。
"""
)
with gr.Tabs():
for task_name in MODEL_IDS.keys():
create_task_tab(task_name) # 调用函数为每个任务创建Tab
# 启动应用
demo.launch(server_name="0.0.0.0")