from __future__ import annotations import base64 import io import os import random import threading import time import traceback from typing import Optional try: import spaces except ImportError: class _SpacesFallback: @staticmethod def GPU(*args, **kwargs): def decorator(fn): return fn return decorator spaces = _SpacesFallback() import gradio as gr import torch from PIL import Image MODEL_REPO_ID = os.environ.get("REALRESTORER_MODEL_REPO", "RealRestorer/RealRestorer") HF_TOKEN = os.environ.get("HF_TOKEN") or None # --- Default Parameters --- DEFAULT_STEPS = 28 DEFAULT_CFG = 3.0 DEFAULT_SEED = 42 FIXED_SIZE_LEVEL = 1024 TASK_PRESETS = { "Low-light Enhancement": "Please restore this low-quality image, recovering its normal brightness and clarity.", "Deblurring": "Please deblur the image and make it sharper.", "Deraining": "Please remove the rain from the image and restore its clarity.", "Compression Artifact Removal": "Please restore the image clarity and artifacts.", "Deflare": "Please remove the lens flare and glare from the image.", "Demoire": "Please remove the moire patterns from the image.", "Dehazing": "Please dehaze the image.", "Denoising": "Please remove noise from the image.", "Reflection Removal": "Please remove the reflection from the image.", "Underwater Image Enhancement": "Please enhance this underwater image, restoring its natural colors, brightness, and clarity.", "Old Photo Restoration": "Please restore this old photo, repairing damage and improving its clarity and overall quality.", "Desnowing": "Please remove the snow from the image and restore its visibility and clarity." } DEFAULT_PRESET = "Low-light Enhancement" # --- UI Header --- TITLE_HTML = """

RealRestorer

A powerful image restoration model supporting deblurring, denoising, deflaring, low-light enhancement, and more.

