File size: 10,971 Bytes
4ed98e6
 
 
 
f7eee2e
4ed98e6
 
7c909e3
4ed98e6
 
 
 
 
 
b604e51
f7eee2e
 
 
 
4ed98e6
 
 
 
7c909e3
4ed98e6
7c909e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b604e51
7c909e3
af9d144
b604e51
de5712c
b604e51
7c909e3
af9d144
b604e51
7c909e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b604e51
7c909e3
 
 
 
 
 
b604e51
 
f7eee2e
 
 
 
7c909e3
f7eee2e
 
7c909e3
f7eee2e
 
b604e51
f7eee2e
 
 
 
 
 
 
 
 
 
 
b604e51
 
 
65044b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b604e51
 
 
 
 
 
2c3f571
b604e51
7c909e3
2c3f571
d336690
 
2c3f571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d336690
4ed98e6
b604e51
7c909e3
2c3f571
 
b604e51
2c3f571
f7eee2e
b604e51
2c3f571
b604e51
 
2c3f571
b604e51
2c3f571
b604e51
2c3f571
 
b604e51
2c3f571
b604e51
 
2c3f571
 
 
b604e51
2c3f571
 
 
 
f7eee2e
b604e51
f7eee2e
2c3f571
b604e51
f7eee2e
2c3f571
f7eee2e
65044b4
2c3f571
 
 
b604e51
 
 
 
 
2c3f571
 
b604e51
 
 
 
 
 
 
 
 
7c909e3
b604e51
 
 
7c909e3
f44e433
7c909e3
f44e433
 
 
 
 
 
b604e51
 
 
 
 
 
 
 
 
 
4ed98e6
 
 
 
f7eee2e
b604e51
 
4ed98e6
b604e51
f7eee2e
4ed98e6
b604e51
 
4ed98e6
b604e51
4ed98e6
b604e51
 
 
4ed98e6
65044b4
 
b604e51
65044b4
b604e51
 
 
 
 
 
 
 
 
 
f7eee2e
b604e51
4ed98e6
 
 
 
 
b604e51
 
7c909e3
b604e51
7c909e3
 
f7eee2e
 
 
 
 
 
 
 
 
b604e51
 
 
f7eee2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c3f571
4ed98e6
65044b4
2c3f571
 
65044b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
"""Gradio demo for UnReflectAnything: remove specular reflections from images."""

from __future__ import annotations

import shutil
import sys
from pathlib import Path
from typing import NamedTuple

# Allow importing unreflectanything when run from gradio_space (e.g. HF Space with root dir)
_REPO_ROOT = Path(__file__).resolve().parent.parent
if _REPO_ROOT not in sys.path:
    sys.path.insert(0, str(_REPO_ROOT))

_GRADIO_DIR = Path(__file__).resolve().parent
try:
    import spaces
except ModuleNotFoundError:
    spaces = None
import gradio as gr
import numpy as np
import torch

from huggingface_hub import hf_hub_download, snapshot_download

HF_REPO = "AlbeRota/UnReflectAnything"
IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp")


class HFAssets(NamedTuple):
    """Paths to assets downloaded from the Hugging Face repo."""

    weights_path: str
    config_path: str
    logo_path: str
    sample_images_dir: Path


def _download_from_hf() -> HFAssets:
    """Download weights, config, logo, and sample images from the HF repo. Returns paths to all assets."""
    weights_path = hf_hub_download(
        repo_id=HF_REPO,
        filename="weights/full_model_weights.pt",
    )
    print("Weights path: ", weights_path)
    config_path = hf_hub_download(
        repo_id=HF_REPO,
        filename="configs/pretrained_config.yaml",
    )
    logo_path = hf_hub_download(
        repo_id=HF_REPO,
        filename="assets/logo.png",
    )
    sample_images_root = Path(
        snapshot_download(
            repo_id=HF_REPO,
            allow_patterns=["sample_images/*"],
        )
    )
    sample_images_dir = sample_images_root / "sample_images"
    return HFAssets(
        weights_path=weights_path,
        config_path=config_path,
        logo_path=logo_path,
        sample_images_dir=sample_images_dir,
    )


_cached_assets: HFAssets | None = None


