Spaces:
Running
on
Zero
Running
on
Zero
| """Gradio demo for UnReflectAnything: remove specular reflections from images.""" | |
| from __future__ import annotations | |
| import shutil | |
| import sys | |
| from pathlib import Path | |
| from typing import NamedTuple | |
| # Allow importing unreflectanything when run from gradio_space (e.g. HF Space with root dir) | |
| _REPO_ROOT = Path(__file__).resolve().parent.parent | |
| if _REPO_ROOT not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| _GRADIO_DIR = Path(__file__).resolve().parent | |
| try: | |
| import spaces | |
| except ModuleNotFoundError: | |
| spaces = None | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| HF_REPO = "AlbeRota/UnReflectAnything" | |
| IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp") | |
| class HFAssets(NamedTuple): | |
| """Paths to assets downloaded from the Hugging Face repo.""" | |
| weights_path: str | |
| config_path: str | |
| logo_path: str | |
| sample_images_dir: Path | |
| def _download_from_hf() -> HFAssets: | |
| """Download weights, config, logo, and sample images from the HF repo. Returns paths to all assets.""" | |
| weights_path = hf_hub_download( | |
| repo_id=HF_REPO, | |
| filename="weights/full_model_weights.pt", | |
| ) | |
| print("Weights path: ", weights_path) | |
| config_path = hf_hub_download( | |
| repo_id=HF_REPO, | |
| filename="configs/pretrained_config.yaml", | |
| ) | |
| logo_path = hf_hub_download( | |
| repo_id=HF_REPO, | |
| filename="assets/logo.png", | |
| ) | |
| sample_images_root = Path( | |
| snapshot_download( | |
| repo_id=HF_REPO, | |
| allow_patterns=["sample_images/*"], | |
| ) | |
| ) | |
| sample_images_dir = sample_images_root / "sample_images" | |
| return HFAssets( | |
| weights_path=weights_path, | |
| config_path=config_path, | |
| logo_path=logo_path, | |
| sample_images_dir=sample_images_dir, | |
| ) | |
| _cached_assets: HFAssets | None = None | |
| def _get_assets() -> HFAssets: | |
| """Return HF assets, downloading once and caching.""" | |
| global _cached_assets | |
| if _cached_assets is None: | |
| _cached_assets = _download_from_hf() | |
| return _cached_assets | |
| # Local copy of sample images under cwd so Gradio never needs allowed_paths for examples | |
| _SAMPLE_IMAGES_COPY_DIR: Path | None = None | |
| def _get_sample_image_paths() -> list[str]: | |
| """Return paths of sample images under cwd (copied from HF cache) so Gradio can use them without allowed_paths.""" | |
| global _SAMPLE_IMAGES_COPY_DIR | |
| assets = _get_assets() | |
| src = assets.sample_images_dir | |
| if not src.is_dir(): | |
| return [] | |
| dest = _GRADIO_DIR / "sample_images" | |
| dest.mkdir(parents=True, exist_ok=True) | |
| paths = [] | |
| for p in sorted(src.iterdir()): | |
| if not p.is_file() or p.suffix.lower() not in IMAGE_EXTENSIONS: | |
| continue | |
| dst_file = dest / p.name | |
| if not dst_file.exists() or dst_file.stat().st_mtime < p.stat().st_mtime: | |
| shutil.copy2(p, dst_file) | |
| paths.append(str(dst_file.resolve())) | |
| _SAMPLE_IMAGES_COPY_DIR = dest | |
| return paths | |
| def _get_sample_image_arrays() -> list[np.ndarray]: | |
| """Load sample images as numpy arrays (H, W, 3) uint8 for gr.Examples so the input Image shows a preview.""" | |
| from PIL import Image | |
| paths = _get_sample_image_paths() | |
| arrays = [] | |
| for p in paths: | |
| try: | |
| img = Image.open(p).convert("RGB") | |
| arrays.append(np.array(img)) | |
| except Exception: | |
| continue | |
| return arrays | |
| # Single model instance; loaded in background at app start or on first inference. | |
| _cached_ura_model = None | |
| _cached_device = None | |
| def _get_model(device: str): | |
| """Return the pretrained model, loading it once and moving to the requested device.""" | |
| global _cached_ura_model, _cached_device | |
| assets = _get_assets() | |
| from unreflectanything import model | |
| # If the model isn't loaded yet, initialize it | |
| if _cached_ura_model is None: | |
| print(f"Loading model initially on {device}...") | |
| _cached_ura_model = model( | |
| pretrained=True, | |
| weights_path=assets.weights_path, | |
| config_path=assets.config_path, | |
| device=device, | |
| verbose=False, | |
| skip_path_resolution=True, | |
| ) | |
| _cached_device = device | |
| # If the model is loaded but on the wrong device, move it | |
| if _cached_device != device: | |
| print(f"Moving model from {_cached_device} to {device}...") | |
| _cached_ura_model.to(device) | |
| _cached_device = device | |
| return _cached_ura_model | |
| def build_ui(): | |
| _get_assets() | |
| # PREVENT: _get_model("cuda") here. It will crash ZeroGPU during startup. | |
| print("UI building... Model will initialize on first inference.") | |
| # Note: Use the decorator directly on the function that does the heavy lifting | |
| def run_inference(image: np.ndarray | None) -> np.ndarray | None: | |
| """Run reflection removal using the cached model on GPU.""" | |
| if image is None: | |
| return None | |
| from torchvision.transforms import functional as TF | |
| import time | |
| # Now it is safe to request 'cuda' because we are inside the @spaces.GPU wrapper | |
| device = "cuda" if (torch.cuda.is_available() and spaces) else "cpu" | |
| ura_model = _get_model(device) | |
| target_side = ura_model.image_size | |
| h, w = image.shape[:2] | |
| # Pre-processing | |
| tensor = TF.to_tensor(image).unsqueeze(0) # [1, 3, H, W] | |
| tensor = TF.resize(tensor, [target_side, target_side], antialias=True) | |
| tensor = tensor.to(device, dtype=torch.float32) | |
| # Create mask based on highlights | |
| mask = tensor.mean(1, keepdim=True) > 0.9 | |
| with torch.no_grad(): | |
| start_time = time.time() | |
| # The model is already on 'device' thanks to _get_model | |
| diffuse = ura_model(images=tensor, inpaint_mask_override=mask) | |
| end_time = time.time() | |
| inference_time_ms = (end_time - start_time) * 1000 | |
| gr.Success(f"Inference complete in {inference_time_ms:.1f} ms") # Use gr.Info for better UX | |
| # Post-processing | |
| diffuse = diffuse.cpu() | |
| diffuse = TF.resize(diffuse, [h, w], antialias=True) | |
| out = diffuse[0].numpy().transpose(1, 2, 0) | |
| out = (np.clip(out, 0.0, 1.0) * 255).astype(np.uint8) | |
| return out | |
| # ... keep your run_inference_slider and UI layout code the same ... | |
| def run_inference_slider( | |
| image: np.ndarray | None, | |
| ) -> tuple[np.ndarray | None, np.ndarray | None] | None: | |
| """Run inference and return (input, output) for ImageSlider.""" | |
| out = run_inference(image) | |
| if out is None: | |
| return None | |
| return (image, out) | |
| assets = _get_assets() | |
| with gr.Blocks(title="UnReflectAnything") as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=0, min_width=100): | |
| if Path(assets.logo_path).is_file(): | |
| gr.Image( | |
| value=assets.logo_path, | |
| show_label=False, | |
| interactive=False, | |
| height=100, | |
| container=False, | |
| buttons=[], | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| """ | |
| # UnReflectAnything | |
| UnReflectAnything inputs any RGB image and **removes specular highlights**, | |
| returning a clean diffuse-only outputs. We trained UnReflectAnything by synthetizing | |
| specularities and supervising in DINOv3 feature space. | |
| UnReflectAnything works on both natural indoor and **surgical/endoscopic** domain data. | |
| Visit the [Project Page](https://alberto-rota.github.io/UnReflectAnything/)! | |
| """ | |
| ) | |
| with gr.Row(): | |
| inp = gr.Image( | |
| type="numpy", | |
| label="Input", | |
| height=600, | |
| width=600, | |
| ) | |
| out_slider = gr.ImageSlider( | |
| label="Output", | |
| type="numpy", | |
| height=600, | |
| show_label=True, | |
| ) | |
| run_btn = gr.Button("Run UnReflectAnything", variant="primary") | |
| run_btn.click( | |
| fn=run_inference_slider, | |
| inputs=[inp], | |
| outputs=out_slider, | |
| ) | |
| sample_arrays = _get_sample_image_arrays() | |
| if sample_arrays: | |
| gr.Examples( | |
| examples=[[arr] for arr in sample_arrays], | |
| inputs=inp, | |
| label="Pre-loaded examples", | |
| examples_per_page=20, | |
| ) | |
| gr.HTML("""<hr>""") | |
| gr.Markdown(""" | |
| [Project Page](https://alberto-rota.github.io/UnReflectAnything/) ⋅ | |
| [GitHub](https://github.com/alberto-rota/UnReflectAnything) ⋅ | |
| [Model Card](https://huggingface.co/AlbeRota/UnReflectAnything) ⋅ | |
| [Paper](https://arxiv.org/abs/2512.09583) ⋅ | |
| [Contact](mailto:alberto1.rota@polimi.it) | |
| """) | |
| return demo | |
| demo = build_ui() | |
| def _launch_allowed_paths(): | |
| """Paths Gradio is allowed to serve (e.g. for gr.Examples from HF cache).""" | |
| paths = [str(_GRADIO_DIR)] | |
| try: | |
| assets = _get_assets() | |
| sample_dir = assets.sample_images_dir | |
| if sample_dir.is_dir(): | |
| paths.append(str(sample_dir.resolve())) | |
| # Also allow parent (snapshot root) in case Gradio resolves paths from repo root | |
| parent = sample_dir.parent | |
| if parent.is_dir(): | |
| paths.append(str(parent.resolve())) | |
| except Exception as e: | |
| print(f"Warning: could not add HF sample_images to allowed_paths: {e}") | |
| return paths | |
| def _launch_kwargs(): | |
| """Default kwargs for launch() so allowed_paths are always set (e.g. when HF Spaces runs demo.launch()).""" | |
| return { | |
| "allowed_paths": _launch_allowed_paths(), | |
| "theme": gr.themes.Soft(primary_hue="orange", secondary_hue="blue"), | |
| } | |
| # Ensure launch() always receives allowed_paths (e.g. when HF Spaces runner calls demo.launch() without args) | |
| _original_launch = demo.launch | |
| def _launch_with_allowed_paths(*args, **kwargs): | |
| for key, value in _launch_kwargs().items(): | |
| if key not in kwargs: | |
| kwargs[key] = value | |
| return _original_launch(*args, **kwargs) | |
| demo.launch = _launch_with_allowed_paths | |
| # Replace your existing launch logic at the very bottom of the file with this: | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=True, server_name="0.0.0.0", server_port=7860) | |
| else: | |
| # This handles cases where Hugging Face imports the file | |
| demo.launch(ssr_mode=True, server_name="0.0.0.0", server_port=7860) |