Spaces:
Running
on
Zero
Running
on
Zero
Fix cach weights
Browse files- .cache/configs/pretrained_config.yaml +59 -0
- .gradio/certificate.pem +31 -0
- app.py +173 -82
- tmp/engine_initializers.log +0 -0
- tmp/main.log +0 -0
- tmp/models.log +0 -0
- tmp/optimization.log +0 -0
- tmp/rgbp.log +0 -0
- tmp/run_resume.log +0 -0
.cache/configs/pretrained_config.yaml
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### BASELINE: CONVERGES AFTER LONG
|
| 2 |
+
|
| 3 |
+
parameters:
|
| 4 |
+
|
| 5 |
+
### MODEL ARCHITECTURE
|
| 6 |
+
MODEL:
|
| 7 |
+
value:
|
| 8 |
+
MODEL_CLASS: "UnReflect_Model_TokenInpainter" # Main model class name (must match class in models.py)
|
| 9 |
+
MODEL_MODULE: "models" # Module name to import model classes from (default: "models")
|
| 10 |
+
RGB_ENCODER:
|
| 11 |
+
ENCODER: "facebook/dinov3-vitl16-pretrain-lvd1689m" # DINOv3 encoder model name (HuggingFace format)
|
| 12 |
+
IMAGE_SIZE: 448 # Input image size (height and width in pixels)
|
| 13 |
+
RETURN_SELECTED_LAYERS: [3, 6, 9, 12] # Transformer layer indices to extract features from (0-indexed)
|
| 14 |
+
RGB_ENCODER_LR: 0.0 # Learning rate for RGB encoder (0.0 = frozen, must be explicitly set)
|
| 15 |
+
DECODERS:
|
| 16 |
+
diffuse:
|
| 17 |
+
USE_FILM: False # Enable FiLM (Feature-wise Linear Modulation) conditioning in decoder
|
| 18 |
+
FEATURE_DIM: 1024 # Feature dimension for decoder (should match encoder output)
|
| 19 |
+
REASSEMBLE_OUT_CHANNELS: [768,1024,1536,2048] # Output channels for each decoder stage (DPT-style reassembly)
|
| 20 |
+
REASSEMBLE_FACTORS: [4.0, 2.0, 1.0, 0.5] # Spatial upsampling factors for each stage
|
| 21 |
+
READOUT_TYPE: "ignore" # Readout type for DPT decoder ("ignore", "project", etc.)
|
| 22 |
+
FROM_PRETRAINED: "weights/rgb_decoder.pth" # Path to pretrained decoder weights (optional)
|
| 23 |
+
USE_BN: False # Use batch normalization in decoder
|
| 24 |
+
DROPOUT: 0.1 # Dropout rate in decoder layers
|
| 25 |
+
OUTPUT_IMAGE_SIZE: [448,448] # Output image resolution [height, width]
|
| 26 |
+
OUTPUT_CHANNELS: 3 # Number of output channels (3 for RGB diffuse image)
|
| 27 |
+
DECODER_LR: 1.0e-5 # Custom learning rate for decoder (0.0 = frozen, 1.0 = same as base LR)
|
| 28 |
+
NUM_FUSION_BLOCKS_TRAINABLE: 1 # Number of fusion blocks to train (0-4, null = train all if DECODER_LR != 0)
|
| 29 |
+
TRAIN_RGB_HEAD: True # Whether to train RGB head (true/false, null = train if DECODER_LR != 0)
|
| 30 |
+
highlight:
|
| 31 |
+
USE_FILM: False # Enable FiLM conditioning in highlight decoder
|
| 32 |
+
FEATURE_DIM: 1024 # Feature dimension for highlight decoder
|
| 33 |
+
REASSEMBLE_OUT_CHANNELS: [96,192,384,768] # Output channels for each decoder stage
|
| 34 |
+
REASSEMBLE_FACTORS: [4.0, 2.0, 1.0, 0.5] # Spatial upsampling factors for each stage
|
| 35 |
+
READOUT_TYPE: "ignore" # Readout type for DPT decoder
|
| 36 |
+
USE_BN: False # Use batch normalization in decoder
|
| 37 |
+
DROPOUT: 0.1 # Dropout rate in decoder layers
|
| 38 |
+
OUTPUT_IMAGE_SIZE: [448,448] # Output image resolution [height, width]
|
| 39 |
+
OUTPUT_CHANNELS: 1 # Number of output channels (1 for highlight mask)
|
| 40 |
+
DECODER_LR: 5.0e-4 # Custom learning rate for decoder (0.0 = frozen, 1.0 = same as base LR)
|
| 41 |
+
NUM_FUSION_BLOCKS_TRAINABLE: null # Number of fusion blocks to train (0-4, null = train all if DECODER_LR != 0)
|
| 42 |
+
TOKEN_INPAINTER:
|
| 43 |
+
TOKEN_INPAINTER_CLASS: "TokenInpainter_Prior" # Token inpainter class name
|
| 44 |
+
TOKEN_INPAINTER_MODULE: "token_inpainters" # Module name to import token inpainter from
|
| 45 |
+
FROM_PRETRAINED: "weights/token_inpainter.pth" # Path to pretrained token inpainter weights
|
| 46 |
+
TOKEN_INPAINTER_LR: 1.0e-5 # Learning rate for token inpainter (can differ from base LR)
|
| 47 |
+
DEPTH: 6 # Number of transformer blocks
|
| 48 |
+
HEADS: 16 # Number of attention heads
|
| 49 |
+
DROP: 0 # Dropout rate
|
| 50 |
+
USE_POSITIONAL_ENCODING: True # Enable 2D sinusoidal positional encodings
|
| 51 |
+
USE_FINAL_NORM: True # Enable final LayerNorm before output projection
|
| 52 |
+
USE_LOCAL_PRIOR: True # Blend local mean prior for masked seeds
|
| 53 |
+
LOCAL_PRIOR_WEIGHT: 0.5 # Weight for local prior blending (1.0 = only mask_token, 0.0 = only local mean)
|
| 54 |
+
LOCAL_PRIOR_KERNEL: 5 # Kernel size for local prior blending (> 1)
|
| 55 |
+
SEED_NOISE_STD: 0.02 # Standard deviation of noise added to masked seeds during training
|
| 56 |
+
INPAINT_MASK_DILATION:
|
| 57 |
+
value: 1 # Dilation kernel size (pixels) for inpaint mask - Must be odd
|
| 58 |
+
USE_TORCH_COMPILE: # Enable PyTorch 2.0 torch.compile for faster training (experimental)
|
| 59 |
+
value: False
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
app.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import sys
|
| 6 |
-
import
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
# Allow importing unreflectanything when run from gradio_space (e.g. HF Space with root dir)
|
|
@@ -11,112 +11,203 @@ _REPO_ROOT = Path(__file__).resolve().parent.parent
|
|
| 11 |
if _REPO_ROOT not in sys.path:
|
| 12 |
sys.path.insert(0, str(_REPO_ROOT))
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
import gradio as gr
|
| 15 |
import numpy as np
|
| 16 |
import torch
|
| 17 |
|
|
|
|
| 18 |
|
| 19 |
def _ensure_weights():
|
| 20 |
"""Download weights to cache if not present."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
from unreflectanything import download
|
| 22 |
-
from unreflectanything._shared import
|
| 23 |
|
| 24 |
-
|
| 25 |
-
if not
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
-
def
|
| 30 |
-
|
| 31 |
-
brightness_threshold: float,
|
| 32 |
-
) -> np.ndarray | None:
|
| 33 |
-
"""Run reflection removal on a single image. Returns RGB numpy [H,W,3] in 0–255 or None."""
|
| 34 |
-
if image is None:
|
| 35 |
-
return None
|
| 36 |
-
from unreflectanything import inference
|
| 37 |
|
| 38 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
)
|
| 68 |
-
else:
|
| 69 |
-
raise
|
| 70 |
-
# result: [1, 3, H, W], float 0–1
|
| 71 |
-
out = result[0].cpu().numpy().transpose(1, 2, 0)
|
| 72 |
-
out = (np.clip(out, 0.0, 1.0) * 255).astype(np.uint8)
|
| 73 |
-
return out
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def build_ui():
|
| 77 |
-
_ensure_weights()
|
| 78 |
-
|
| 79 |
-
with gr.Blocks(
|
| 80 |
-
title="UnReflectAnything",
|
| 81 |
-
theme=gr.themes.Soft(primary_hue="green", secondary_hue="purple"),
|
| 82 |
-
) as demo:
|
| 83 |
-
gr.Markdown(
|
| 84 |
-
"""
|
| 85 |
-
# UnReflectAnything
|
| 86 |
-
Remove **specular reflections** from a single image. Upload an image and adjust the highlight threshold if needed.
|
| 87 |
-
"""
|
| 88 |
-
)
|
| 89 |
with gr.Row():
|
| 90 |
inp = gr.Image(
|
| 91 |
-
label="Input image",
|
| 92 |
type="numpy",
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
)
|
| 95 |
-
|
| 96 |
-
label="
|
| 97 |
type="numpy",
|
| 98 |
-
height=
|
|
|
|
| 99 |
)
|
| 100 |
-
|
| 101 |
-
minimum=0.0,
|
| 102 |
-
maximum=1.0,
|
| 103 |
-
value=0.8,
|
| 104 |
-
step=0.05,
|
| 105 |
-
label="Brightness threshold (highlight detection)",
|
| 106 |
-
)
|
| 107 |
-
run_btn = gr.Button("Remove reflections", variant="primary")
|
| 108 |
run_btn.click(
|
| 109 |
-
fn=
|
| 110 |
-
inputs=[inp
|
| 111 |
-
outputs=
|
| 112 |
-
)
|
| 113 |
-
gr.Markdown(
|
| 114 |
-
"Weights are cached after first run. On CPU inference may be slow."
|
| 115 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
return demo
|
| 117 |
|
| 118 |
|
| 119 |
demo = build_ui()
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
if __name__ == "__main__":
|
| 122 |
-
demo.launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import sys
|
| 6 |
+
import threading
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
# Allow importing unreflectanything when run from gradio_space (e.g. HF Space with root dir)
|
|
|
|
| 11 |
if _REPO_ROOT not in sys.path:
|
| 12 |
sys.path.insert(0, str(_REPO_ROOT))
|
| 13 |
|
| 14 |
+
# Logo path: put your PNG in gradio_space/logo.png (next to app.py)
|
| 15 |
+
_GRADIO_DIR = Path(__file__).resolve().parent
|
| 16 |
+
|
| 17 |
import gradio as gr
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
| 20 |
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
|
| 23 |
def _ensure_weights():
|
| 24 |
"""Download weights to cache if not present."""
|
| 25 |
+
weights_path = hf_hub_download(
|
| 26 |
+
repo_id="AlbeRota/UnReflectAnything",
|
| 27 |
+
filename="weights/full_model_weights.pt"
|
| 28 |
+
)
|
| 29 |
+
config_path = hf_hub_download(
|
| 30 |
+
repo_id="AlbeRota/UnReflectAnything",
|
| 31 |
+
filename="configs/pretrained_config.yaml"
|
| 32 |
+
)
|
| 33 |
+
return weights_path, config_path
|
| 34 |
+
|
| 35 |
+
def _ensure_sample_images() -> Path | None:
|
| 36 |
+
"""Ensure sample images are downloaded to the standard cache dir and return it.
|
| 37 |
+
|
| 38 |
+
Uses the same cache layout as the rest of the library:
|
| 39 |
+
get_cache_dir("images") / <files>.
|
| 40 |
+
"""
|
| 41 |
from unreflectanything import download
|
| 42 |
+
from unreflectanything._shared import get_cache_dir
|
| 43 |
|
| 44 |
+
images_dir = get_cache_dir("images")
|
| 45 |
+
if not images_dir.is_dir():
|
| 46 |
+
try:
|
| 47 |
+
download("images")
|
| 48 |
+
except Exception:
|
| 49 |
+
return None
|
| 50 |
+
return images_dir
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _get_sample_images():
|
| 54 |
+
"""Return list of sample image paths from the images cache directory."""
|
| 55 |
+
from unreflectanything._shared import DEFAULT_IMAGE_EXTENSIONS
|
| 56 |
+
|
| 57 |
+
images_dir = _ensure_sample_images()
|
| 58 |
+
if images_dir is None or not images_dir.is_dir():
|
| 59 |
+
return []
|
| 60 |
+
paths = []
|
| 61 |
+
for p in sorted(images_dir.iterdir()):
|
| 62 |
+
if p.is_file() and p.suffix.lower() in DEFAULT_IMAGE_EXTENSIONS:
|
| 63 |
+
paths.append(str(p))
|
| 64 |
+
return paths
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Single model instance; loaded in background at app start or on first inference.
|
| 68 |
+
_cached_ura_model = None
|
| 69 |
+
_cached_device = None
|
| 70 |
+
_model_load_lock = threading.Lock()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _get_model(device: str):
|
| 74 |
+
"""Return the pretrained model, loading it once and reusing. Ensures weights exist (downloads if missing)."""
|
| 75 |
+
global _cached_ura_model, _cached_device
|
| 76 |
+
weights_path, config_path = _ensure_weights()
|
| 77 |
+
with _model_load_lock:
|
| 78 |
+
if _cached_ura_model is not None and _cached_device == device:
|
| 79 |
+
return _cached_ura_model
|
| 80 |
+
from unreflectanything import model
|
| 81 |
+
|
| 82 |
+
_cached_ura_model = model(
|
| 83 |
+
pretrained=True,
|
| 84 |
+
# weights_path=os.path.join(os.path.dirname(__file__), ".cache", "weights", "full_model_weights.pt"),
|
| 85 |
+
# config_path=os.path.join(os.path.dirname(__file__), ".cache", "configs", "pretrained_config.yaml"),
|
| 86 |
+
weights_path=weights_path,
|
| 87 |
+
config_path=config_path,
|
| 88 |
+
device=device,
|
| 89 |
+
verbose=False,
|
| 90 |
+
)
|
| 91 |
+
_cached_device = device
|
| 92 |
+
return _cached_ura_model
|
| 93 |
|
| 94 |
|
| 95 |
+
def build_ui():
|
| 96 |
+
_ensure_sample_images()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 99 |
+
# Start loading the model in the background so it is ready (or nearly ready) by first use.
|
| 100 |
+
threading.Thread(target=_get_model, args=(device,), daemon=True).start()
|
| 101 |
+
|
| 102 |
+
def run_inference(image: np.ndarray | None) -> np.ndarray | None:
|
| 103 |
+
"""Run reflection removal using the cached model. Returns RGB numpy [H,W,3] in 0–255 or None."""
|
| 104 |
+
if image is None:
|
| 105 |
+
return None
|
| 106 |
+
from torchvision.transforms import functional as TF
|
| 107 |
+
|
| 108 |
+
ura_model = _get_model(device)
|
| 109 |
+
target_side = ura_model.image_size
|
| 110 |
+
# image: [H, W, 3] uint8 0–255
|
| 111 |
+
h, w = image.shape[:2]
|
| 112 |
+
tensor = TF.to_tensor(image).unsqueeze(0) # [1, 3, H, W], [0, 1]
|
| 113 |
+
tensor = TF.resize(tensor, [target_side, target_side], antialias=True)
|
| 114 |
+
tensor = tensor.to(ura_model.device, dtype=torch.float32)
|
| 115 |
+
mask = tensor.mean(1, keepdim=True) > 0.9 # [1, 1, S, S]
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
diffuse = ura_model(images=tensor, inpaint_mask_override=mask)
|
| 118 |
+
diffuse = diffuse.cpu()
|
| 119 |
+
diffuse = TF.resize(diffuse, [h, w], antialias=True)
|
| 120 |
+
out = diffuse[0].numpy().transpose(1, 2, 0)
|
| 121 |
+
out = (np.clip(out, 0.0, 1.0) * 255).astype(np.uint8)
|
| 122 |
+
return out
|
| 123 |
+
|
| 124 |
+
def run_inference_slider(
|
| 125 |
+
image: np.ndarray | None,
|
| 126 |
+
) -> tuple[np.ndarray | None, np.ndarray | None] | None:
|
| 127 |
+
"""Run inference and return (input, output) for ImageSlider."""
|
| 128 |
+
out = run_inference(image)
|
| 129 |
+
if out is None:
|
| 130 |
+
return None
|
| 131 |
+
return (image, out)
|
| 132 |
+
|
| 133 |
+
with gr.Blocks(title="UnReflectAnything") as demo:
|
| 134 |
+
with gr.Row():
|
| 135 |
+
with gr.Column(scale=0, min_width=100):
|
| 136 |
+
# if LOGO_PATH.is_file():
|
| 137 |
+
# gr.Image(
|
| 138 |
+
# value=str(LOGO_PATH),
|
| 139 |
+
# show_label=False,
|
| 140 |
+
# interactive=False,
|
| 141 |
+
# height=100,
|
| 142 |
+
# container=False,
|
| 143 |
+
# buttons=[],
|
| 144 |
+
# )
|
| 145 |
+
with gr.Column(scale=1):
|
| 146 |
+
gr.Markdown(
|
| 147 |
+
"""
|
| 148 |
+
# UnReflectAnything
|
| 149 |
+
UnReflectAnything inputs any RGB image and **removes specular highlights**,
|
| 150 |
+
returning a clean diffuse-only outputs. We trained UnReflectAnything by synthetizing
|
| 151 |
+
specularities and supervising in DINOv3 feature space.
|
| 152 |
+
UnReflectAnything works on both natural indoor and **surgical/endoscopic** domain data.
|
| 153 |
+
Visit the [Project Page](https://alberto-rota.github.io/UnReflectAnything/)!
|
| 154 |
+
"""
|
| 155 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
with gr.Row():
|
| 157 |
inp = gr.Image(
|
|
|
|
| 158 |
type="numpy",
|
| 159 |
+
label="Image input",
|
| 160 |
+
height=600,
|
| 161 |
+
width=600,
|
| 162 |
)
|
| 163 |
+
out_slider = gr.ImageSlider(
|
| 164 |
+
label="Input",
|
| 165 |
type="numpy",
|
| 166 |
+
height=600,
|
| 167 |
+
show_label=True,
|
| 168 |
)
|
| 169 |
+
run_btn = gr.Button("Run UnReflectAnything", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
run_btn.click(
|
| 171 |
+
fn=run_inference_slider,
|
| 172 |
+
inputs=[inp],
|
| 173 |
+
outputs=out_slider,
|
|
|
|
|
|
|
|
|
|
| 174 |
)
|
| 175 |
+
sample_paths = _get_sample_images()
|
| 176 |
+
if sample_paths:
|
| 177 |
+
gr.Examples(
|
| 178 |
+
examples=[[p] for p in sample_paths],
|
| 179 |
+
inputs=inp,
|
| 180 |
+
label="Pre-loaded examples",
|
| 181 |
+
examples_per_page=20,
|
| 182 |
+
)
|
| 183 |
+
gr.HTML("""<hr>""")
|
| 184 |
+
gr.Markdown("""
|
| 185 |
+
[Project Page](https://alberto-rota.github.io/UnReflectAnything/) ⋅
|
| 186 |
+
[GitHub](https://github.com/alberto-rota/UnReflectAnything) ⋅
|
| 187 |
+
[Model Card](https://huggingface.co/AlbeRota/UnReflectAnything) ⋅
|
| 188 |
+
[Paper](https://arxiv.org/abs/2512.09583) ⋅
|
| 189 |
+
[Contact](mailto:alberto1.rota@polimi.it) ⋅
|
| 190 |
+
""")
|
| 191 |
return demo
|
| 192 |
|
| 193 |
|
| 194 |
demo = build_ui()
|
| 195 |
|
| 196 |
+
|
| 197 |
+
def _launch_allowed_paths():
|
| 198 |
+
"""Paths Gradio is allowed to serve (e.g. for gr.Examples from cache)."""
|
| 199 |
+
from unreflectanything._shared import get_cache_dir
|
| 200 |
+
|
| 201 |
+
paths = [str(_GRADIO_DIR)]
|
| 202 |
+
images_cache = get_cache_dir("images")
|
| 203 |
+
if images_cache.is_dir():
|
| 204 |
+
paths.append(str(images_cache))
|
| 205 |
+
return paths
|
| 206 |
+
|
| 207 |
+
|
| 208 |
if __name__ == "__main__":
|
| 209 |
+
demo.launch(
|
| 210 |
+
share=True,
|
| 211 |
+
allowed_paths=_launch_allowed_paths(),
|
| 212 |
+
theme=gr.themes.Soft(primary_hue="orange", secondary_hue="blue"),
|
| 213 |
+
)
|
tmp/engine_initializers.log
ADDED
|
File without changes
|
tmp/main.log
ADDED
|
File without changes
|
tmp/models.log
ADDED
|
File without changes
|
tmp/optimization.log
ADDED
|
File without changes
|
tmp/rgbp.log
ADDED
|
File without changes
|
tmp/run_resume.log
ADDED
|
File without changes
|