File size: 13,346 Bytes
bb9d70a
 
 
 
 
 
 
 
 
 
 
75642ae
 
7357080
75642ae
 
 
7357080
75642ae
 
 
bb9d70a
 
7357080
bb9d70a
 
 
 
7357080
d07783b
 
 
7357080
d07783b
bb9d70a
 
 
 
 
 
 
 
 
 
7357080
 
 
bb9d70a
d07783b
7357080
9d7213d
7357080
b1b404c
9d7213d
 
bb9d70a
 
 
9d7213d
bb9d70a
9d7213d
 
b1b404c
9d7213d
b1b404c
9d7213d
 
b1b404c
9d7213d
 
 
b1b404c
5d8652f
b1b404c
5d8652f
 
 
 
9d7213d
5d8652f
9d7213d
5d8652f
9d7213d
 
b1b404c
5d8652f
 
 
9d7213d
5d8652f
 
9d7213d
 
 
 
 
 
 
 
 
 
5d8652f
bb9d70a
 
 
 
 
 
5d8652f
7357080
75642ae
7357080
 
 
bb9d70a
 
 
7357080
 
bb9d70a
7357080
bb9d70a
7357080
bb9d70a
7357080
 
 
 
 
 
bb9d70a
 
 
 
 
 
7357080
bb9d70a
 
9d7213d
7357080
 
9d7213d
b1b404c
7357080
 
 
bb9d70a
 
 
 
 
b1b404c
7357080
 
 
bb9d70a
 
9d7213d
 
b1b404c
 
 
 
 
 
9d7213d
 
b1b404c
 
 
 
 
bb9d70a
7357080
9d7213d
7357080
 
bb9d70a
 
 
7357080
bb9d70a
 
75642ae
7357080
 
 
bb9d70a
7357080
 
bb9d70a
 
7357080
bb9d70a
 
7357080
bb9d70a
7357080
bb9d70a
 
 
 
 
7357080
bb9d70a
 
7357080
bb9d70a
7357080
 
 
 
bb9d70a
 
7357080
 
 
9d7213d
 
 
 
5d8652f
 
9d7213d
7357080
 
b1b404c
9d7213d
b1b404c
 
9d7213d
7357080
9d7213d
 
b1b404c
9d7213d
b1b404c
 
 
 
9d7213d
b1b404c
 
 
7357080
b1b404c
7357080
b1b404c
7357080
 
9d7213d
b1b404c
6a78ea7
b1b404c
9d7213d
b1b404c
 
9d7213d
b1b404c
9d7213d
 
b1b404c
 
 
 
 
7357080
b1b404c
9d7213d
7357080
b1b404c
 
 
7357080
 
bb9d70a
9d7213d
7357080
4037b93
bb9d70a
7357080
 
9d7213d
 
 
 
bb9d70a
 
 
 
 
7357080
b1b404c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
from __future__ import annotations

import base64
import io
import os
import random
import threading
import time
import traceback
from typing import Optional

try:
    import spaces
except ImportError:
    class _SpacesFallback:
        @staticmethod
        def GPU(*args, **kwargs):
            def decorator(fn): return fn
            return decorator
    spaces = _SpacesFallback()

import gradio as gr
import torch
from PIL import Image

MODEL_REPO_ID = os.environ.get("REALRESTORER_MODEL_REPO", "RealRestorer/RealRestorer")
HF_TOKEN = os.environ.get("HF_TOKEN") or None

# --- Default Parameters ---
DEFAULT_STEPS = 28
DEFAULT_CFG = 3.0
DEFAULT_SEED = 42
FIXED_SIZE_LEVEL = 1024

