from __future__ import annotations
import base64
import io
import os
import random
import threading
import time
import traceback
from typing import Optional
try:
import spaces
except ImportError:
class _SpacesFallback:
@staticmethod
def GPU(*args, **kwargs):
def decorator(fn): return fn
return decorator
spaces = _SpacesFallback()
import gradio as gr
import torch
from PIL import Image
MODEL_REPO_ID = os.environ.get("REALRESTORER_MODEL_REPO", "RealRestorer/RealRestorer")
HF_TOKEN = os.environ.get("HF_TOKEN") or None
# --- Default Parameters ---
DEFAULT_STEPS = 28
DEFAULT_CFG = 3.0
DEFAULT_SEED = 42
FIXED_SIZE_LEVEL = 1024
TASK_PRESETS = {
"Low-light Enhancement": "Please restore this low-quality image, recovering its normal brightness and clarity.",
"Deblurring": "Please deblur the image and make it sharper.",
"Deraining": "Please remove the rain from the image and restore its clarity.",
"Compression Artifact Removal": "Please restore the image clarity and artifacts.",
"Deflare": "Please remove the lens flare and glare from the image.",
"Demoire": "Please remove the moire patterns from the image.",
"Dehazing": "Please dehaze the image.",
"Denoising": "Please remove noise from the image.",
"Reflection Removal": "Please remove the reflection from the image.",
"Underwater Image Enhancement": "Please enhance this underwater image, restoring its natural colors, brightness, and clarity.",
"Old Photo Restoration": "Please restore this old photo, repairing damage and improving its clarity and overall quality.",
"Desnowing": "Please remove the snow from the image and restore its visibility and clarity."
}
DEFAULT_PRESET = "Low-light Enhancement"
# --- UI Header ---
TITLE_HTML = """
RealRestorer
A powerful image restoration model supporting deblurring, denoising, deflaring, low-light enhancement, and more.
"""
# --- 核心修复:纯布局 CSS,不干扰 Gradio 原生颜色 ---
CUSTOM_CSS = """
/* 放宽最大宽度,横向占满 */
.gradio-container { max-width: 1400px !important; margin: auto !important; padding-top: 30px !important; }
/* 卡片容器:使用 Gradio 自带的 CSS 变量,完美适应深浅模式 */
.rr-section {
background: var(--background-fill-primary) !important;
border: 1px solid var(--border-color-primary) !important;
border-radius: 16px !important;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.05) !important;
padding: 24px !important;
margin-bottom: 24px !important;
}
/* 巨型行动按钮 */
#run-btn {
background: linear-gradient(135deg, #4f46e5 0%, #6366f1 100%) !important;
color: white !important;
border: none !important;
font-size: 1.25rem !important;
font-weight: 700 !important;
padding: 16px !important;
border-radius: 12px !important;
box-shadow: 0 4px 15px rgba(79, 70, 229, 0.4) !important;
transition: transform 0.2s, box-shadow 0.2s !important;
margin-top: 10px !important;
}
#run-btn:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 20px rgba(79, 70, 229, 0.5) !important;
}
/* 状态栏:使用主题变量 */
.rr-status {
background: var(--background-fill-secondary) !important;
border: 1px solid var(--border-color-primary) !important;
border-radius: 8px !important;
padding: 12px 16px !important;
font-size: 1rem !important;
color: var(--body-text-color) !important;
margin-bottom: 20px !important;
font-family: ui-monospace, monospace !important;
}
"""
PIPELINE = None
PIPELINE_LOCK = threading.Lock()
INFERENCE_LOCK = threading.Lock()
@spaces.GPU(duration=180)
def _spaces_gpu_probe(): return None
def _pick_device():
if torch.cuda.is_available(): return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps"
return "cpu"
def _load_pipeline():
global PIPELINE
if PIPELINE is not None: return PIPELINE
with PIPELINE_LOCK:
if PIPELINE is not None: return PIPELINE
from diffusers import RealRestorerPipeline
device = _pick_device()
dtype = torch.bfloat16 if device == "cuda" and torch.cuda.is_bf16_supported() else torch.float16 if device == "cuda" else torch.float32
pipe = RealRestorerPipeline.from_pretrained(MODEL_REPO_ID, torch_dtype=dtype, token=HF_TOKEN)
if device == "cuda": pipe.enable_model_cpu_offload(device=device)
else: pipe.to(device)
PIPELINE = pipe
return PIPELINE
def _pil_to_data_url(image: Image.Image) -> str:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
def _build_slider_html(before_image: Image.Image, after_image: Image.Image) -> str:
# 巧妙利用 var(--body-text-color) 等变量,适配深色模式
if before_image is None or after_image is None:
return """
Upload an image and run to see the full-width comparison here.
"""
if before_image.size != after_image.size:
before_image = before_image.resize(after_image.size, Image.LANCZOS)
before_url = _pil_to_data_url(before_image)
after_url = _pil_to_data_url(after_image)
width, height = after_image.size
slider_id = f"slider_{int(time.time() * 1000)}"
on_input = f"var p=this.value; document.getElementById('{slider_id}_top').style.clipPath='inset(0 '+(100-p)+'% 0 0)'; document.getElementById('{slider_id}_line').style.left=p+'%';"
return f"""
⬅️ Original
Restored ➡️
"""
def _on_preset_change(preset_name: str):
return TASK_PRESETS.get(preset_name, "")
@spaces.GPU(duration=180)
def run_inference(image: Optional[Image.Image], task_name: str, prompt: str, steps: float, guidance_scale: float, seed: float, progress=gr.Progress(track_tqdm=False)):
if image is None: return None, "💡 Error: Please upload an image first.", _build_slider_html(None, None)
source_image = image.convert("RGB")
final_prompt = prompt.strip() or TASK_PRESETS.get(task_name, TASK_PRESETS[DEFAULT_PRESET])
final_seed = int(seed) if seed >= 0 else random.randint(0, 2**31 - 1)
try:
progress(0.1, desc="Preparing model...")
pipeline = _load_pipeline()
start_time = time.time()
with INFERENCE_LOCK:
progress(0.3, desc="Restoring Image...")
output = pipeline(
image=source_image,
prompt=final_prompt,
num_inference_steps=int(steps),
guidance_scale=float(guidance_scale),
size_level=FIXED_SIZE_LEVEL,
seed=final_seed,
).images[0]
elapsed = time.time() - start_time
status = f"✅ Success | Time: {elapsed:.1f}s | Steps: {int(steps)} | CFG: {guidance_scale} | Seed: {final_seed}"
slider_html = _build_slider_html(source_image, output)
return output, status, slider_html
except Exception as exc:
traceback.print_exc()
return None, f"❌ Failed: {exc}", _build_slider_html(None, None)
def build_demo():
# 使用纯净的 Soft 主题,交由 Gradio 原生处理深浅色模式的颜色变化
theme = gr.themes.Soft(
primary_hue="indigo",
text_size=gr.themes.sizes.text_lg
)
with gr.Blocks(css=CUSTOM_CSS, title="RealRestorer", theme=theme) as demo:
gr.HTML(TITLE_HTML)
# ==========================================
# Section 1: 图像上传与控制参数 (纵向排列)
# ==========================================
with gr.Column(elem_classes=["rr-section"]):
gr.Markdown("### 📥 1. Upload & Settings")
# 横向宽屏上传框
input_image = gr.Image(label="Upload Image", type="pil", height=420)
# 控制参数横向铺开
with gr.Row():
with gr.Column(scale=1):
task_dropdown = gr.Dropdown(choices=list(TASK_PRESETS.keys()), value=DEFAULT_PRESET, label="Task Preset")
with gr.Column(scale=2):
prompt_box = gr.Textbox(label="Instruction (输入指令)", value=TASK_PRESETS[DEFAULT_PRESET], lines=1)
with gr.Row():
with gr.Column(scale=1):
guidance_slider = gr.Slider(minimum=1.0, maximum=6.0, value=DEFAULT_CFG, step=0.1, label="CFG / Guidance Scale")
with gr.Column(scale=1):
steps_slider = gr.Slider(minimum=12, maximum=40, value=DEFAULT_STEPS, step=1, label="Inference Steps")
with gr.Column(scale=1):
seed_box = gr.Number(label="Seed (-1 for random)", value=DEFAULT_SEED, precision=0)
# 巨型运行按钮
run_button = gr.Button("🚀 Run Restoration", elem_id="run-btn")
# ==========================================
# Section 2: 结果展示 (横跨全屏)
# ==========================================
with gr.Column(elem_classes=["rr-section"]):
gr.Markdown("### 🖼️ 2. Restoration Results")
status_box = gr.HTML(
value="💡 Status: Ready. Upload an image, adjust settings, and click Run.
"
)
with gr.Tabs():
with gr.Tab("Compare View"):
slider_view = gr.HTML(_build_slider_html(None, None))
with gr.Tab("Output Image"):
output_image = gr.Image(label="Restored Result", type="pil", interactive=False, height=600, show_label=False)
# ==========================================
# Events
# ==========================================
task_dropdown.change(fn=_on_preset_change, inputs=[task_dropdown], outputs=[prompt_box])
run_button.click(
fn=lambda: ("⏳ Status: Processing... Please wait.
", _build_slider_html(None, None)),
outputs=[status_box, slider_view]
).then(
fn=run_inference,
inputs=[input_image, task_dropdown, prompt_box, steps_slider, guidance_slider, seed_box],
outputs=[output_image, status_box, slider_view],
).then(
fn=lambda status: f"{status}
",
inputs=[status_box],
outputs=[status_box]
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.queue(max_size=8, default_concurrency_limit=1).launch(server_name="0.0.0.0", show_error=True)