File size: 8,271 Bytes
a143c1a
 
9ae093c
942c73d
a143c1a
 
 
 
 
942c73d
 
d5dd056
a143c1a
9367fe4
8525111
6a6f453
 
 
8525111
e10fd9a
6a6f453
 
ae5af4b
e10fd9a
a143c1a
 
 
9367fe4
8525111
 
 
 
6a6f453
 
 
 
 
 
 
 
8525111
6a6f453
 
 
 
8525111
 
7ec5312
ae5af4b
7ec5312
ae5af4b
7ec5312
6a6f453
7ec5312
ae5af4b
 
9367fe4
7ec5312
 
 
 
 
 
 
6a6f453
7ec5312
6a6f453
 
7ec5312
6a6f453
 
7ec5312
9367fe4
7ec5312
 
6a6f453
7ec5312
6a6f453
7ec5312
ae5af4b
 
7ec5312
ae5af4b
7ec5312
9367fe4
7ec5312
 
 
ae5af4b
 
7ec5312
 
9367fe4
 
7ec5312
ae5af4b
7ec5312
 
 
ae5af4b
 
7ec5312
9367fe4
 
7ec5312
9367fe4
7ec5312
 
9367fe4
 
 
 
 
ae5af4b
 
7ec5312
9367fe4
 
7ec5312
ae5af4b
 
 
7ec5312
ae5af4b
7ec5312
 
 
 
 
9367fe4
7ec5312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9367fe4
7ec5312
 
 
 
 
 
 
9367fe4
 
 
 
 
 
 
 
7ec5312
ae5af4b
7ec5312
 
 
 
 
 
 
 
575a9d2
9367fe4
942c73d
6a6f453
 
 
 
 
 
 
 
9367fe4
97d6718
9367fe4
97d6718
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
from transformers import CLIPImageProcessor
from modeling_ast import ASTForRestoration
from PIL import Image
import requests
from io import BytesIO
from torchvision.transforms.functional import to_pil_image, to_tensor
from tqdm import tqdm
import math

# --- 1. 配置 ---
MODEL_IDS = {
    "去雨痕 (Derain)": "Suncongcong/AST_DeRain",
    "去雨滴 (Deraindrop)": "Suncongcong/AST_DeRainDrop",
    "去雾 (Dehaze)": "Suncongcong/AST_Dehazing"
}
EXAMPLE_IMAGES = {
    "去雨痕 (Derain)": [["derain_example1.png"], ["derain_example2.png"], ["derain_example3.png"]],
    "去雨滴 (Deraindrop)": [["deraindrop_example1.png"], ["deraindrop_example2.png"], ["deraindrop_example3.png"]],
    "去雾 (Dehaze)": [["dehaze_example1.jpg"],["dehaze_example2.jpg"],["dehaze_example3.jpg"]]
}
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"正在使用的设备: {device}")

# --- 2. 加载所有模型和处理器 ---
MODELS = {}
PROCESSOR = None
print("正在加载所有模型和处理器...")
try:
    for task_name, repo_id in MODEL_IDS.items():
        print(f"正在加载模型: {task_name} ({repo_id})")
        if PROCESSOR is None:
            PROCESSOR = CLIPImageProcessor.from_pretrained(repo_id)
            print("✅ 处理器加载成功。")
        model = ASTForRestoration.from_pretrained(repo_id, trust_remote_code=True).to(device).eval()
        MODELS[task_name] = model
        print(f"✅ 模型 '{task_name}' 加载成功。")
except Exception as e:
    print(f"加载模型时出错: {e}")
    def load_error_func(*args, **kwargs):
        raise gr.Error(f"模型加载失败! 错误: {e}")
    MODELS = {task: load_error_func for task in MODEL_IDS.keys()}
print("所有模型加载完毕,准备就绪!")


# --- 3. 定义不同任务的处理函数 ---

def process_with_pad_to_square(model, img_tensor):
    """将图片填充为正方形后进行处理,适用于去雨/去雨滴任务。"""
    def expand2square(timg, factor=128.0):
        # factor: 模型的网络结构要求输入的尺寸最好是该值的整数倍
        _, _, h, w = timg.size()
        X = int(math.ceil(max(h, w) / factor) * factor)
        # 确保创建的张量在正确的设备上
        img_padded = torch.zeros(1, 3, X, X, device=timg.device, dtype=timg.dtype)
        mask = torch.zeros(1, 1, X, X, device=timg.device, dtype=timg.dtype)
        
        pad_h = (X - h) // 2
        pad_w = (X - w) // 2
        img_padded[:, :, pad_h:pad_h + h, pad_w:pad_w + w] = timg
        mask[:, :, pad_h:pad_h + h, pad_w:pad_w + w].fill_(1)
        return img_padded, mask
        
    original_h, original_w = img_tensor.shape[2:]
    padded_input, mask = expand2square(img_tensor.to(device), factor=128.0)
    
    with torch.no_grad():
        restored_padded = model(padded_input)
        
    # 确保 mask 和 restored_padded 在同一设备上
    mask_bool = mask.bool().to(restored_padded.device)
    
    restored_tensor = torch.masked_select(
        restored_padded, mask_bool
    ).reshape(1, 3, original_h, original_w)
    
    return restored_tensor


