Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import base64 | |
| import io | |
| import os | |
| import random | |
| import threading | |
| import time | |
| import traceback | |
| from typing import Optional | |
| try: | |
| import spaces | |
| except ImportError: | |
| class _SpacesFallback: | |
| def GPU(*args, **kwargs): | |
| def decorator(fn): return fn | |
| return decorator | |
| spaces = _SpacesFallback() | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| MODEL_REPO_ID = os.environ.get("REALRESTORER_MODEL_REPO", "RealRestorer/RealRestorer") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") or None | |
| # --- Default Parameters --- | |
| DEFAULT_STEPS = 28 | |
| DEFAULT_CFG = 3.0 | |
| DEFAULT_SEED = 42 | |
| FIXED_SIZE_LEVEL = 1024 | |
| TASK_PRESETS = { | |
| "Low-light Enhancement": "Please restore this low-quality image, recovering its normal brightness and clarity.", | |
| "Deblurring": "Please deblur the image and make it sharper.", | |
| "Deraining": "Please remove the rain from the image and restore its clarity.", | |
| "Compression Artifact Removal": "Please restore the image clarity and artifacts.", | |
| "Deflare": "Please remove the lens flare and glare from the image.", | |
| "Demoire": "Please remove the moire patterns from the image.", | |
| "Dehazing": "Please dehaze the image.", | |
| "Denoising": "Please remove noise from the image.", | |
| "Reflection Removal": "Please remove the reflection from the image.", | |
| "Underwater Image Enhancement": "Please enhance this underwater image, restoring its natural colors, brightness, and clarity.", | |
| "Old Photo Restoration": "Please restore this old photo, repairing damage and improving its clarity and overall quality.", | |
| "Desnowing": "Please remove the snow from the image and restore its visibility and clarity." | |
| } | |
| DEFAULT_PRESET = "Low-light Enhancement" | |
| # --- UI Header --- | |
| TITLE_HTML = """ | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1 style="font-size: 3rem; margin: 0 0 10px 0; background: linear-gradient(135deg, #4f46e5 0%, #ec4899 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; letter-spacing: -1px;">RealRestorer</h1> | |
| <p style="font-size: 1.15rem; color: var(--body-text-color-subdued); margin: 0; font-weight: 500;">A powerful image restoration model supporting deblurring, denoising, deflaring, low-light enhancement, and more.</p> | |
| </div> | |
| """ | |
| # --- 核心修复:纯布局 CSS,不干扰 Gradio 原生颜色 --- | |
| CUSTOM_CSS = """ | |
| /* 放宽最大宽度,横向占满 */ | |
| .gradio-container { max-width: 1400px !important; margin: auto !important; padding-top: 30px !important; } | |
| /* 卡片容器:使用 Gradio 自带的 CSS 变量,完美适应深浅模式 */ | |
| .rr-section { | |
| background: var(--background-fill-primary) !important; | |
| border: 1px solid var(--border-color-primary) !important; | |
| border-radius: 16px !important; | |
| box-shadow: 0 4px 15px rgba(0, 0, 0, 0.05) !important; | |
| padding: 24px !important; | |
| margin-bottom: 24px !important; | |
| } | |
| /* 巨型行动按钮 */ | |
| #run-btn { | |
| background: linear-gradient(135deg, #4f46e5 0%, #6366f1 100%) !important; | |
| color: white !important; | |
| border: none !important; | |
| font-size: 1.25rem !important; | |
| font-weight: 700 !important; | |
| padding: 16px !important; | |
| border-radius: 12px !important; | |
| box-shadow: 0 4px 15px rgba(79, 70, 229, 0.4) !important; | |
| transition: transform 0.2s, box-shadow 0.2s !important; | |
| margin-top: 10px !important; | |
| } | |
| #run-btn:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 20px rgba(79, 70, 229, 0.5) !important; | |
| } | |
| /* 状态栏:使用主题变量 */ | |
| .rr-status { | |
| background: var(--background-fill-secondary) !important; | |
| border: 1px solid var(--border-color-primary) !important; | |
| border-radius: 8px !important; | |
| padding: 12px 16px !important; | |
| font-size: 1rem !important; | |
| color: var(--body-text-color) !important; | |
| margin-bottom: 20px !important; | |
| font-family: ui-monospace, monospace !important; | |
| } | |
| """ | |
| PIPELINE = None | |
| PIPELINE_LOCK = threading.Lock() | |
| INFERENCE_LOCK = threading.Lock() | |
| def _spaces_gpu_probe(): return None | |
| def _pick_device(): | |
| if torch.cuda.is_available(): return "cuda" | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" | |
| return "cpu" | |
| def _load_pipeline(): | |
| global PIPELINE | |
| if PIPELINE is not None: return PIPELINE | |
| with PIPELINE_LOCK: | |
| if PIPELINE is not None: return PIPELINE | |
| from diffusers import RealRestorerPipeline | |
| device = _pick_device() | |
| dtype = torch.bfloat16 if device == "cuda" and torch.cuda.is_bf16_supported() else torch.float16 if device == "cuda" else torch.float32 | |
| pipe = RealRestorerPipeline.from_pretrained(MODEL_REPO_ID, torch_dtype=dtype, token=HF_TOKEN) | |
| if device == "cuda": pipe.enable_model_cpu_offload(device=device) | |
| else: pipe.to(device) | |
| PIPELINE = pipe | |
| return PIPELINE | |
| def _pil_to_data_url(image: Image.Image) -> str: | |
| buffer = io.BytesIO() | |
| image.save(buffer, format="PNG") | |
| return f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" | |
| def _build_slider_html(before_image: Image.Image, after_image: Image.Image) -> str: | |
| # 巧妙利用 var(--body-text-color) 等变量,适配深色模式 | |
| if before_image is None or after_image is None: | |
| return """ | |
| <div style='height: 500px; width: 100%; display:flex; align-items:center; justify-content:center; color: var(--body-text-color-subdued); border: 2px dashed var(--border-color-primary); border-radius: 12px; font-size: 1.2rem; background: var(--background-fill-secondary);'> | |
| Upload an image and run to see the full-width comparison here. | |
| </div> | |
| """ | |
| if before_image.size != after_image.size: | |
| before_image = before_image.resize(after_image.size, Image.LANCZOS) | |
| before_url = _pil_to_data_url(before_image) | |
| after_url = _pil_to_data_url(after_image) | |
| width, height = after_image.size | |
| slider_id = f"slider_{int(time.time() * 1000)}" | |
| on_input = f"var p=this.value; document.getElementById('{slider_id}_top').style.clipPath='inset(0 '+(100-p)+'% 0 0)'; document.getElementById('{slider_id}_line').style.left=p+'%';" | |
| return f""" | |
| <div style="display: flex; justify-content: center; width: 100%; background: var(--background-fill-secondary); border-radius: 12px; padding: 20px; border: 1px solid var(--border-color-primary); box-sizing: border-box;"> | |
| <div style="position:relative; width:100%; max-height:80vh; aspect-ratio:{width}/{height}; overflow:hidden; border-radius:8px; box-shadow: 0 4px 15px rgba(0,0,0,0.2);"> | |
| <!-- 底图: 修复后 --> | |
| <img src="{after_url}" style="position:absolute; top:0; left:0; width:100%; height:100%; object-fit:contain;" draggable="false" /> | |
| <!-- 顶图: 原图 --> | |
| <img id="{slider_id}_top" src="{before_url}" style="position:absolute; top:0; left:0; width:100%; height:100%; object-fit:contain; clip-path:inset(0 50% 0 0);" draggable="false" /> | |
| <!-- 拖拽线与把手 --> | |
| <div id="{slider_id}_line" style="position:absolute; top:0; left:50%; width:4px; height:100%; background:white; box-shadow:0 0 10px rgba(0,0,0,0.5); transform:translateX(-50%); pointer-events:none; z-index:10;"> | |
| <div style="position:absolute; top:50%; left:50%; transform:translate(-50%, -50%); width:48px; height:48px; border-radius:50%; background:white; display:flex; align-items:center; justify-content:center; box-shadow:0 2px 10px rgba(0,0,0,0.4); font-weight:bold; color:#1e293b; font-size:1.3rem;">↔</div> | |
| </div> | |
| <!-- 透明原生滑动条 --> | |
| <input type="range" min="0" max="100" value="50" oninput="{on_input}" style="position:absolute; top:0; left:0; width:100%; height:100%; opacity:0; cursor:ew-resize; z-index:20; margin:0; appearance:auto;" /> | |
| </div> | |
| </div> | |
| <div style="display:flex; justify-content:space-between; color:var(--body-text-color-subdued); font-size:1.05rem; padding-top:12px; font-weight:600;"> | |
| <span>⬅️ Original</span> | |
| <span>Restored ➡️</span> | |
| </div> | |
| """ | |
| def _on_preset_change(preset_name: str): | |
| return TASK_PRESETS.get(preset_name, "") | |
| def run_inference(image: Optional[Image.Image], task_name: str, prompt: str, steps: float, guidance_scale: float, seed: float, progress=gr.Progress(track_tqdm=False)): | |
| if image is None: return None, "💡 Error: Please upload an image first.", _build_slider_html(None, None) | |
| source_image = image.convert("RGB") | |
| final_prompt = prompt.strip() or TASK_PRESETS.get(task_name, TASK_PRESETS[DEFAULT_PRESET]) | |
| final_seed = int(seed) if seed >= 0 else random.randint(0, 2**31 - 1) | |
| try: | |
| progress(0.1, desc="Preparing model...") | |
| pipeline = _load_pipeline() | |
| start_time = time.time() | |
| with INFERENCE_LOCK: | |
| progress(0.3, desc="Restoring Image...") | |
| output = pipeline( | |
| image=source_image, | |
| prompt=final_prompt, | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance_scale), | |
| size_level=FIXED_SIZE_LEVEL, | |
| seed=final_seed, | |
| ).images[0] | |
| elapsed = time.time() - start_time | |
| status = f"✅ Success | Time: {elapsed:.1f}s | Steps: {int(steps)} | CFG: {guidance_scale} | Seed: {final_seed}" | |
| slider_html = _build_slider_html(source_image, output) | |
| return output, status, slider_html | |
| except Exception as exc: | |
| traceback.print_exc() | |
| return None, f"❌ Failed: {exc}", _build_slider_html(None, None) | |
| def build_demo(): | |
| # 使用纯净的 Soft 主题,交由 Gradio 原生处理深浅色模式的颜色变化 | |
| theme = gr.themes.Soft( | |
| primary_hue="indigo", | |
| text_size=gr.themes.sizes.text_lg | |
| ) | |
| with gr.Blocks(css=CUSTOM_CSS, title="RealRestorer", theme=theme) as demo: | |
| gr.HTML(TITLE_HTML) | |
| # ========================================== | |
| # Section 1: 图像上传与控制参数 (纵向排列) | |
| # ========================================== | |
| with gr.Column(elem_classes=["rr-section"]): | |
| gr.Markdown("### 📥 1. Upload & Settings") | |
| # 横向宽屏上传框 | |
| input_image = gr.Image(label="Upload Image", type="pil", height=420) | |
| # 控制参数横向铺开 | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| task_dropdown = gr.Dropdown(choices=list(TASK_PRESETS.keys()), value=DEFAULT_PRESET, label="Task Preset") | |
| with gr.Column(scale=2): | |
| prompt_box = gr.Textbox(label="Instruction (输入指令)", value=TASK_PRESETS[DEFAULT_PRESET], lines=1) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| guidance_slider = gr.Slider(minimum=1.0, maximum=6.0, value=DEFAULT_CFG, step=0.1, label="CFG / Guidance Scale") | |
| with gr.Column(scale=1): | |
| steps_slider = gr.Slider(minimum=12, maximum=40, value=DEFAULT_STEPS, step=1, label="Inference Steps") | |
| with gr.Column(scale=1): | |
| seed_box = gr.Number(label="Seed (-1 for random)", value=DEFAULT_SEED, precision=0) | |
| # 巨型运行按钮 | |
| run_button = gr.Button("🚀 Run Restoration", elem_id="run-btn") | |
| # ========================================== | |
| # Section 2: 结果展示 (横跨全屏) | |
| # ========================================== | |
| with gr.Column(elem_classes=["rr-section"]): | |
| gr.Markdown("### 🖼️ 2. Restoration Results") | |
| status_box = gr.HTML( | |
| value="<div class='rr-status'>💡 Status: Ready. Upload an image, adjust settings, and click Run.</div>" | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Compare View"): | |
| slider_view = gr.HTML(_build_slider_html(None, None)) | |
| with gr.Tab("Output Image"): | |
| output_image = gr.Image(label="Restored Result", type="pil", interactive=False, height=600, show_label=False) | |
| # ========================================== | |
| # Events | |
| # ========================================== | |
| task_dropdown.change(fn=_on_preset_change, inputs=[task_dropdown], outputs=[prompt_box]) | |
| run_button.click( | |
| fn=lambda: ("<div class='rr-status'>⏳ Status: Processing... Please wait.</div>", _build_slider_html(None, None)), | |
| outputs=[status_box, slider_view] | |
| ).then( | |
| fn=run_inference, | |
| inputs=[input_image, task_dropdown, prompt_box, steps_slider, guidance_slider, seed_box], | |
| outputs=[output_image, status_box, slider_view], | |
| ).then( | |
| fn=lambda status: f"<div class='rr-status'>{status}</div>", | |
| inputs=[status_box], | |
| outputs=[status_box] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.queue(max_size=8, default_concurrency_limit=1).launch(server_name="0.0.0.0", show_error=True) |