yangyufeng commited on
Commit
bb9d70a
·
0 Parent(s):

first commit

Browse files
Files changed (4) hide show
  1. .gitignore +7 -0
  2. README.md +38 -0
  3. app.py +414 -0
  4. requirements.txt +13 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .DS_Store
3
+ .gradio/
4
+ .hf_cache/
5
+ .cache/
6
+ venv/
7
+ .venv/
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: RealRestorer Demo
3
+ emoji: 🖼️
4
+ colorFrom: teal
5
+ colorTo: orange
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ python_version: 3.10
10
+ ---
11
+
12
+ # RealRestorer Hugging Face Space
13
+
14
+ This Space runs the public `RealRestorer/RealRestorer` model with a Gradio interface for single-image restoration.
15
+
16
+ It is designed as a lightweight Space repo:
17
+
18
+ - the model weights are loaded from Hugging Face
19
+ - the custom RealRestorer-enabled `diffusers` build is installed from the main GitHub repo
20
+ - the UI layout follows the existing internal Gradio demo, but is simplified for Spaces deployment
21
+
22
+ ## Notes
23
+
24
+ - A GPU Space is strongly recommended. The released model is not practical on CPU.
25
+ - The default model repo is `RealRestorer/RealRestorer`.
26
+ - You can override the model repo with the `REALRESTORER_MODEL_REPO` Space environment variable.
27
+
28
+ ## Local Run
29
+
30
+ ```bash
31
+ python -m pip install -r requirements.txt
32
+ python app.py
33
+ ```
34
+
35
+ ## Space Variables
36
+
37
+ - `REALRESTORER_MODEL_REPO`: optional model repo id, default `RealRestorer/RealRestorer`
38
+ - `HF_TOKEN`: optional token if you want to load a private model repo
app.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import io
5
+ import os
6
+ import random
7
+ import threading
8
+ import time
9
+ import traceback
10
+ from typing import Optional
11
+
12
+ import gradio as gr
13
+ import torch
14
+ from PIL import Image
15
+
16
+ MODEL_REPO_ID = os.environ.get("REALRESTORER_MODEL_REPO", "RealRestorer/RealRestorer")
17
+ HF_TOKEN = os.environ.get("HF_TOKEN") or None
18
+
19
+ TASK_PRESETS = {
20
+ "Low-light Enhancement": "Please restore this low-quality image, recovering its normal brightness and clarity.",
21
+ "Deblurring": "Please deblur the image and make it sharper.",
22
+ "Deraining": "Please remove the rain from the image and restore its clarity.",
23
+ "Compression Artifact Removal": "Please restore the image clarity and artifacts.",
24
+ "Deflare": "Please remove the lens flare and glare from the image.",
25
+ "Demoire": "Please remove the moire patterns from the image.",
26
+ "Dehazing": "Please dehaze the image.",
27
+ "Denoising": "Please remove noise from the image.",
28
+ "Reflection Removal": "Please remove the reflection from the image.",
29
+ }
30
+
31
+ DEFAULT_PRESET = "Reflection Removal"
32
+ DEFAULT_STATUS = (
33
+ "Model not loaded yet. Upload an image, pick a preset or edit the prompt, "
34
+ "then click Run Inference."
35
+ )
36
+ DEFAULT_SLIDER = """
37
+ <div style="text-align:center; padding:96px 24px; color:#7b8794; font-size:1.05rem;">
38
+ The interactive before/after slider will appear here after the first run.
39
+ </div>
40
+ """
41
+
42
+ CUSTOM_CSS = """
43
+ :root {
44
+ --rr-ink: #102a43;
45
+ --rr-muted: #5c6b7a;
46
+ --rr-panel: rgba(255, 255, 255, 0.9);
47
+ --rr-edge: rgba(16, 42, 67, 0.12);
48
+ --rr-accent: #0f766e;
49
+ --rr-accent-2: #f59e0b;
50
+ }
51
+
52
+ body {
53
+ background:
54
+ radial-gradient(circle at top left, rgba(15, 118, 110, 0.12), transparent 30%),
55
+ radial-gradient(circle at top right, rgba(245, 158, 11, 0.12), transparent 28%),
56
+ linear-gradient(180deg, #f7fafc 0%, #eef2f7 100%);
57
+ }
58
+
59
+ .gradio-container {
60
+ max-width: 1380px !important;
61
+ margin: 0 auto !important;
62
+ }
63
+
64
+ .rr-hero {
65
+ padding: 20px 8px 10px;
66
+ }
67
+
68
+ .rr-hero h1 {
69
+ margin: 0;
70
+ text-align: center;
71
+ font-size: 2.55rem;
72
+ line-height: 1.05;
73
+ letter-spacing: -0.04em;
74
+ color: var(--rr-ink);
75
+ }
76
+
77
+ .rr-hero p {
78
+ margin: 12px auto 0;
79
+ max-width: 820px;
80
+ text-align: center;
81
+ color: var(--rr-muted);
82
+ font-size: 1rem;
83
+ line-height: 1.6;
84
+ }
85
+
86
+ .rr-shell {
87
+ border: 1px solid var(--rr-edge);
88
+ border-radius: 24px;
89
+ background: rgba(255, 255, 255, 0.7);
90
+ backdrop-filter: blur(16px);
91
+ box-shadow: 0 18px 48px rgba(15, 23, 42, 0.08);
92
+ padding: 18px;
93
+ }
94
+
95
+ .rr-note {
96
+ border-left: 4px solid var(--rr-accent);
97
+ padding: 12px 14px;
98
+ border-radius: 12px;
99
+ background: rgba(15, 118, 110, 0.08);
100
+ color: var(--rr-ink);
101
+ margin-bottom: 8px;
102
+ }
103
+
104
+ .rr-foot {
105
+ text-align: center;
106
+ color: var(--rr-muted);
107
+ font-size: 0.9rem;
108
+ padding: 10px 0 2px;
109
+ }
110
+
111
+ #run-btn {
112
+ background: linear-gradient(135deg, var(--rr-accent) 0%, #0b8a7b 55%, var(--rr-accent-2) 100%) !important;
113
+ color: white !important;
114
+ border: none !important;
115
+ font-weight: 700 !important;
116
+ }
117
+ """
118
+
119
+ PIPELINE = None
120
+ PIPELINE_LOCK = threading.Lock()
121
+ INFERENCE_LOCK = threading.Lock()
122
+ PIPELINE_DEVICE = "cpu"
123
+ PIPELINE_DTYPE = "float32"
124
+
125
+
126
+ def _pick_device() -> str:
127
+ if torch.cuda.is_available():
128
+ return "cuda"
129
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
130
+ return "mps"
131
+ return "cpu"
132
+
133
+
134
+ def _pick_dtype(device: str) -> tuple[torch.dtype, str]:
135
+ if device == "cuda":
136
+ if torch.cuda.is_bf16_supported():
137
+ return torch.bfloat16, "bfloat16"
138
+ return torch.float16, "float16"
139
+ return torch.float32, "float32"
140
+
141
+
142
+ def _load_pipeline():
143
+ global PIPELINE, PIPELINE_DEVICE, PIPELINE_DTYPE
144
+
145
+ if PIPELINE is not None:
146
+ return PIPELINE
147
+
148
+ with PIPELINE_LOCK:
149
+ if PIPELINE is not None:
150
+ return PIPELINE
151
+
152
+ from diffusers import RealRestorerPipeline
153
+
154
+ device = _pick_device()
155
+ torch_dtype, dtype_name = _pick_dtype(device)
156
+ pipe = RealRestorerPipeline.from_pretrained(
157
+ MODEL_REPO_ID,
158
+ torch_dtype=torch_dtype,
159
+ token=HF_TOKEN,
160
+ )
161
+ if device == "cuda":
162
+ pipe.enable_model_cpu_offload(device=device)
163
+ else:
164
+ pipe.to(device)
165
+
166
+ PIPELINE = pipe
167
+ PIPELINE_DEVICE = device
168
+ PIPELINE_DTYPE = dtype_name
169
+ return PIPELINE
170
+
171
+
172
+ def _resolve_seed(seed: float | int | None) -> int:
173
+ if seed is None:
174
+ return random.randint(0, 2**31 - 1)
175
+ seed_value = int(seed)
176
+ if seed_value < 0:
177
+ return random.randint(0, 2**31 - 1)
178
+ return seed_value
179
+
180
+
181
+ def _placeholder_slider(message: str) -> str:
182
+ return (
183
+ "<div style='text-align:center; padding:72px 24px; color:#7b8794; font-size:1rem;'>"
184
+ f"{message}"
185
+ "</div>"
186
+ )
187
+
188
+
189
+ def _pil_to_data_url(image: Image.Image) -> str:
190
+ buffer = io.BytesIO()
191
+ image.save(buffer, format="PNG")
192
+ encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
193
+ return f"data:image/png;base64,{encoded}"
194
+
195
+
196
+ def _build_slider_html(before_image: Image.Image, after_image: Image.Image) -> str:
197
+ if before_image.size != after_image.size:
198
+ before_image = before_image.resize(after_image.size, Image.LANCZOS)
199
+
200
+ before_url = _pil_to_data_url(before_image)
201
+ after_url = _pil_to_data_url(after_image)
202
+ width, height = after_image.size
203
+ slider_id = f"rr_slider_{int(time.time() * 1000)}"
204
+ on_input = (
205
+ f"var p=this.value;"
206
+ f"document.getElementById('{slider_id}_top').style.clipPath='inset(0 '+(100-p)+'% 0 0)';"
207
+ f"document.getElementById('{slider_id}_line').style.left=p+'%';"
208
+ )
209
+
210
+ return f"""
211
+ <div style="display:flex; justify-content:space-between; align-items:end; gap:12px; margin-bottom:14px; color:#102a43;">
212
+ <div>
213
+ <div style="font-size:1.12rem; font-weight:700;">Before / After Comparison</div>
214
+ <div style="font-size:0.92rem; color:#5c6b7a;">Drag the slider to inspect restored details.</div>
215
+ </div>
216
+ <div style="font-size:0.92rem; color:#5c6b7a;">{width} x {height}px</div>
217
+ </div>
218
+ <div style="position:relative; width:100%; aspect-ratio:{width}/{height}; overflow:hidden; border-radius:18px; border:1px solid rgba(16, 42, 67, 0.12); box-shadow:0 18px 36px rgba(15, 23, 42, 0.12); background:white;">
219
+ <img src="{after_url}" style="position:absolute; inset:0; width:100%; height:100%; object-fit:cover;" draggable="false" />
220
+ <img id="{slider_id}_top" src="{before_url}" style="position:absolute; inset:0; width:100%; height:100%; object-fit:cover; clip-path:inset(0 50% 0 0);" draggable="false" />
221
+ <div id="{slider_id}_line" style="position:absolute; top:0; left:50%; width:3px; height:100%; background:white; box-shadow:0 0 10px rgba(0,0,0,0.35); transform:translateX(-50%);">
222
+ <div style="position:absolute; top:50%; left:50%; transform:translate(-50%, -50%); width:44px; height:44px; border-radius:999px; background:white; display:flex; align-items:center; justify-content:center; box-shadow:0 10px 24px rgba(15, 23, 42, 0.18); color:#102a43; font-weight:700;">
223
+ &#8596;
224
+ </div>
225
+ </div>
226
+ </div>
227
+ <div style="padding-top:18px;">
228
+ <input type="range" min="0" max="100" value="50" oninput="{on_input}" style="width:100%; cursor:ew-resize;" />
229
+ <div style="display:flex; justify-content:space-between; color:#486581; font-size:0.92rem; padding-top:6px;">
230
+ <span>Original</span>
231
+ <span>Restored</span>
232
+ </div>
233
+ </div>
234
+ """
235
+
236
+
237
+ def _on_preset_change(preset_name: str) -> str:
238
+ return TASK_PRESETS.get(preset_name, "")
239
+
240
+
241
+ def run_inference(
242
+ image: Optional[Image.Image],
243
+ preset_name: str,
244
+ prompt: str,
245
+ steps: float,
246
+ guidance_scale: float,
247
+ size_level: float,
248
+ seed: float,
249
+ ):
250
+ if image is None:
251
+ return None, "Please upload an input image first.", _placeholder_slider("Please upload an image to start.")
252
+
253
+ source_image = image.convert("RGB")
254
+ final_prompt = prompt.strip() or TASK_PRESETS.get(preset_name, TASK_PRESETS[DEFAULT_PRESET])
255
+ final_seed = _resolve_seed(seed)
256
+
257
+ try:
258
+ pipeline = _load_pipeline()
259
+ start_time = time.time()
260
+ with INFERENCE_LOCK:
261
+ output = pipeline(
262
+ image=source_image,
263
+ prompt=final_prompt,
264
+ num_inference_steps=int(steps),
265
+ guidance_scale=float(guidance_scale),
266
+ size_level=int(size_level),
267
+ seed=final_seed,
268
+ ).images[0]
269
+ elapsed = time.time() - start_time
270
+ except Exception as exc:
271
+ traceback.print_exc()
272
+ error_text = (
273
+ f"Inference failed: {exc}\n"
274
+ f"Model repo: {MODEL_REPO_ID}"
275
+ )
276
+ return None, error_text, _placeholder_slider("Inference failed. Check the Space logs for details.")
277
+
278
+ status = (
279
+ f"Done in {elapsed:.2f}s\n"
280
+ f"Model: {MODEL_REPO_ID}\n"
281
+ f"Device: {PIPELINE_DEVICE}\n"
282
+ f"Dtype: {PIPELINE_DTYPE}\n"
283
+ f"Steps: {int(steps)} | Guidance: {float(guidance_scale):.1f} | Size level: {int(size_level)}\n"
284
+ f"Seed: {final_seed}"
285
+ )
286
+ slider_html = _build_slider_html(source_image, output)
287
+ return output, status, slider_html
288
+
289
+
290
+ def build_demo() -> gr.Blocks:
291
+ with gr.Blocks(css=CUSTOM_CSS, title="RealRestorer Demo") as demo:
292
+ gr.HTML(
293
+ """
294
+ <div class="rr-hero">
295
+ <h1>RealRestorer Image Restoration</h1>
296
+ <p>
297
+ Restore real-world degraded photos with the released RealRestorer model.
298
+ Upload one image, choose a restoration preset, and compare the result with the
299
+ interactive before/after slider.
300
+ </p>
301
+ </div>
302
+ """
303
+ )
304
+
305
+ with gr.Group(elem_classes=["rr-shell"]):
306
+ gr.HTML(
307
+ """
308
+ <div class="rr-note">
309
+ This Space loads the public <b>RealRestorer/RealRestorer</b> model.
310
+ A GPU Space is recommended because the released checkpoint is heavy.
311
+ </div>
312
+ """
313
+ )
314
+
315
+ with gr.Row():
316
+ with gr.Column(scale=1):
317
+ input_image = gr.Image(
318
+ label="Input Image",
319
+ type="pil",
320
+ height=360,
321
+ )
322
+
323
+ with gr.Column(scale=1):
324
+ preset_dropdown = gr.Dropdown(
325
+ choices=list(TASK_PRESETS.keys()),
326
+ value=DEFAULT_PRESET,
327
+ label="Preset",
328
+ )
329
+ prompt_box = gr.Textbox(
330
+ label="Instruction",
331
+ value=TASK_PRESETS[DEFAULT_PRESET],
332
+ lines=4,
333
+ )
334
+
335
+ with gr.Accordion("Advanced Settings", open=False):
336
+ steps_slider = gr.Slider(
337
+ minimum=12,
338
+ maximum=40,
339
+ value=28,
340
+ step=1,
341
+ label="Inference Steps",
342
+ )
343
+ guidance_slider = gr.Slider(
344
+ minimum=1.0,
345
+ maximum=6.0,
346
+ value=3.0,
347
+ step=0.1,
348
+ label="Guidance Scale",
349
+ )
350
+ size_level_slider = gr.Slider(
351
+ minimum=512,
352
+ maximum=1280,
353
+ value=1024,
354
+ step=64,
355
+ label="Resize Target",
356
+ )
357
+ seed_box = gr.Number(
358
+ label="Seed (-1 for random)",
359
+ value=-1,
360
+ precision=0,
361
+ )
362
+
363
+ run_button = gr.Button("Run Inference", variant="primary", elem_id="run-btn")
364
+ status_box = gr.Textbox(
365
+ label="Status",
366
+ value=DEFAULT_STATUS,
367
+ lines=6,
368
+ interactive=False,
369
+ )
370
+
371
+ gr.Markdown("---")
372
+
373
+ with gr.Tabs():
374
+ with gr.Tab("Interactive Compare"):
375
+ slider_html = gr.HTML(DEFAULT_SLIDER)
376
+ with gr.Tab("Restored Output"):
377
+ output_image = gr.Image(label="Output Image", type="pil", interactive=False)
378
+
379
+ gr.HTML(
380
+ """
381
+ <div class="rr-foot">
382
+ Prompts are fully editable. Presets are only shortcuts for common restoration tasks.
383
+ </div>
384
+ """
385
+ )
386
+
387
+ preset_dropdown.change(
388
+ fn=_on_preset_change,
389
+ inputs=[preset_dropdown],
390
+ outputs=[prompt_box],
391
+ )
392
+ run_button.click(
393
+ fn=run_inference,
394
+ inputs=[
395
+ input_image,
396
+ preset_dropdown,
397
+ prompt_box,
398
+ steps_slider,
399
+ guidance_slider,
400
+ size_level_slider,
401
+ seed_box,
402
+ ],
403
+ outputs=[output_image, status_box, slider_html],
404
+ )
405
+
406
+ return demo
407
+
408
+
409
+ demo = build_demo()
410
+ demo.queue(max_size=8, default_concurrency_limit=1)
411
+
412
+
413
+ if __name__ == "__main__":
414
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.0.0
2
+ einops>=0.8.0
3
+ gradio>=5.0.0,<6.0.0
4
+ huggingface_hub>=0.36.0
5
+ opencv-python>=4.10.0
6
+ Pillow>=10.0.0
7
+ safetensors>=0.5.0
8
+ sentencepiece>=0.2.0
9
+ timm>=1.0.0
10
+ torch>=2.5.0
11
+ torchvision>=0.20.0
12
+ transformers>=4.57.0
13
+ git+https://github.com/yfyang007/RealRestorer.git@main#subdirectory=diffusers