Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from transformers import CLIPImageProcessor | |
| from modeling_ast import ASTForRestoration | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| from torchvision.transforms.functional import to_pil_image, to_tensor | |
| from tqdm import tqdm | |
| import math | |
| # --- 1. 配置 --- | |
| MODEL_IDS = { | |
| "去雨痕 (Derain)": "Suncongcong/AST_DeRain", | |
| "去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop", | |
| "去雾 (Dehaze)": "Suncongcong/AST_Dehazing" | |
| } | |
| EXAMPLE_IMAGES = { | |
| "去雨痕 (Derain)": [["derain_example1.png"], ["derain_example2.png"], ["derain_example3.png"]], | |
| "去雨滴 (Deraindrop)": [["deraindrop_example1.png"], ["deraindrop_example2.png"], ["deraindrop_example3.png"]], | |
| "去雾 (Dehaze)": [["dehaze_example1.jpg"],["dehaze_example2.jpg"],["dehaze_example3.jpg"]] | |
| } | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"正在使用的设备: {device}") | |
| # --- 2. 加载所有模型和处理器 --- | |
| MODELS = {} | |
| PROCESSOR = None | |
| print("正在加载所有模型和处理器...") | |
| try: | |
| for task_name, repo_id in MODEL_IDS.items(): | |
| print(f"正在加载模型: {task_name} ({repo_id})") | |
| if PROCESSOR is None: | |
| PROCESSOR = CLIPImageProcessor.from_pretrained(repo_id) | |
| print("✅ 处理器加载成功。") | |
| model = ASTForRestoration.from_pretrained(repo_id, trust_remote_code=True).to(device).eval() | |
| MODELS[task_name] = model | |
| print(f"✅ 模型 '{task_name}' 加载成功。") | |
| except Exception as e: | |
| print(f"加载模型时出错: {e}") | |
| def load_error_func(*args, **kwargs): | |
| raise gr.Error(f"模型加载失败! 错误: {e}") | |
| MODELS = {task: load_error_func for task in MODEL_IDS.keys()} | |
| print("所有模型加载完毕,准备就绪!") | |
| # --- 3. 定义不同任务的处理函数 --- | |
| def process_with_pad_to_square(model, img_tensor): | |
| """将图片填充为正方形后进行处理,适用于去雨/去雨滴任务。""" | |
| def expand2square(timg, factor=128.0): | |
| # factor: 模型的网络结构要求输入的尺寸最好是该值的整数倍 | |
| _, _, h, w = timg.size() | |
| X = int(math.ceil(max(h, w) / factor) * factor) | |
| # 确保创建的张量在正确的设备上 | |
| img_padded = torch.zeros(1, 3, X, X, device=timg.device, dtype=timg.dtype) | |
| mask = torch.zeros(1, 1, X, X, device=timg.device, dtype=timg.dtype) | |
| pad_h = (X - h) // 2 | |
| pad_w = (X - w) // 2 | |
| img_padded[:, :, pad_h:pad_h + h, pad_w:pad_w + w] = timg | |
| mask[:, :, pad_h:pad_h + h, pad_w:pad_w + w].fill_(1) | |
| return img_padded, mask | |
| original_h, original_w = img_tensor.shape[2:] | |
| padded_input, mask = expand2square(img_tensor.to(device), factor=128.0) | |
| with torch.no_grad(): | |
| restored_padded = model(padded_input) | |
| # 确保 mask 和 restored_padded 在同一设备上 | |
| mask_bool = mask.bool().to(restored_padded.device) | |
| restored_tensor = torch.masked_select( | |
| restored_padded, mask_bool | |
| ).reshape(1, 3, original_h, original_w) | |
| return restored_tensor | |
| def process_with_dehaze_tiling(model, img_tensor, progress): | |
| """使用重叠分块策略处理图像,适用于去雾任务。""" | |
| # 将“魔法数字”定义为常量并添加注释 | |
| CROP_SIZE = 1152 # 每个图块的尺寸 | |
| OVERLAP = 384 # 图块之间的重叠区域大小,以避免边缘效应 | |
| b, c, h_orig, w_orig = img_tensor.shape | |
| stride = CROP_SIZE - OVERLAP | |
| # 计算需要填充的尺寸 | |
| h_pad = (stride - (h_orig - OVERLAP) % stride) % stride if h_orig > OVERLAP else 0 | |
| w_pad = (stride - (w_orig - OVERLAP) % stride) % stride if w_orig > OVERLAP else 0 | |
| img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate') | |
| b, c, h_padded, w_padded = img_padded.shape | |
| # 使用CPU来存储最终结果,避免占用大量显存 | |
| output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu') | |
| weight_map = torch.zeros_like(output_canvas) | |
| h_steps_range = range(0, h_padded - OVERLAP, stride) if h_padded > OVERLAP else [0] | |
| w_steps_range = range(0, w_padded - OVERLAP, stride) if w_padded > OVERLAP else [0] | |
| # 使用Gradio的进度条 | |
| for y in progress.tqdm(h_steps_range, desc="正在分块去雾..."): | |
| for x in w_steps_range: | |
| # 确保切片范围正确 | |
| y_end = min(y + CROP_SIZE, h_padded) | |
| x_end = min(x + CROP_SIZE, w_padded) | |
| patch_in = img_padded[:, :, y:y_end, x:x_end] | |
| with torch.no_grad(): | |
| patch_out = model(patch_in.to(device)).cpu() | |
| output_canvas[:, :, y:y_end, x:x_end] += patch_out | |
| weight_map[:, :, y:y_end, x:x_end] += 1 | |
| restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1) | |
| return restored_padded_tensor[:, :, :h_orig, :w_orig] | |
| def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)): | |
| """主处理函数,根据任务名分派到不同的处理流程。""" | |
| if input_image is None: | |
| gr.Warning("请输入一张图片!") | |
| return None | |
| # 增加完整的运行时错误捕获 | |
| try: | |
| model = MODELS[task_name] | |
| print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}") | |
| # 检查模型是否加载成功 | |
| if not isinstance(model, torch.nn.Module): | |
| model() # 如果加载失败,这里会触发 load_error_func 并抛出异常 | |
| img = input_image.convert("RGB") | |
| img_tensor = to_tensor(img).unsqueeze(0) | |
| if task_name == "去雾 (Dehaze)": | |
| restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress) | |
| else: | |
| restored_tensor = process_with_pad_to_square(model, img_tensor) | |
| restored_tensor = torch.clamp(restored_tensor, 0, 1) | |
| restored_image = to_pil_image(restored_tensor.cpu().squeeze(0)) | |
| return restored_image | |
| except Exception as e: | |
| print(f"处理图片时发生错误: {e}") | |
| # 在UI上给用户一个清晰的错误提示 | |
| gr.Error(f"处理失败!错误信息: {e}") | |
| # 返回原始图像,而不是空着或保留上一次的结果 | |
| return input_image | |
| # --- 4. Gradio UI --- | |
| def create_task_tab(task_name: str): | |
| """动态创建每个任务的UI选项卡。""" | |
| with gr.TabItem(task_name, id=task_name): | |
| with gr.Row(): | |
| input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})") | |
| output_img = gr.Image(type="pil", label="输出图片 (Output)") | |
| submit_btn = gr.Button("开始处理 (Process)", variant="primary") | |
| # ✨ 修正后的处理函数 ✨ | |
| # 这个函数只接收它能从 inputs 中得到的 `img` 参数。 | |
| def specific_process_fn(img): | |
| # 调用 process_image 时不传递 progress 参数, | |
| # 从而让 process_image 自动使用其函数定义中的默认值: progress=gr.Progress(...) | |
| return process_image(img, task_name) | |
| # click 事件的 inputs 列表只有一个元素,对应 specific_process_fn 的 img 参数 | |
| submit_btn.click(fn=specific_process_fn, inputs=[input_img], outputs=output_img) | |
| if EXAMPLE_IMAGES.get(task_name): | |
| gr.Examples( | |
| examples=EXAMPLE_IMAGES.get(task_name, []), | |
| inputs=input_img, | |
| outputs=output_img, | |
| fn=specific_process_fn, # 复用上面为按钮创建的处理函数 | |
| cache_examples=True, | |
| ) | |
| # 创建应用主界面 | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🖼️ 多功能图像复原工具 (AST 模型) | |
| 请选择一个任务,然后上传对应的图片或点击下方的示例图片进行处理。 | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| for task_name in MODEL_IDS.keys(): | |
| create_task_tab(task_name) # 调用函数为每个任务创建Tab | |
| # 启动应用 | |
| demo.launch(server_name="0.0.0.0") |