"""Gradio demo for UnReflectAnything: remove specular reflections from images.""" from __future__ import annotations import shutil import sys from pathlib import Path from typing import NamedTuple # Allow importing unreflectanything when run from gradio_space (e.g. HF Space with root dir) _REPO_ROOT = Path(__file__).resolve().parent.parent if _REPO_ROOT not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) # Guard against missing '__main__' in worker threads (wandb/pydantic compat) if "__main__" not in sys.modules: import types sys.modules["__main__"] = types.ModuleType("__main__") _GRADIO_DIR = Path(__file__).resolve().parent try: import spaces except ModuleNotFoundError: spaces = None import gradio as gr import numpy as np import torch from huggingface_hub import hf_hub_download, snapshot_download HF_REPO = "AlbeRota/UnReflectAnything" IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp") class HFAssets(NamedTuple): """Paths to assets downloaded from the Hugging Face repo.""" weights_path: str config_path: str logo_path: str sample_images_dir: Path def _download_from_hf() -> HFAssets: """Download weights, config, logo, and sample images from the HF repo. Returns paths to all assets.""" weights_path = hf_hub_download( repo_id=HF_REPO, filename="weights/full_model_weights.pt", ) print("Weights path: ", weights_path) # config_path = hf_hub_download( # repo_id=HF_REPO, # filename="configs/pretrained_config.yaml", # ) logo_path = hf_hub_download( repo_id=HF_REPO, filename="assets/logo.png", ) sample_images_root = Path( snapshot_download( repo_id=HF_REPO, allow_patterns=["sample_images/*"], ) ) sample_images_dir = sample_images_root / "sample_images" return HFAssets( weights_path=weights_path, config_path=Path(__file__).parent / "pretrained_config.yaml", logo_path=logo_path, sample_images_dir=sample_images_dir, ) _cached_assets: HFAssets | None = None def _get_assets() -> HFAssets: """Return HF assets, downloading once and caching.""" global _cached_assets if _cached_assets is None: _cached_assets = _download_from_hf() return _cached_assets # Local copy of sample images under cwd so Gradio never needs allowed_paths for examples _SAMPLE_IMAGES_COPY_DIR: Path | None = None def _get_sample_image_paths() -> list[str]: """Return paths of sample images under cwd (copied from HF cache) so Gradio can use them without allowed_paths.""" global _SAMPLE_IMAGES_COPY_DIR assets = _get_assets() src = assets.sample_images_dir if not src.is_dir(): return [] dest = _GRADIO_DIR / "sample_images" dest.mkdir(parents=True, exist_ok=True) paths = [] for p in sorted(src.iterdir()): if not p.is_file() or p.suffix.lower() not in IMAGE_EXTENSIONS: continue dst_file = dest / p.name if not dst_file.exists() or dst_file.stat().st_mtime < p.stat().st_mtime: shutil.copy2(p, dst_file) paths.append(str(dst_file.resolve())) _SAMPLE_IMAGES_COPY_DIR = dest return paths def _get_sample_image_arrays() -> list[np.ndarray]: """Load sample images as numpy arrays (H, W, 3) uint8 for gr.Examples so the input Image shows a preview.""" from PIL import Image paths = _get_sample_image_paths() arrays = [] for p in paths: try: img = Image.open(p).convert("RGB") arrays.append(np.array(img)) except Exception: continue return arrays # Single model instance; loaded in background at app start or on first inference. _cached_ura_model = None _cached_device = None def _get_model(device: str): """Return the pretrained model, loading it once and moving to the requested device.""" global _cached_ura_model, _cached_device assets = _get_assets() from unreflectanything import model # If the model isn't loaded yet, initialize it if _cached_ura_model is None: print(f"Loading model initially on {device}...") _cached_ura_model = model( pretrained=True, weights_path=assets.weights_path, # config_path=assets.config_path, device=device, verbose=False, ) _cached_device = device # If the model is loaded but on the wrong device, move it if _cached_device != device: print(f"Moving model from {_cached_device} to {device}...") _cached_ura_model.to(device) _cached_device = device return _cached_ura_model def build_ui(): _get_assets() # PREVENT: _get_model("cuda") here. It will crash ZeroGPU during startup. print("UI building... Model will initialize on first inference.") # Note: Use the decorator directly on the function that does the heavy lifting def _extract_tokens_nc(tokens) -> torch.Tensor: """Extract [N, C] from list (last layer) or tensor [B, N, C] (first sample).""" t = tokens[-1] if isinstance(tokens, list) else tokens t = t[0].cpu().float() if t.dim() == 3 else t.cpu().float() return t.squeeze(0) if t.dim() == 3 else t # [N, C] def _tokens_pair_to_rgb( tokens_completed, tokens_input, h: int, w: int, ) -> tuple[np.ndarray, np.ndarray]: """PCA fit once on completed tokens; apply same mean and V to both; joint min/max norm.""" from PIL import Image as PILImage t_comp = _extract_tokens_nc(tokens_completed) # [N, C] t_inp = _extract_tokens_nc(tokens_input) # [N, C] mean = t_comp.mean(dim=0, keepdim=True) # [1, C] – fit on completed only centered_comp = t_comp - mean # [N, C] U, S, V = torch.svd_lowrank(centered_comp, q=3) # V: [C, 3] # Project both with same parameters (same mean, same V) proj_comp = (t_comp - mean) @ V # [N, 3] proj_inp = (t_inp - mean) @ V # [N, 3] # Joint min/max so both images share the same color scale lo = min(proj_comp.min().item(), proj_inp.min().item()) hi = max(proj_comp.max().item(), proj_inp.max().item()) eps = 1e-8 proj_comp = (proj_comp - lo) / (hi - lo + eps) proj_inp = (proj_inp - lo) / (hi - lo + eps) grid = int(t_comp.shape[0] ** 0.5) def to_img(proj: torch.Tensor) -> np.ndarray: arr = (proj.reshape(grid, grid, 3).numpy() * 255).clip(0, 255).astype(np.uint8) return np.array(PILImage.fromarray(arr).resize((w, h), PILImage.BILINEAR)) return to_img(proj_comp), to_img(proj_inp) def _gray_to_rgb(tensor_1c: torch.Tensor, h: int, w: int) -> np.ndarray: """Convert [B, 1, H_model, W_model] to resized [H, W, 3] uint8 grayscale-as-RGB.""" from torchvision.transforms import functional as TF resized = TF.resize(tensor_1c.cpu(), [h, w], antialias=True) # [B, 1, H, W] gray = (resized[0, 0].numpy().clip(0.0, 1.0) * 255).astype(np.uint8) # [H, W] return np.stack([gray] * 3, axis=-1) # [H, W, 3] @spaces.GPU if spaces else lambda x: x def run_inference( image: np.ndarray | None, threshold: float = 0.3, dilation: int = 40, ) -> dict[str, np.ndarray] | None: """Run reflection removal; return all visualisable outputs as numpy arrays.""" if image is None: return None from torchvision.transforms import functional as TF import time device = "cuda" if torch.cuda.is_available() else "cpu" ura_model = _get_model(device) target_side = ura_model.image_size h, w = image.shape[:2] tensor = TF.to_tensor(image).unsqueeze(0) # [1, 3, H, W] tensor = TF.resize(tensor, [target_side, target_side], antialias=True) tensor = tensor.to(device, dtype=torch.float32) with torch.no_grad(): start_time = time.time() out = ura_model( images=tensor, threshold=threshold, dilation=int(dilation), return_dict=True, ) end_time = time.time() inference_time_ms = (end_time - start_time) * 1000 gr.Info(f"Inference complete in {inference_time_ms:.1f} ms") results: dict[str, np.ndarray] = {} # Diffuse: [1, 3, H, W] -> [H, W, 3] uint8 diffuse = TF.resize(out["diffuse"].cpu(), [h, w], antialias=True) results["diffuse"] = (diffuse[0].numpy().transpose(1, 2, 0).clip(0.0, 1.0) * 255).astype(np.uint8) # Detected highlight: RGBA overlay superimposed on darkened input hl_data = out.get("highlight") if hl_data is not None: m = TF.resize(hl_data.cpu(), [h, w], antialias=True)[0, 0].numpy().clip(0.0, 1.0) # [H, W] image_dark = (image.astype(np.float32) * 0.5).clip(0, 255) # [H, W, 3] base overlay_rgb = np.array([255, 200, 0], dtype=np.float32) # amber alpha = (0.5 * m)[:, :, np.newaxis] # [H, W, 1] comp_rgb = (1 - alpha) * image_dark + alpha * overlay_rgb # [H, W, 3] comp_uint8 = np.clip(comp_rgb, 0, 255).astype(np.uint8) results["highlight_overlay"] = np.concatenate( [comp_uint8, np.full((h, w, 1), 255, dtype=np.uint8)], axis=-1 ) # [H, W, 4] RGBA results["highlight_gray"] = _gray_to_rgb(hl_data, h, w) # Highlight mask (binary/dilated) mask_data = out.get("highlight_mask") if mask_data is not None: results["highlight_mask"] = _gray_to_rgb(mask_data, h, w) # DINOv3 tokens (PCA visualization) – same PCA fit for both, joint color scale tokens_completed_data = out.get("tokens_completed") tokens_input_data = out.get("tokens_input") if tokens_completed_data is not None and tokens_input_data is not None: img_comp, img_inp = _tokens_pair_to_rgb(tokens_completed_data, tokens_input_data, h, w) results["tokens_completed"] = img_comp results["tokens_input"] = img_inp elif tokens_completed_data is not None: t = _extract_tokens_nc(tokens_completed_data) mean = t.mean(dim=0, keepdim=True) V = torch.svd_lowrank(t - mean, q=3)[2] proj = (t - mean) @ V lo, hi = proj.min().item(), proj.max().item() proj = (proj - lo) / (hi - lo + 1e-8) grid = int(t.shape[0] ** 0.5) from PIL import Image as PILImage arr = (proj.reshape(grid, grid, 3).numpy() * 255).clip(0, 255).astype(np.uint8) results["tokens_completed"] = np.array(PILImage.fromarray(arr).resize((w, h), PILImage.BILINEAR)) results["tokens_input"] = results["tokens_completed"] elif tokens_input_data is not None: t = _extract_tokens_nc(tokens_input_data) mean = t.mean(dim=0, keepdim=True) V = torch.svd_lowrank(t - mean, q=3)[2] proj = (t - mean) @ V lo, hi = proj.min().item(), proj.max().item() proj = (proj - lo) / (hi - lo + 1e-8) grid = int(t.shape[0] ** 0.5) from PIL import Image as PILImage arr = (proj.reshape(grid, grid, 3).numpy() * 255).clip(0, 255).astype(np.uint8) results["tokens_input"] = np.array(PILImage.fromarray(arr).resize((w, h), PILImage.BILINEAR)) results["tokens_completed"] = results["tokens_input"] return results VIEW_MODES = ["Diffuse", "Highlight", "Inpaint mask", "DINOv3 space"] def run_inference_slider( image: np.ndarray | None, threshold: float, dilation: int, ) -> tuple: """Return 4 slider tuples: (left, right) for each view mode.""" results = run_inference(image, threshold, dilation) if results is None: return (None,) * 4 diffuse = results["diffuse"] # Darken input for slider so highlights are more visible; Gradio expects uint8 [0,255] image_dark = (image.astype(np.float32) * 0.5).clip(0, 255).astype(np.uint8) hl_overlay = results.get("highlight_overlay", diffuse) hl_gray = results.get("highlight_gray", diffuse) hl_mask = results.get("highlight_mask", diffuse) tok_comp = results.get("tokens_completed", diffuse) tok_inp = results.get("tokens_input", diffuse) return ( (image, diffuse), # Diffuse (image_dark, hl_overlay), # Detected highlight (hl_gray, hl_mask), # Highlight mask (tok_inp, tok_comp), # DINOv3 tokens ) assets = _get_assets() with gr.Blocks(title="UnReflectAnything") as demo: with gr.Row(elem_classes="mobile-stack"): with gr.Column(scale=0, min_width=100): if Path(assets.logo_path).is_file(): gr.Image( value=assets.logo_path, show_label=False, interactive=False, height=100, container=False, buttons=[], ) with gr.Column(scale=1): gr.Markdown( """ # UnReflectAnything UnReflectAnything inputs any RGB image and **removes specular highlights**, returning a clean diffuse-only outputs. We trained UnReflectAnything by synthetizing specularities and supervising in DINOv3 feature space. UnReflectAnything works on both natural indoor and **surgical/endoscopic** domain data. Visit the [Project Page](https://alberto-rota.github.io/UnReflectAnything/)! """ ) slider_labels = [ "Diffuse", "Highlight", "Inpaint", "DINOv3 Space", ] with gr.Row(elem_classes="mobile-stack"): inp = gr.Image( type="numpy", label="Input", height=600, ) sliders = [] for i, lbl in enumerate(slider_labels): sliders.append( gr.ImageSlider( label=lbl, type="numpy", height=600, show_label=True, visible=(i == 0), ) ) with gr.Row(elem_classes="mobile-stack"): threshold_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.2, step=0.01, label="Highlight Threshold", info="Brightness threshold for detecting specular highlights", ) dilation_slider = gr.Slider( minimum=0, maximum=100, value=40, step=1, label="Mask Dilation", info="Dilation (px) applied to the detected highlight mask", ) view_radio = gr.Radio( choices=VIEW_MODES, value=VIEW_MODES[0], label="Output view", ) run_btn = gr.Button("Run UnReflectAnything", variant="primary") run_btn.click( fn=run_inference_slider, inputs=[inp, threshold_slider, dilation_slider], outputs=sliders, ) view_radio.change( fn=lambda mode: [gr.update(visible=(m == mode)) for m in VIEW_MODES], inputs=view_radio, outputs=sliders, ) sample_arrays = _get_sample_image_arrays() if sample_arrays: gr.Examples( examples=[[arr] for arr in sample_arrays], inputs=inp, label="Pre-loaded examples", examples_per_page=20, ) gr.HTML("""