AlbeRota's picture
Hopefully fixing input preview
65044b4
"""Gradio demo for UnReflectAnything: remove specular reflections from images."""
from __future__ import annotations
import shutil
import sys
from pathlib import Path
from typing import NamedTuple
# Allow importing unreflectanything when run from gradio_space (e.g. HF Space with root dir)
_REPO_ROOT = Path(__file__).resolve().parent.parent
if _REPO_ROOT not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
_GRADIO_DIR = Path(__file__).resolve().parent
try:
import spaces
except ModuleNotFoundError:
spaces = None
import gradio as gr
import numpy as np
import torch
from huggingface_hub import hf_hub_download, snapshot_download
HF_REPO = "AlbeRota/UnReflectAnything"
IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp")
class HFAssets(NamedTuple):
"""Paths to assets downloaded from the Hugging Face repo."""
weights_path: str
config_path: str
logo_path: str
sample_images_dir: Path
def _download_from_hf() -> HFAssets:
"""Download weights, config, logo, and sample images from the HF repo. Returns paths to all assets."""
weights_path = hf_hub_download(
repo_id=HF_REPO,
filename="weights/full_model_weights.pt",
)
print("Weights path: ", weights_path)
config_path = hf_hub_download(
repo_id=HF_REPO,
filename="configs/pretrained_config.yaml",
)
logo_path = hf_hub_download(
repo_id=HF_REPO,
filename="assets/logo.png",
)
sample_images_root = Path(
snapshot_download(
repo_id=HF_REPO,
allow_patterns=["sample_images/*"],
)
)
sample_images_dir = sample_images_root / "sample_images"
return HFAssets(
weights_path=weights_path,
config_path=config_path,
logo_path=logo_path,
sample_images_dir=sample_images_dir,
)
_cached_assets: HFAssets | None = None
def _get_assets() -> HFAssets:
"""Return HF assets, downloading once and caching."""
global _cached_assets
if _cached_assets is None:
_cached_assets = _download_from_hf()
return _cached_assets
# Local copy of sample images under cwd so Gradio never needs allowed_paths for examples
_SAMPLE_IMAGES_COPY_DIR: Path | None = None
def _get_sample_image_paths() -> list[str]:
"""Return paths of sample images under cwd (copied from HF cache) so Gradio can use them without allowed_paths."""
global _SAMPLE_IMAGES_COPY_DIR
assets = _get_assets()
src = assets.sample_images_dir
if not src.is_dir():
return []
dest = _GRADIO_DIR / "sample_images"
dest.mkdir(parents=True, exist_ok=True)
paths = []
for p in sorted(src.iterdir()):
if not p.is_file() or p.suffix.lower() not in IMAGE_EXTENSIONS:
continue
dst_file = dest / p.name
if not dst_file.exists() or dst_file.stat().st_mtime < p.stat().st_mtime:
shutil.copy2(p, dst_file)
paths.append(str(dst_file.resolve()))
_SAMPLE_IMAGES_COPY_DIR = dest
return paths
def _get_sample_image_arrays() -> list[np.ndarray]:
"""Load sample images as numpy arrays (H, W, 3) uint8 for gr.Examples so the input Image shows a preview."""
from PIL import Image
paths = _get_sample_image_paths()
arrays = []
for p in paths:
try:
img = Image.open(p).convert("RGB")
arrays.append(np.array(img))
except Exception:
continue
return arrays
# Single model instance; loaded in background at app start or on first inference.
_cached_ura_model = None
_cached_device = None
def _get_model(device: str):
"""Return the pretrained model, loading it once and moving to the requested device."""
global _cached_ura_model, _cached_device
assets = _get_assets()
from unreflectanything import model
# If the model isn't loaded yet, initialize it
if _cached_ura_model is None:
print(f"Loading model initially on {device}...")
_cached_ura_model = model(
pretrained=True,
weights_path=assets.weights_path,
config_path=assets.config_path,
device=device,
verbose=False,
skip_path_resolution=True,
)
_cached_device = device
# If the model is loaded but on the wrong device, move it
if _cached_device != device:
print(f"Moving model from {_cached_device} to {device}...")
_cached_ura_model.to(device)
_cached_device = device
return _cached_ura_model
def build_ui():
_get_assets()
# PREVENT: _get_model("cuda") here. It will crash ZeroGPU during startup.
print("UI building... Model will initialize on first inference.")
# Note: Use the decorator directly on the function that does the heavy lifting
@spaces.GPU if spaces else lambda x: x
def run_inference(image: np.ndarray | None) -> np.ndarray | None:
"""Run reflection removal using the cached model on GPU."""
if image is None:
return None
from torchvision.transforms import functional as TF
import time
# Now it is safe to request 'cuda' because we are inside the @spaces.GPU wrapper
device = "cuda" if (torch.cuda.is_available() and spaces) else "cpu"
ura_model = _get_model(device)
target_side = ura_model.image_size
h, w = image.shape[:2]
# Pre-processing
tensor = TF.to_tensor(image).unsqueeze(0) # [1, 3, H, W]
tensor = TF.resize(tensor, [target_side, target_side], antialias=True)
tensor = tensor.to(device, dtype=torch.float32)
# Create mask based on highlights
mask = tensor.mean(1, keepdim=True) > 0.9
with torch.no_grad():
start_time = time.time()
# The model is already on 'device' thanks to _get_model
diffuse = ura_model(images=tensor, inpaint_mask_override=mask)
end_time = time.time()
inference_time_ms = (end_time - start_time) * 1000
gr.Success(f"Inference complete in {inference_time_ms:.1f} ms") # Use gr.Info for better UX
# Post-processing
diffuse = diffuse.cpu()
diffuse = TF.resize(diffuse, [h, w], antialias=True)
out = diffuse[0].numpy().transpose(1, 2, 0)
out = (np.clip(out, 0.0, 1.0) * 255).astype(np.uint8)
return out
# ... keep your run_inference_slider and UI layout code the same ...
def run_inference_slider(
image: np.ndarray | None,
) -> tuple[np.ndarray | None, np.ndarray | None] | None:
"""Run inference and return (input, output) for ImageSlider."""
out = run_inference(image)
if out is None:
return None
return (image, out)
assets = _get_assets()
with gr.Blocks(title="UnReflectAnything") as demo:
with gr.Row():
with gr.Column(scale=0, min_width=100):
if Path(assets.logo_path).is_file():
gr.Image(
value=assets.logo_path,
show_label=False,
interactive=False,
height=100,
container=False,
buttons=[],
)
with gr.Column(scale=1):
gr.Markdown(
"""
# UnReflectAnything
UnReflectAnything inputs any RGB image and **removes specular highlights**,
returning a clean diffuse-only outputs. We trained UnReflectAnything by synthetizing
specularities and supervising in DINOv3 feature space.
UnReflectAnything works on both natural indoor and **surgical/endoscopic** domain data.
Visit the [Project Page](https://alberto-rota.github.io/UnReflectAnything/)!
"""
)
with gr.Row():
inp = gr.Image(
type="numpy",
label="Input",
height=600,
width=600,
)
out_slider = gr.ImageSlider(
label="Output",
type="numpy",
height=600,
show_label=True,
)
run_btn = gr.Button("Run UnReflectAnything", variant="primary")
run_btn.click(
fn=run_inference_slider,
inputs=[inp],
outputs=out_slider,
)
sample_arrays = _get_sample_image_arrays()
if sample_arrays:
gr.Examples(
examples=[[arr] for arr in sample_arrays],
inputs=inp,
label="Pre-loaded examples",
examples_per_page=20,
)
gr.HTML("""<hr>""")
gr.Markdown("""
[Project Page](https://alberto-rota.github.io/UnReflectAnything/) ⋅
[GitHub](https://github.com/alberto-rota/UnReflectAnything) ⋅
[Model Card](https://huggingface.co/AlbeRota/UnReflectAnything) ⋅
[Paper](https://arxiv.org/abs/2512.09583) ⋅
[Contact](mailto:alberto1.rota@polimi.it)
""")
return demo
demo = build_ui()
def _launch_allowed_paths():
"""Paths Gradio is allowed to serve (e.g. for gr.Examples from HF cache)."""
paths = [str(_GRADIO_DIR)]
try:
assets = _get_assets()
sample_dir = assets.sample_images_dir
if sample_dir.is_dir():
paths.append(str(sample_dir.resolve()))
# Also allow parent (snapshot root) in case Gradio resolves paths from repo root
parent = sample_dir.parent
if parent.is_dir():
paths.append(str(parent.resolve()))
except Exception as e:
print(f"Warning: could not add HF sample_images to allowed_paths: {e}")
return paths
def _launch_kwargs():
"""Default kwargs for launch() so allowed_paths are always set (e.g. when HF Spaces runs demo.launch())."""
return {
"allowed_paths": _launch_allowed_paths(),
"theme": gr.themes.Soft(primary_hue="orange", secondary_hue="blue"),
}
# Ensure launch() always receives allowed_paths (e.g. when HF Spaces runner calls demo.launch() without args)
_original_launch = demo.launch
def _launch_with_allowed_paths(*args, **kwargs):
for key, value in _launch_kwargs().items():
if key not in kwargs:
kwargs[key] = value
return _original_launch(*args, **kwargs)
demo.launch = _launch_with_allowed_paths
# Replace your existing launch logic at the very bottom of the file with this:
if __name__ == "__main__":
demo.launch(ssr_mode=True, server_name="0.0.0.0", server_port=7860)
else:
# This handles cases where Hugging Face imports the file
demo.launch(ssr_mode=True, server_name="0.0.0.0", server_port=7860)