def _get_assets() -> HFAssets:
    """Return HF assets, downloading once and caching."""
    global _cached_assets
    if _cached_assets is None:
        _cached_assets = _download_from_hf()
    return _cached_assets


# Local copy of sample images under cwd so Gradio never needs allowed_paths for examples
_SAMPLE_IMAGES_COPY_DIR: Path | None = None


def _get_sample_image_paths() -> list[str]:
    """Return paths of sample images under cwd (copied from HF cache) so Gradio can use them without allowed_paths."""
    global _SAMPLE_IMAGES_COPY_DIR
    assets = _get_assets()
    src = assets.sample_images_dir
    if not src.is_dir():
        return []
    dest = _GRADIO_DIR / "sample_images"
    dest.mkdir(parents=True, exist_ok=True)
    paths = []
    for p in sorted(src.iterdir()):
        if not p.is_file() or p.suffix.lower() not in IMAGE_EXTENSIONS:
            continue
        dst_file = dest / p.name
        if not dst_file.exists() or dst_file.stat().st_mtime < p.stat().st_mtime:
            shutil.copy2(p, dst_file)
        paths.append(str(dst_file.resolve()))
    _SAMPLE_IMAGES_COPY_DIR = dest
    return paths


def _get_sample_image_arrays() -> list[np.ndarray]:
    """Load sample images as numpy arrays (H, W, 3) uint8 for gr.Examples so the input Image shows a preview."""
    from PIL import Image

    paths = _get_sample_image_paths()
    arrays = []
    for p in paths:
        try:
            img = Image.open(p).convert("RGB")
            arrays.append(np.array(img))
        except Exception:
            continue
    return arrays


# Single model instance; loaded in background at app start or on first inference.
_cached_ura_model = None
_cached_device = None


def _get_model(device: str):
    """Return the pretrained model, loading it once and moving to the requested device."""
    global _cached_ura_model, _cached_device
    assets = _get_assets()
    
    from unreflectanything import model

    # If the model isn't loaded yet, initialize it
    if _cached_ura_model is None:
        print(f"Loading model initially on {device}...")
        _cached_ura_model = model(
            pretrained=True,
            weights_path=assets.weights_path,
            config_path=assets.config_path,
            device=device,
            verbose=False,
            skip_path_resolution=True,
        )
        _cached_device = device
    
    # If the model is loaded but on the wrong device, move it
    if _cached_device != device:
        print(f"Moving model from {_cached_device} to {device}...")
        _cached_ura_model.to(device)
        _cached_device = device
        
    return _cached_ura_model

