import gradio as gr import torch import torch.nn.functional as F import numpy as np from transformers import CLIPImageProcessor from modeling_ast import ASTForRestoration from PIL import Image import requests from io import BytesIO from torchvision.transforms.functional import to_pil_image, to_tensor from tqdm import tqdm import math # --- 1. 配置 --- MODEL_IDS = { "去雨痕 (Derain)": "Suncongcong/AST_DeRain", "去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop", "去雾 (Dehaze)": "Suncongcong/AST_Dehazing" } EXAMPLE_IMAGES = { "去雨痕 (Derain)": [["derain_example1.png"], ["derain_example2.png"], ["derain_example3.png"]], "去雨滴 (Deraindrop)": [["deraindrop_example1.png"], ["deraindrop_example2.png"], ["deraindrop_example3.png"]], "去雾 (Dehaze)": [["dehaze_example1.jpg"],["dehaze_example2.jpg"],["dehaze_example3.jpg"]] } device = "cuda" if torch.cuda.is_available() else "cpu" print(f"正在使用的设备: {device}") # --- 2. 加载所有模型和处理器 --- MODELS = {} PROCESSOR = None print("正在加载所有模型和处理器...") try: for task_name, repo_id in MODEL_IDS.items(): print(f"正在加载模型: {task_name} ({repo_id})") if PROCESSOR is None: PROCESSOR = CLIPImageProcessor.from_pretrained(repo_id) print("✅ 处理器加载成功。") model = ASTForRestoration.from_pretrained(repo_id, trust_remote_code=True).to(device).eval() MODELS[task_name] = model print(f"✅ 模型 '{task_name}' 加载成功。") except Exception as e: print(f"加载模型时出错: {e}") def load_error_func(*args, **kwargs): raise gr.Error(f"模型加载失败! 错误: {e}") MODELS = {task: load_error_func for task in MODEL_IDS.keys()} print("所有模型加载完毕,准备就绪!") # --- 3. 定义不同任务的处理函数 --- def process_with_pad_to_square(model, img_tensor): """将图片填充为正方形后进行处理,适用于去雨/去雨滴任务。""" def expand2square(timg, factor=128.0): # factor: 模型的网络结构要求输入的尺寸最好是该值的整数倍 _, _, h, w = timg.size() X = int(math.ceil(max(h, w) / factor) * factor) # 确保创建的张量在正确的设备上 img_padded = torch.zeros(1, 3, X, X, device=timg.device, dtype=timg.dtype) mask = torch.zeros(1, 1, X, X, device=timg.device, dtype=timg.dtype) pad_h = (X - h) // 2 pad_w = (X - w) // 2 img_padded[:, :, pad_h:pad_h + h, pad_w:pad_w + w] = timg mask[:, :, pad_h:pad_h + h, pad_w:pad_w + w].fill_(1) return img_padded, mask original_h, original_w = img_tensor.shape[2:] padded_input, mask = expand2square(img_tensor.to(device), factor=128.0) with torch.no_grad(): restored_padded = model(padded_input) # 确保 mask 和 restored_padded 在同一设备上 mask_bool = mask.bool().to(restored_padded.device) restored_tensor = torch.masked_select( restored_padded, mask_bool ).reshape(1, 3, original_h, original_w) return restored_tensor def process_with_dehaze_tiling(model, img_tensor, progress): """使用重叠分块策略处理图像,适用于去雾任务。""" # 将“魔法数字”定义为常量并添加注释 CROP_SIZE = 1152 # 每个图块的尺寸 OVERLAP = 384 # 图块之间的重叠区域大小,以避免边缘效应 b, c, h_orig, w_orig = img_tensor.shape stride = CROP_SIZE - OVERLAP # 计算需要填充的尺寸 h_pad = (stride - (h_orig - OVERLAP) % stride) % stride if h_orig > OVERLAP else 0 w_pad = (stride - (w_orig - OVERLAP) % stride) % stride if w_orig > OVERLAP else 0 img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate') b, c, h_padded, w_padded = img_padded.shape # 使用CPU来存储最终结果,避免占用大量显存 output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu') weight_map = torch.zeros_like(output_canvas) h_steps_range = range(0, h_padded - OVERLAP, stride) if h_padded > OVERLAP else [0] w_steps_range = range(0, w_padded - OVERLAP, stride) if w_padded > OVERLAP else [0] # 使用Gradio的进度条 for y in progress.tqdm(h_steps_range, desc="正在分块去雾..."): for x in w_steps_range: # 确保切片范围正确 y_end = min(y + CROP_SIZE, h_padded) x_end = min(x + CROP_SIZE, w_padded) patch_in = img_padded[:, :, y:y_end, x:x_end] with torch.no_grad(): patch_out = model(patch_in.to(device)).cpu() output_canvas[:, :, y:y_end, x:x_end] += patch_out weight_map[:, :, y:y_end, x:x_end] += 1 restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1) return restored_padded_tensor[:, :, :h_orig, :w_orig] def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)): """主处理函数,根据任务名分派到不同的处理流程。""" if input_image is None: gr.Warning("请输入一张图片!") return None # 增加完整的运行时错误捕获 try: model = MODELS[task_name] print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}") # 检查模型是否加载成功 if not isinstance(model, torch.nn.Module): model() # 如果加载失败,这里会触发 load_error_func 并抛出异常 img = input_image.convert("RGB") img_tensor = to_tensor(img).unsqueeze(0) if task_name == "去雾 (Dehaze)": restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress) else: restored_tensor = process_with_pad_to_square(model, img_tensor) restored_tensor = torch.clamp(restored_tensor, 0, 1) restored_image = to_pil_image(restored_tensor.cpu().squeeze(0)) return restored_image except Exception as e: print(f"处理图片时发生错误: {e}") # 在UI上给用户一个清晰的错误提示 gr.Error(f"处理失败!错误信息: {e}") # 返回原始图像,而不是空着或保留上一次的结果 return input_image # --- 4. Gradio UI --- def create_task_tab(task_name: str): """动态创建每个任务的UI选项卡。""" with gr.TabItem(task_name, id=task_name): with gr.Row(): input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})") output_img = gr.Image(type="pil", label="输出图片 (Output)") submit_btn = gr.Button("开始处理 (Process)", variant="primary") # ✨ 修正后的处理函数 ✨ # 这个函数只接收它能从 inputs 中得到的 `img` 参数。 def specific_process_fn(img): # 调用 process_image 时不传递 progress 参数, # 从而让 process_image 自动使用其函数定义中的默认值: progress=gr.Progress(...) return process_image(img, task_name) # click 事件的 inputs 列表只有一个元素,对应 specific_process_fn 的 img 参数 submit_btn.click(fn=specific_process_fn, inputs=[input_img], outputs=output_img) if EXAMPLE_IMAGES.get(task_name): gr.Examples( examples=EXAMPLE_IMAGES.get(task_name, []), inputs=input_img, outputs=output_img, fn=specific_process_fn, # 复用上面为按钮创建的处理函数 cache_examples=True, ) # 创建应用主界面 with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🖼️ 多功能图像复原工具 (AST 模型) 请选择一个任务,然后上传对应的图片或点击下方的示例图片进行处理。 """ ) with gr.Tabs(): for task_name in MODEL_IDS.keys(): create_task_tab(task_name) # 调用函数为每个任务创建Tab # 启动应用 demo.launch(server_name="0.0.0.0")