TASK_PRESETS = {
    "Low-light Enhancement": "Please restore this low-quality image, recovering its normal brightness and clarity.",
    "Deblurring": "Please deblur the image and make it sharper.",
    "Deraining": "Please remove the rain from the image and restore its clarity.",
    "Compression Artifact Removal": "Please restore the image clarity and artifacts.",
    "Deflare": "Please remove the lens flare and glare from the image.",
    "Demoire": "Please remove the moire patterns from the image.",
    "Dehazing": "Please dehaze the image.",
    "Denoising": "Please remove noise from the image.",
    "Reflection Removal": "Please remove the reflection from the image.",
    "Underwater Image Enhancement": "Please enhance this underwater image, restoring its natural colors, brightness, and clarity.",
    "Old Photo Restoration": "Please restore this old photo, repairing damage and improving its clarity and overall quality.",
    "Desnowing": "Please remove the snow from the image and restore its visibility and clarity."
}
DEFAULT_PRESET = "Low-light Enhancement"

# --- UI Header ---
TITLE_HTML = """
<div style="text-align: center; margin-bottom: 20px;">
    <h1 style="font-size: 3rem; margin: 0 0 10px 0; background: linear-gradient(135deg, #4f46e5 0%, #ec4899 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; letter-spacing: -1px;">RealRestorer</h1>
    <p style="font-size: 1.15rem; color: var(--body-text-color-subdued); margin: 0; font-weight: 500;">A powerful image restoration model supporting deblurring, denoising, deflaring, low-light enhancement, and more.</p>
</div>
"""

# --- 核心修复:纯布局 CSS,不干扰 Gradio 原生颜色 ---
CUSTOM_CSS = """
/* 放宽最大宽度,横向占满 */
.gradio-container { max-width: 1400px !important; margin: auto !important; padding-top: 30px !important; }

/* 卡片容器:使用 Gradio 自带的 CSS 变量,完美适应深浅模式 */
.rr-section {
    background: var(--background-fill-primary) !important;
    border: 1px solid var(--border-color-primary) !important;
    border-radius: 16px !important;
    box-shadow: 0 4px 15px rgba(0, 0, 0, 0.05) !important;
    padding: 24px !important;
    margin-bottom: 24px !important;
}

/* 巨型行动按钮 */
#run-btn {
    background: linear-gradient(135deg, #4f46e5 0%, #6366f1 100%) !important;
    color: white !important;
    border: none !important;
    font-size: 1.25rem !important;
    font-weight: 700 !important;
    padding: 16px !important;
    border-radius: 12px !important;
    box-shadow: 0 4px 15px rgba(79, 70, 229, 0.4) !important;
    transition: transform 0.2s, box-shadow 0.2s !important;
    margin-top: 10px !important;
}
#run-btn:hover {
    transform: translateY(-2px) !important;
    box-shadow: 0 6px 20px rgba(79, 70, 229, 0.5) !important;
}

/* 状态栏:使用主题变量 */
.rr-status {
    background: var(--background-fill-secondary) !important;
    border: 1px solid var(--border-color-primary) !important;
    border-radius: 8px !important;
    padding: 12px 16px !important;
    font-size: 1rem !important;
    color: var(--body-text-color) !important;
    margin-bottom: 20px !important;
    font-family: ui-monospace, monospace !important;
}
"""

PIPELINE = None
PIPELINE_LOCK = threading.Lock()
INFERENCE_LOCK = threading.Lock()

@spaces.GPU(duration=180)
def _spaces_gpu_probe(): return None

def _pick_device():
    if torch.cuda.is_available(): return "cuda"
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps"
    return "cpu"

def _load_pipeline():
    global PIPELINE
    if PIPELINE is not None: return PIPELINE
    with PIPELINE_LOCK:
        if PIPELINE is not None: return PIPELINE
        from diffusers import RealRestorerPipeline
        
        device = _pick_device()
        dtype = torch.bfloat16 if device == "cuda" and torch.cuda.is_bf16_supported() else torch.float16 if device == "cuda" else torch.float32
        
        pipe = RealRestorerPipeline.from_pretrained(MODEL_REPO_ID, torch_dtype=dtype, token=HF_TOKEN)
        if device == "cuda": pipe.enable_model_cpu_offload(device=device)
        else: pipe.to(device)
        
        PIPELINE = pipe
        return PIPELINE

def _pil_to_data_url(image: Image.Image) -> str:
    buffer = io.BytesIO()
    image.save(buffer, format="PNG")
    return f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"