def process_with_dehaze_tiling(model, img_tensor, progress):
    """使用重叠分块策略处理图像,适用于去雾任务。"""
    # 将“魔法数字”定义为常量并添加注释
    CROP_SIZE = 1152  # 每个图块的尺寸
    OVERLAP = 384   # 图块之间的重叠区域大小,以避免边缘效应
    
    b, c, h_orig, w_orig = img_tensor.shape
    stride = CROP_SIZE - OVERLAP

    # 计算需要填充的尺寸
    h_pad = (stride - (h_orig - OVERLAP) % stride) % stride if h_orig > OVERLAP else 0
    w_pad = (stride - (w_orig - OVERLAP) % stride) % stride if w_orig > OVERLAP else 0

    img_padded = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'replicate')
    b, c, h_padded, w_padded = img_padded.shape

    # 使用CPU来存储最终结果,避免占用大量显存
    output_canvas = torch.zeros((b, c, h_padded, w_padded), device='cpu')
    weight_map = torch.zeros_like(output_canvas)
    
    h_steps_range = range(0, h_padded - OVERLAP, stride) if h_padded > OVERLAP else [0]
    w_steps_range = range(0, w_padded - OVERLAP, stride) if w_padded > OVERLAP else [0]
    
    # 使用Gradio的进度条
    for y in progress.tqdm(h_steps_range, desc="正在分块去雾..."):
        for x in w_steps_range:
            # 确保切片范围正确
            y_end = min(y + CROP_SIZE, h_padded)
            x_end = min(x + CROP_SIZE, w_padded)
            patch_in = img_padded[:, :, y:y_end, x:x_end]

            with torch.no_grad():
                patch_out = model(patch_in.to(device)).cpu()
            
            output_canvas[:, :, y:y_end, x:x_end] += patch_out
            weight_map[:, :, y:y_end, x:x_end] += 1
            
    restored_padded_tensor = output_canvas / torch.clamp(weight_map, min=1)
    return restored_padded_tensor[:, :, :h_orig, :w_orig]


def process_image(input_image: Image.Image, task_name: str, progress=gr.Progress(track_tqdm=True)):
    """主处理函数,根据任务名分派到不同的处理流程。"""
    if input_image is None:
        gr.Warning("请输入一张图片!")
        return None
        
    # 增加完整的运行时错误捕获
    try:
        model = MODELS[task_name]
        print(f"已选择任务: {task_name}, 使用模型: {MODEL_IDS[task_name]}")
        
        # 检查模型是否加载成功
        if not isinstance(model, torch.nn.Module):
            model() # 如果加载失败,这里会触发 load_error_func 并抛出异常
        
        img = input_image.convert("RGB")
        img_tensor = to_tensor(img).unsqueeze(0)
        
        if task_name == "去雾 (Dehaze)":
            restored_tensor = process_with_dehaze_tiling(model, img_tensor, progress)
        else:
            restored_tensor = process_with_pad_to_square(model, img_tensor)
            
        restored_tensor = torch.clamp(restored_tensor, 0, 1)
        restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
        return restored_image
        
    except Exception as e:
        print(f"处理图片时发生错误: {e}")
        # 在UI上给用户一个清晰的错误提示
        gr.Error(f"处理失败!错误信息: {e}")
        # 返回原始图像,而不是空着或保留上一次的结果
        return input_image


# --- 4. Gradio UI ---

def create_task_tab(task_name: str):
    """动态创建每个任务的UI选项卡。"""
    with gr.TabItem(task_name, id=task_name):
        with gr.Row():
            input_img = gr.Image(type="pil", label=f"输入图片 (Input for {task_name})")
            output_img = gr.Image(type="pil", label="输出图片 (Output)")
        
        submit_btn = gr.Button("开始处理 (Process)", variant="primary")
        
        # ✨ 修正后的处理函数 ✨
        # 这个函数只接收它能从 inputs 中得到的 `img` 参数。
        def specific_process_fn(img):
            # 调用 process_image 时不传递 progress 参数,
            # 从而让 process_image 自动使用其函数定义中的默认值: progress=gr.Progress(...)
            return process_image(img, task_name)

        # click 事件的 inputs 列表只有一个元素,对应 specific_process_fn 的 img 参数
        submit_btn.click(fn=specific_process_fn, inputs=[input_img], outputs=output_img)
        
        if EXAMPLE_IMAGES.get(task_name):
            gr.Examples(
                examples=EXAMPLE_IMAGES.get(task_name, []),
                inputs=input_img,
                outputs=output_img,
                fn=specific_process_fn, # 复用上面为按钮创建的处理函数
                cache_examples=True,
            )

# 创建应用主界面
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🖼️ 多功能图像复原工具 (AST 模型)
        请选择一个任务,然后上传对应的图片或点击下方的示例图片进行处理。
        """
    )
    with gr.Tabs():
        for task_name in MODEL_IDS.keys():
            create_task_tab(task_name) # 调用函数为每个任务创建Tab

# 启动应用
demo.launch(server_name="0.0.0.0")