def build_ui():
    _get_assets()
    # PREVENT: _get_model("cuda") here. It will crash ZeroGPU during startup.
    print("UI building... Model will initialize on first inference.")

    # Note: Use the decorator directly on the function that does the heavy lifting
    @spaces.GPU if spaces else lambda x: x
    def run_inference(image: np.ndarray | None) -> np.ndarray | None:
        """Run reflection removal using the cached model on GPU."""
        if image is None:
            return None
        
        from torchvision.transforms import functional as TF
        import time

        # Now it is safe to request 'cuda' because we are inside the @spaces.GPU wrapper
        device = "cuda" if (torch.cuda.is_available() and spaces) else "cpu"
        ura_model = _get_model(device)
        
        target_side = ura_model.image_size
        h, w = image.shape[:2]
        
        # Pre-processing
        tensor = TF.to_tensor(image).unsqueeze(0)  # [1, 3, H, W]
        tensor = TF.resize(tensor, [target_side, target_side], antialias=True)
        tensor = tensor.to(device, dtype=torch.float32)
        
        # Create mask based on highlights
        mask = tensor.mean(1, keepdim=True) > 0.9 

        with torch.no_grad():
            start_time = time.time()
            # The model is already on 'device' thanks to _get_model
            diffuse = ura_model(images=tensor, inpaint_mask_override=mask)
            end_time = time.time()
        
        inference_time_ms = (end_time - start_time) * 1000
        gr.Success(f"Inference complete in {inference_time_ms:.1f} ms") # Use gr.Info for better UX

        # Post-processing
        diffuse = diffuse.cpu()
        diffuse = TF.resize(diffuse, [h, w], antialias=True)
        out = diffuse[0].numpy().transpose(1, 2, 0)
        out = (np.clip(out, 0.0, 1.0) * 255).astype(np.uint8)
        return out

    # ... keep your run_inference_slider and UI layout code the same ...

    def run_inference_slider(
        image: np.ndarray | None,
    ) -> tuple[np.ndarray | None, np.ndarray | None] | None:
        """Run inference and return (input, output) for ImageSlider."""
        out = run_inference(image)
        if out is None:
            return None
        return (image, out)

    assets = _get_assets()
    with gr.Blocks(title="UnReflectAnything") as demo:
        with gr.Row():
            with gr.Column(scale=0, min_width=100):
                if Path(assets.logo_path).is_file():
                    gr.Image(
                        value=assets.logo_path,
                        show_label=False,
                        interactive=False,
                        height=100,
                        container=False,
                        buttons=[],
                    )
            with gr.Column(scale=1):
                gr.Markdown(
                    """
                    # UnReflectAnything
                    UnReflectAnything inputs any RGB image and **removes specular highlights**, 
                    returning a clean diffuse-only outputs. We trained UnReflectAnything by synthetizing 
                    specularities and supervising in DINOv3 feature space. 
                    UnReflectAnything works on both natural indoor and **surgical/endoscopic** domain data. 
                    Visit the [Project Page](https://alberto-rota.github.io/UnReflectAnything/)!                  
                    """
                )
        with gr.Row():
            inp = gr.Image(
                type="numpy",
                label="Input",
                height=600,
                width=600,
            )
            out_slider = gr.ImageSlider(
                label="Output",
                type="numpy",
                height=600,
                show_label=True,
            )
        run_btn = gr.Button("Run UnReflectAnything", variant="primary")
        run_btn.click(
            fn=run_inference_slider,
            inputs=[inp],
            outputs=out_slider,
        )
        sample_arrays = _get_sample_image_arrays()
        if sample_arrays:
            gr.Examples(
                examples=[[arr] for arr in sample_arrays],
                inputs=inp,
                label="Pre-loaded examples",
                examples_per_page=20,
            )
        gr.HTML("""<hr>""")
        gr.Markdown("""
                    [Project Page](https://alberto-rota.github.io/UnReflectAnything/) ⋅
                    [GitHub](https://github.com/alberto-rota/UnReflectAnything) ⋅
                    [Model Card](https://huggingface.co/AlbeRota/UnReflectAnything) ⋅
                    [Paper](https://arxiv.org/abs/2512.09583) ⋅
                    [Contact](mailto:alberto1.rota@polimi.it)
                    """)
    return demo


demo = build_ui()


def _launch_allowed_paths():
    """Paths Gradio is allowed to serve (e.g. for gr.Examples from HF cache)."""
    paths = [str(_GRADIO_DIR)]
    try:
        assets = _get_assets()
        sample_dir = assets.sample_images_dir
        if sample_dir.is_dir():
            paths.append(str(sample_dir.resolve()))
        # Also allow parent (snapshot root) in case Gradio resolves paths from repo root
        parent = sample_dir.parent
        if parent.is_dir():
            paths.append(str(parent.resolve()))
    except Exception as e:
        print(f"Warning: could not add HF sample_images to allowed_paths: {e}")
    return paths


def _launch_kwargs():
    """Default kwargs for launch() so allowed_paths are always set (e.g. when HF Spaces runs demo.launch())."""
    return {
        "allowed_paths": _launch_allowed_paths(),
        "theme": gr.themes.Soft(primary_hue="orange", secondary_hue="blue"),
    }


# Ensure launch() always receives allowed_paths (e.g. when HF Spaces runner calls demo.launch() without args)
_original_launch = demo.launch


def _launch_with_allowed_paths(*args, **kwargs):
    for key, value in _launch_kwargs().items():
        if key not in kwargs:
            kwargs[key] = value
    return _original_launch(*args, **kwargs)


demo.launch = _launch_with_allowed_paths


# Replace your existing launch logic at the very bottom of the file with this:
if __name__ == "__main__":
    demo.launch(ssr_mode=True, server_name="0.0.0.0", server_port=7860)
else:
    # This handles cases where Hugging Face imports the file
    demo.launch(ssr_mode=True, server_name="0.0.0.0", server_port=7860)