def _build_slider_html(before_image: Image.Image, after_image: Image.Image) -> str:
    # 巧妙利用 var(--body-text-color) 等变量,适配深色模式
    if before_image is None or after_image is None:
        return """
        <div style='height: 500px; width: 100%; display:flex; align-items:center; justify-content:center; color: var(--body-text-color-subdued); border: 2px dashed var(--border-color-primary); border-radius: 12px; font-size: 1.2rem; background: var(--background-fill-secondary);'>
            Upload an image and run to see the full-width comparison here.
        </div>
        """
    
    if before_image.size != after_image.size:
        before_image = before_image.resize(after_image.size, Image.LANCZOS)

    before_url = _pil_to_data_url(before_image)
    after_url = _pil_to_data_url(after_image)
    width, height = after_image.size
    slider_id = f"slider_{int(time.time() * 1000)}"
    
    on_input = f"var p=this.value; document.getElementById('{slider_id}_top').style.clipPath='inset(0 '+(100-p)+'% 0 0)'; document.getElementById('{slider_id}_line').style.left=p+'%';"

    return f"""
    <div style="display: flex; justify-content: center; width: 100%; background: var(--background-fill-secondary); border-radius: 12px; padding: 20px; border: 1px solid var(--border-color-primary); box-sizing: border-box;">
        <div style="position:relative; width:100%; max-height:80vh; aspect-ratio:{width}/{height}; overflow:hidden; border-radius:8px; box-shadow: 0 4px 15px rgba(0,0,0,0.2);">
          <!-- 底图: 修复后 -->
          <img src="{after_url}" style="position:absolute; top:0; left:0; width:100%; height:100%; object-fit:contain;" draggable="false" />
          <!-- 顶图: 原图 -->
          <img id="{slider_id}_top" src="{before_url}" style="position:absolute; top:0; left:0; width:100%; height:100%; object-fit:contain; clip-path:inset(0 50% 0 0);" draggable="false" />
          
          <!-- 拖拽线与把手 -->
          <div id="{slider_id}_line" style="position:absolute; top:0; left:50%; width:4px; height:100%; background:white; box-shadow:0 0 10px rgba(0,0,0,0.5); transform:translateX(-50%); pointer-events:none; z-index:10;">
            <div style="position:absolute; top:50%; left:50%; transform:translate(-50%, -50%); width:48px; height:48px; border-radius:50%; background:white; display:flex; align-items:center; justify-content:center; box-shadow:0 2px 10px rgba(0,0,0,0.4); font-weight:bold; color:#1e293b; font-size:1.3rem;">↔</div>
          </div>
          
          <!-- 透明原生滑动条 -->
          <input type="range" min="0" max="100" value="50" oninput="{on_input}" style="position:absolute; top:0; left:0; width:100%; height:100%; opacity:0; cursor:ew-resize; z-index:20; margin:0; appearance:auto;" />
        </div>
    </div>
    
    <div style="display:flex; justify-content:space-between; color:var(--body-text-color-subdued); font-size:1.05rem; padding-top:12px; font-weight:600;">
      <span>⬅️ Original</span>
      <span>Restored ➡️</span>
    </div>
    """

def _on_preset_change(preset_name: str):
    return TASK_PRESETS.get(preset_name, "")

@spaces.GPU(duration=180)
def run_inference(image: Optional[Image.Image], task_name: str, prompt: str, steps: float, guidance_scale: float, seed: float, progress=gr.Progress(track_tqdm=False)):
    if image is None: return None, "💡 Error: Please upload an image first.", _build_slider_html(None, None)
    
    source_image = image.convert("RGB")
    final_prompt = prompt.strip() or TASK_PRESETS.get(task_name, TASK_PRESETS[DEFAULT_PRESET])
    final_seed = int(seed) if seed >= 0 else random.randint(0, 2**31 - 1)

    try:
        progress(0.1, desc="Preparing model...")
        pipeline = _load_pipeline()
        start_time = time.time()
        
        with INFERENCE_LOCK:
            progress(0.3, desc="Restoring Image...")
            output = pipeline(
                image=source_image,
                prompt=final_prompt,
                num_inference_steps=int(steps),
                guidance_scale=float(guidance_scale),
                size_level=FIXED_SIZE_LEVEL,
                seed=final_seed,
            ).images[0]
            
        elapsed = time.time() - start_time
        status = f"✅ Success | Time: {elapsed:.1f}s | Steps: {int(steps)} | CFG: {guidance_scale} | Seed: {final_seed}"
        slider_html = _build_slider_html(source_image, output)
        return output, status, slider_html
        
    except Exception as exc:
        traceback.print_exc()
        return None, f"❌ Failed: {exc}", _build_slider_html(None, None)