""" # --- 核心修复:纯布局 CSS,不干扰 Gradio 原生颜色 --- CUSTOM_CSS = """ /* 放宽最大宽度,横向占满 */ .gradio-container { max-width: 1400px !important; margin: auto !important; padding-top: 30px !important; } /* 卡片容器:使用 Gradio 自带的 CSS 变量,完美适应深浅模式 */ .rr-section { background: var(--background-fill-primary) !important; border: 1px solid var(--border-color-primary) !important; border-radius: 16px !important; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.05) !important; padding: 24px !important; margin-bottom: 24px !important; } /* 巨型行动按钮 */ #run-btn { background: linear-gradient(135deg, #4f46e5 0%, #6366f1 100%) !important; color: white !important; border: none !important; font-size: 1.25rem !important; font-weight: 700 !important; padding: 16px !important; border-radius: 12px !important; box-shadow: 0 4px 15px rgba(79, 70, 229, 0.4) !important; transition: transform 0.2s, box-shadow 0.2s !important; margin-top: 10px !important; } #run-btn:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 20px rgba(79, 70, 229, 0.5) !important; } /* 状态栏:使用主题变量 */ .rr-status { background: var(--background-fill-secondary) !important; border: 1px solid var(--border-color-primary) !important; border-radius: 8px !important; padding: 12px 16px !important; font-size: 1rem !important; color: var(--body-text-color) !important; margin-bottom: 20px !important; font-family: ui-monospace, monospace !important; } """ PIPELINE = None PIPELINE_LOCK = threading.Lock() INFERENCE_LOCK = threading.Lock() @spaces.GPU(duration=180) def _spaces_gpu_probe(): return None def _pick_device(): if torch.cuda.is_available(): return "cuda" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" return "cpu" def _load_pipeline(): global PIPELINE if PIPELINE is not None: return PIPELINE with PIPELINE_LOCK: if PIPELINE is not None: return PIPELINE from diffusers import RealRestorerPipeline device = _pick_device() dtype = torch.bfloat16 if device == "cuda" and torch.cuda.is_bf16_supported() else torch.float16 if device == "cuda" else torch.float32 pipe = RealRestorerPipeline.from_pretrained(MODEL_REPO_ID, torch_dtype=dtype, token=HF_TOKEN) if device == "cuda": pipe.enable_model_cpu_offload(device=device) else: pipe.to(device) PIPELINE = pipe return PIPELINE def _pil_to_data_url(image: Image.Image) -> str: buffer = io.BytesIO() image.save(buffer, format="PNG") return f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" def _build_slider_html(before_image: Image.Image, after_image: Image.Image) -> str: # 巧妙利用 var(--body-text-color) 等变量,适配深色模式 if before_image is None or after_image is None: return """
Upload an image and run to see the full-width comparison here.
""" if before_image.size != after_image.size: before_image = before_image.resize(after_image.size, Image.LANCZOS) before_url = _pil_to_data_url(before_image) after_url = _pil_to_data_url(after_image) width, height = after_image.size slider_id = f"slider_{int(time.time() * 1000)}" on_input = f"var p=this.value; document.getElementById('{slider_id}_top').style.clipPath='inset(0 '+(100-p)+'% 0 0)'; document.getElementById('{slider_id}_line').style.left=p+'%';" return f"""
⬅️ Original Restored ➡️
""" def _on_preset_change(preset_name: str): return TASK_PRESETS.get(preset_name, "") @spaces.GPU(duration=180) def run_inference(image: Optional[Image.Image], task_name: str, prompt: str, steps: float, guidance_scale: float, seed: float, progress=gr.Progress(track_tqdm=False)): if image is None: return None, "💡 Error: Please upload an image first.", _build_slider_html(None, None) source_image = image.convert("RGB") final_prompt = prompt.strip() or TASK_PRESETS.get(task_name, TASK_PRESETS[DEFAULT_PRESET]) final_seed = int(seed) if seed >= 0 else random.randint(0, 2**31 - 1) try: progress(0.1, desc="Preparing model...") pipeline = _load_pipeline() start_time = time.time() with INFERENCE_LOCK: progress(0.3, desc="Restoring Image...") output = pipeline( image=source_image, prompt=final_prompt, num_inference_steps=int(steps), guidance_scale=float(guidance_scale), size_level=FIXED_SIZE_LEVEL, seed=final_seed, ).images[0] elapsed = time.time() - start_time status = f"✅ Success | Time: {elapsed:.1f}s | Steps: {int(steps)} | CFG: {guidance_scale} | Seed: {final_seed}" slider_html = _build_slider_html(source_image, output) return output, status, slider_html except Exception as exc: traceback.print_exc() return None, f"❌ Failed: {exc}", _build_slider_html(None, None) def build_demo(): # 使用纯净的 Soft 主题,交由 Gradio 原生处理深浅色模式的颜色变化 theme = gr.themes.Soft( primary_hue="indigo", text_size=gr.themes.sizes.text_lg ) with gr.Blocks(css=CUSTOM_CSS, title="RealRestorer", theme=theme) as demo: gr.HTML(TITLE_HTML) # ========================================== # Section 1: 图像上传与控制参数 (纵向排列) # ========================================== with gr.Column(elem_classes=["rr-section"]): gr.Markdown("### 📥 1. Upload & Settings") # 横向宽屏上传框 input_image = gr.Image(label="Upload Image", type="pil", height=420) # 控制参数横向铺开 with gr.Row(): with gr.Column(scale=1): task_dropdown = gr.Dropdown(choices=list(TASK_PRESETS.keys()), value=DEFAULT_PRESET, label="Task Preset") with gr.Column(scale=2): prompt_box = gr.Textbox(label="Instruction (输入指令)", value=TASK_PRESETS[DEFAULT_PRESET], lines=1) with gr.Row(): with gr.Column(scale=1): guidance_slider = gr.Slider(minimum=1.0, maximum=6.0, value=DEFAULT_CFG, step=0.1, label="CFG / Guidance Scale") with gr.Column(scale=1): steps_slider = gr.Slider(minimum=12, maximum=40, value=DEFAULT_STEPS, step=1, label="Inference Steps") with gr.Column(scale=1): seed_box = gr.Number(label="Seed (-1 for random)", value=DEFAULT_SEED, precision=0) # 巨型运行按钮 run_button = gr.Button("🚀 Run Restoration", elem_id="run-btn") # ========================================== # Section 2: 结果展示 (横跨全屏) # ========================================== with gr.Column(elem_classes=["rr-section"]): gr.Markdown("### 🖼️ 2. Restoration Results") status_box = gr.HTML( value="
💡 Status: Ready. Upload an image, adjust settings, and click Run.
" ) with gr.Tabs(): with gr.Tab("Compare View"): slider_view = gr.HTML(_build_slider_html(None, None)) with gr.Tab("Output Image"): output_image = gr.Image(label="Restored Result", type="pil", interactive=False, height=600, show_label=False) # ========================================== # Events # ========================================== task_dropdown.change(fn=_on_preset_change, inputs=[task_dropdown], outputs=[prompt_box]) run_button.click( fn=lambda: ("
⏳ Status: Processing... Please wait.
", _build_slider_html(None, None)), outputs=[status_box, slider_view] ).then( fn=run_inference, inputs=[input_image, task_dropdown, prompt_box, steps_slider, guidance_slider, seed_box], outputs=[output_image, status_box, slider_view], ).then( fn=lambda status: f"
{status}
", inputs=[status_box], outputs=[status_box] ) return demo if __name__ == "__main__": demo = build_demo() demo.queue(max_size=8, default_concurrency_limit=1).launch(server_name="0.0.0.0", show_error=True)