def build_demo():
    # 使用纯净的 Soft 主题,交由 Gradio 原生处理深浅色模式的颜色变化
    theme = gr.themes.Soft(
        primary_hue="indigo",
        text_size=gr.themes.sizes.text_lg
    )

    with gr.Blocks(css=CUSTOM_CSS, title="RealRestorer", theme=theme) as demo:
        gr.HTML(TITLE_HTML)

        # ==========================================
        # Section 1: 图像上传与控制参数 (纵向排列)
        # ==========================================
        with gr.Column(elem_classes=["rr-section"]):
            gr.Markdown("### 📥 1. Upload & Settings")
            
            # 横向宽屏上传框
            input_image = gr.Image(label="Upload Image", type="pil", height=420)
            
            # 控制参数横向铺开
            with gr.Row():
                with gr.Column(scale=1):
                    task_dropdown = gr.Dropdown(choices=list(TASK_PRESETS.keys()), value=DEFAULT_PRESET, label="Task Preset")
                with gr.Column(scale=2):
                    prompt_box = gr.Textbox(label="Instruction (输入指令)", value=TASK_PRESETS[DEFAULT_PRESET], lines=1)
            
            with gr.Row():
                with gr.Column(scale=1):
                    guidance_slider = gr.Slider(minimum=1.0, maximum=6.0, value=DEFAULT_CFG, step=0.1, label="CFG / Guidance Scale")
                with gr.Column(scale=1):
                    steps_slider = gr.Slider(minimum=12, maximum=40, value=DEFAULT_STEPS, step=1, label="Inference Steps")
                with gr.Column(scale=1):
                    seed_box = gr.Number(label="Seed (-1 for random)", value=DEFAULT_SEED, precision=0)

            # 巨型运行按钮
            run_button = gr.Button("🚀 Run Restoration", elem_id="run-btn")

        # ==========================================
        # Section 2: 结果展示 (横跨全屏)
        # ==========================================
        with gr.Column(elem_classes=["rr-section"]):
            gr.Markdown("### 🖼️ 2. Restoration Results")
            
            status_box = gr.HTML(
                value="<div class='rr-status'>💡 Status: Ready. Upload an image, adjust settings, and click Run.</div>"
            )
            
            with gr.Tabs():
                with gr.Tab("Compare View"):
                    slider_view = gr.HTML(_build_slider_html(None, None))
                
                with gr.Tab("Output Image"):
                    output_image = gr.Image(label="Restored Result", type="pil", interactive=False, height=600, show_label=False)

        # ==========================================
        # Events
        # ==========================================
        task_dropdown.change(fn=_on_preset_change, inputs=[task_dropdown], outputs=[prompt_box])
        
        run_button.click(
            fn=lambda: ("<div class='rr-status'>⏳ Status: Processing... Please wait.</div>", _build_slider_html(None, None)), 
            outputs=[status_box, slider_view]
        ).then(
            fn=run_inference,
            inputs=[input_image, task_dropdown, prompt_box, steps_slider, guidance_slider, seed_box],
            outputs=[output_image, status_box, slider_view],
        ).then(
            fn=lambda status: f"<div class='rr-status'>{status}</div>",
            inputs=[status_box],
            outputs=[status_box]
        )

    return demo

if __name__ == "__main__":
    demo = build_demo()
    demo.queue(max_size=8, default_concurrency_limit=1).launch(server_name="0.0.0.0", show_error=True)