AlbeRota commited on
Commit
b604e51
·
1 Parent(s): 4ed98e6

Fix cach weights

Browse files
.cache/configs/pretrained_config.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### BASELINE: CONVERGES AFTER LONG
2
+
3
+ parameters:
4
+
5
+ ### MODEL ARCHITECTURE
6
+ MODEL:
7
+ value:
8
+ MODEL_CLASS: "UnReflect_Model_TokenInpainter" # Main model class name (must match class in models.py)
9
+ MODEL_MODULE: "models" # Module name to import model classes from (default: "models")
10
+ RGB_ENCODER:
11
+ ENCODER: "facebook/dinov3-vitl16-pretrain-lvd1689m" # DINOv3 encoder model name (HuggingFace format)
12
+ IMAGE_SIZE: 448 # Input image size (height and width in pixels)
13
+ RETURN_SELECTED_LAYERS: [3, 6, 9, 12] # Transformer layer indices to extract features from (0-indexed)
14
+ RGB_ENCODER_LR: 0.0 # Learning rate for RGB encoder (0.0 = frozen, must be explicitly set)
15
+ DECODERS:
16
+ diffuse:
17
+ USE_FILM: False # Enable FiLM (Feature-wise Linear Modulation) conditioning in decoder
18
+ FEATURE_DIM: 1024 # Feature dimension for decoder (should match encoder output)
19
+ REASSEMBLE_OUT_CHANNELS: [768,1024,1536,2048] # Output channels for each decoder stage (DPT-style reassembly)
20
+ REASSEMBLE_FACTORS: [4.0, 2.0, 1.0, 0.5] # Spatial upsampling factors for each stage
21
+ READOUT_TYPE: "ignore" # Readout type for DPT decoder ("ignore", "project", etc.)
22
+ FROM_PRETRAINED: "weights/rgb_decoder.pth" # Path to pretrained decoder weights (optional)
23
+ USE_BN: False # Use batch normalization in decoder
24
+ DROPOUT: 0.1 # Dropout rate in decoder layers
25
+ OUTPUT_IMAGE_SIZE: [448,448] # Output image resolution [height, width]
26
+ OUTPUT_CHANNELS: 3 # Number of output channels (3 for RGB diffuse image)
27
+ DECODER_LR: 1.0e-5 # Custom learning rate for decoder (0.0 = frozen, 1.0 = same as base LR)
28
+ NUM_FUSION_BLOCKS_TRAINABLE: 1 # Number of fusion blocks to train (0-4, null = train all if DECODER_LR != 0)
29
+ TRAIN_RGB_HEAD: True # Whether to train RGB head (true/false, null = train if DECODER_LR != 0)
30
+ highlight:
31
+ USE_FILM: False # Enable FiLM conditioning in highlight decoder
32
+ FEATURE_DIM: 1024 # Feature dimension for highlight decoder
33
+ REASSEMBLE_OUT_CHANNELS: [96,192,384,768] # Output channels for each decoder stage
34
+ REASSEMBLE_FACTORS: [4.0, 2.0, 1.0, 0.5] # Spatial upsampling factors for each stage
35
+ READOUT_TYPE: "ignore" # Readout type for DPT decoder
36
+ USE_BN: False # Use batch normalization in decoder
37
+ DROPOUT: 0.1 # Dropout rate in decoder layers
38
+ OUTPUT_IMAGE_SIZE: [448,448] # Output image resolution [height, width]
39
+ OUTPUT_CHANNELS: 1 # Number of output channels (1 for highlight mask)
40
+ DECODER_LR: 5.0e-4 # Custom learning rate for decoder (0.0 = frozen, 1.0 = same as base LR)
41
+ NUM_FUSION_BLOCKS_TRAINABLE: null # Number of fusion blocks to train (0-4, null = train all if DECODER_LR != 0)
42
+ TOKEN_INPAINTER:
43
+ TOKEN_INPAINTER_CLASS: "TokenInpainter_Prior" # Token inpainter class name
44
+ TOKEN_INPAINTER_MODULE: "token_inpainters" # Module name to import token inpainter from
45
+ FROM_PRETRAINED: "weights/token_inpainter.pth" # Path to pretrained token inpainter weights
46
+ TOKEN_INPAINTER_LR: 1.0e-5 # Learning rate for token inpainter (can differ from base LR)
47
+ DEPTH: 6 # Number of transformer blocks
48
+ HEADS: 16 # Number of attention heads
49
+ DROP: 0 # Dropout rate
50
+ USE_POSITIONAL_ENCODING: True # Enable 2D sinusoidal positional encodings
51
+ USE_FINAL_NORM: True # Enable final LayerNorm before output projection
52
+ USE_LOCAL_PRIOR: True # Blend local mean prior for masked seeds
53
+ LOCAL_PRIOR_WEIGHT: 0.5 # Weight for local prior blending (1.0 = only mask_token, 0.0 = only local mean)
54
+ LOCAL_PRIOR_KERNEL: 5 # Kernel size for local prior blending (> 1)
55
+ SEED_NOISE_STD: 0.02 # Standard deviation of noise added to masked seeds during training
56
+ INPAINT_MASK_DILATION:
57
+ value: 1 # Dilation kernel size (pixels) for inpaint mask - Must be odd
58
+ USE_TORCH_COMPILE: # Enable PyTorch 2.0 torch.compile for faster training (experimental)
59
+ value: False
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app.py CHANGED
@@ -3,7 +3,7 @@
3
  from __future__ import annotations
4
 
5
  import sys
6
- import tempfile
7
  from pathlib import Path
8
 
9
  # Allow importing unreflectanything when run from gradio_space (e.g. HF Space with root dir)
@@ -11,112 +11,203 @@ _REPO_ROOT = Path(__file__).resolve().parent.parent
11
  if _REPO_ROOT not in sys.path:
12
  sys.path.insert(0, str(_REPO_ROOT))
13
 
 
 
 
14
  import gradio as gr
15
  import numpy as np
16
  import torch
17
 
 
18
 
19
  def _ensure_weights():
20
  """Download weights to cache if not present."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  from unreflectanything import download
22
- from unreflectanything._shared import DEFAULT_WEIGHTS_FILENAME, get_cache_dir
23
 
24
- weights_dir = get_cache_dir("weights")
25
- if not (weights_dir / DEFAULT_WEIGHTS_FILENAME).exists():
26
- download("weights")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
- def run_inference(
30
- image: np.ndarray | None,
31
- brightness_threshold: float,
32
- ) -> np.ndarray | None:
33
- """Run reflection removal on a single image. Returns RGB numpy [H,W,3] in 0–255 or None."""
34
- if image is None:
35
- return None
36
- from unreflectanything import inference
37
 
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
39
- with tempfile.TemporaryDirectory() as tmpdir:
40
- inp_path = Path(tmpdir) / "input.png"
41
- out_path = Path(tmpdir) / "output.png"
42
- # Gradio passes RGB numpy (H, W, 3) in 0–255
43
- from PIL import Image
44
-
45
- Image.fromarray(image.astype(np.uint8)).save(inp_path)
46
- try:
47
- result = inference(
48
- input=str(inp_path),
49
- output=None,
50
- device=device,
51
- batch_size=1,
52
- brightness_threshold=brightness_threshold,
53
- resize_output=True,
54
- verbose=False,
55
- )
56
- except FileNotFoundError as e:
57
- if "Weights not found" in str(e) or "Run 'unreflect download" in str(e):
58
- _ensure_weights()
59
- result = inference(
60
- input=str(inp_path),
61
- output=None,
62
- device=device,
63
- batch_size=1,
64
- brightness_threshold=brightness_threshold,
65
- resize_output=True,
66
- verbose=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  )
68
- else:
69
- raise
70
- # result: [1, 3, H, W], float 0–1
71
- out = result[0].cpu().numpy().transpose(1, 2, 0)
72
- out = (np.clip(out, 0.0, 1.0) * 255).astype(np.uint8)
73
- return out
74
-
75
-
76
- def build_ui():
77
- _ensure_weights()
78
-
79
- with gr.Blocks(
80
- title="UnReflectAnything",
81
- theme=gr.themes.Soft(primary_hue="green", secondary_hue="purple"),
82
- ) as demo:
83
- gr.Markdown(
84
- """
85
- # UnReflectAnything
86
- Remove **specular reflections** from a single image. Upload an image and adjust the highlight threshold if needed.
87
- """
88
- )
89
  with gr.Row():
90
  inp = gr.Image(
91
- label="Input image",
92
  type="numpy",
93
- height=360,
 
 
94
  )
95
- out = gr.Image(
96
- label="Reflection‑removed (diffuse)",
97
  type="numpy",
98
- height=360,
 
99
  )
100
- brightness = gr.Slider(
101
- minimum=0.0,
102
- maximum=1.0,
103
- value=0.8,
104
- step=0.05,
105
- label="Brightness threshold (highlight detection)",
106
- )
107
- run_btn = gr.Button("Remove reflections", variant="primary")
108
  run_btn.click(
109
- fn=run_inference,
110
- inputs=[inp, brightness],
111
- outputs=out,
112
- )
113
- gr.Markdown(
114
- "Weights are cached after first run. On CPU inference may be slow."
115
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  return demo
117
 
118
 
119
  demo = build_ui()
120
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if __name__ == "__main__":
122
- demo.launch()
 
 
 
 
 
3
  from __future__ import annotations
4
 
5
  import sys
6
+ import threading
7
  from pathlib import Path
8
 
9
  # Allow importing unreflectanything when run from gradio_space (e.g. HF Space with root dir)
 
11
  if _REPO_ROOT not in sys.path:
12
  sys.path.insert(0, str(_REPO_ROOT))
13
 
14
+ # Logo path: put your PNG in gradio_space/logo.png (next to app.py)
15
+ _GRADIO_DIR = Path(__file__).resolve().parent
16
+
17
  import gradio as gr
18
  import numpy as np
19
  import torch
20
 
21
+ from huggingface_hub import hf_hub_download
22
 
23
  def _ensure_weights():
24
  """Download weights to cache if not present."""
25
+ weights_path = hf_hub_download(
26
+ repo_id="AlbeRota/UnReflectAnything",
27
+ filename="weights/full_model_weights.pt"
28
+ )
29
+ config_path = hf_hub_download(
30
+ repo_id="AlbeRota/UnReflectAnything",
31
+ filename="configs/pretrained_config.yaml"
32
+ )
33
+ return weights_path, config_path
34
+
35
+ def _ensure_sample_images() -> Path | None:
36
+ """Ensure sample images are downloaded to the standard cache dir and return it.
37
+
38
+ Uses the same cache layout as the rest of the library:
39
+ get_cache_dir("images") / <files>.
40
+ """
41
  from unreflectanything import download
42
+ from unreflectanything._shared import get_cache_dir
43
 
44
+ images_dir = get_cache_dir("images")
45
+ if not images_dir.is_dir():
46
+ try:
47
+ download("images")
48
+ except Exception:
49
+ return None
50
+ return images_dir
51
+
52
+
53
+ def _get_sample_images():
54
+ """Return list of sample image paths from the images cache directory."""
55
+ from unreflectanything._shared import DEFAULT_IMAGE_EXTENSIONS
56
+
57
+ images_dir = _ensure_sample_images()
58
+ if images_dir is None or not images_dir.is_dir():
59
+ return []
60
+ paths = []
61
+ for p in sorted(images_dir.iterdir()):
62
+ if p.is_file() and p.suffix.lower() in DEFAULT_IMAGE_EXTENSIONS:
63
+ paths.append(str(p))
64
+ return paths
65
+
66
+
67
+ # Single model instance; loaded in background at app start or on first inference.
68
+ _cached_ura_model = None
69
+ _cached_device = None
70
+ _model_load_lock = threading.Lock()
71
+
72
+
73
+ def _get_model(device: str):
74
+ """Return the pretrained model, loading it once and reusing. Ensures weights exist (downloads if missing)."""
75
+ global _cached_ura_model, _cached_device
76
+ weights_path, config_path = _ensure_weights()
77
+ with _model_load_lock:
78
+ if _cached_ura_model is not None and _cached_device == device:
79
+ return _cached_ura_model
80
+ from unreflectanything import model
81
+
82
+ _cached_ura_model = model(
83
+ pretrained=True,
84
+ # weights_path=os.path.join(os.path.dirname(__file__), ".cache", "weights", "full_model_weights.pt"),
85
+ # config_path=os.path.join(os.path.dirname(__file__), ".cache", "configs", "pretrained_config.yaml"),
86
+ weights_path=weights_path,
87
+ config_path=config_path,
88
+ device=device,
89
+ verbose=False,
90
+ )
91
+ _cached_device = device
92
+ return _cached_ura_model
93
 
94
 
95
+ def build_ui():
96
+ _ensure_sample_images()
 
 
 
 
 
 
97
 
98
  device = "cuda" if torch.cuda.is_available() else "cpu"
99
+ # Start loading the model in the background so it is ready (or nearly ready) by first use.
100
+ threading.Thread(target=_get_model, args=(device,), daemon=True).start()
101
+
102
+ def run_inference(image: np.ndarray | None) -> np.ndarray | None:
103
+ """Run reflection removal using the cached model. Returns RGB numpy [H,W,3] in 0–255 or None."""
104
+ if image is None:
105
+ return None
106
+ from torchvision.transforms import functional as TF
107
+
108
+ ura_model = _get_model(device)
109
+ target_side = ura_model.image_size
110
+ # image: [H, W, 3] uint8 0–255
111
+ h, w = image.shape[:2]
112
+ tensor = TF.to_tensor(image).unsqueeze(0) # [1, 3, H, W], [0, 1]
113
+ tensor = TF.resize(tensor, [target_side, target_side], antialias=True)
114
+ tensor = tensor.to(ura_model.device, dtype=torch.float32)
115
+ mask = tensor.mean(1, keepdim=True) > 0.9 # [1, 1, S, S]
116
+ with torch.no_grad():
117
+ diffuse = ura_model(images=tensor, inpaint_mask_override=mask)
118
+ diffuse = diffuse.cpu()
119
+ diffuse = TF.resize(diffuse, [h, w], antialias=True)
120
+ out = diffuse[0].numpy().transpose(1, 2, 0)
121
+ out = (np.clip(out, 0.0, 1.0) * 255).astype(np.uint8)
122
+ return out
123
+
124
+ def run_inference_slider(
125
+ image: np.ndarray | None,
126
+ ) -> tuple[np.ndarray | None, np.ndarray | None] | None:
127
+ """Run inference and return (input, output) for ImageSlider."""
128
+ out = run_inference(image)
129
+ if out is None:
130
+ return None
131
+ return (image, out)
132
+
133
+ with gr.Blocks(title="UnReflectAnything") as demo:
134
+ with gr.Row():
135
+ with gr.Column(scale=0, min_width=100):
136
+ # if LOGO_PATH.is_file():
137
+ # gr.Image(
138
+ # value=str(LOGO_PATH),
139
+ # show_label=False,
140
+ # interactive=False,
141
+ # height=100,
142
+ # container=False,
143
+ # buttons=[],
144
+ # )
145
+ with gr.Column(scale=1):
146
+ gr.Markdown(
147
+ """
148
+ # UnReflectAnything
149
+ UnReflectAnything inputs any RGB image and **removes specular highlights**,
150
+ returning a clean diffuse-only outputs. We trained UnReflectAnything by synthetizing
151
+ specularities and supervising in DINOv3 feature space.
152
+ UnReflectAnything works on both natural indoor and **surgical/endoscopic** domain data.
153
+ Visit the [Project Page](https://alberto-rota.github.io/UnReflectAnything/)!
154
+ """
155
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  with gr.Row():
157
  inp = gr.Image(
 
158
  type="numpy",
159
+ label="Image input",
160
+ height=600,
161
+ width=600,
162
  )
163
+ out_slider = gr.ImageSlider(
164
+ label="Input",
165
  type="numpy",
166
+ height=600,
167
+ show_label=True,
168
  )
169
+ run_btn = gr.Button("Run UnReflectAnything", variant="primary")
 
 
 
 
 
 
 
170
  run_btn.click(
171
+ fn=run_inference_slider,
172
+ inputs=[inp],
173
+ outputs=out_slider,
 
 
 
174
  )
175
+ sample_paths = _get_sample_images()
176
+ if sample_paths:
177
+ gr.Examples(
178
+ examples=[[p] for p in sample_paths],
179
+ inputs=inp,
180
+ label="Pre-loaded examples",
181
+ examples_per_page=20,
182
+ )
183
+ gr.HTML("""<hr>""")
184
+ gr.Markdown("""
185
+ [Project Page](https://alberto-rota.github.io/UnReflectAnything/) ⋅
186
+ [GitHub](https://github.com/alberto-rota/UnReflectAnything) ⋅
187
+ [Model Card](https://huggingface.co/AlbeRota/UnReflectAnything) ⋅
188
+ [Paper](https://arxiv.org/abs/2512.09583) ⋅
189
+ [Contact](mailto:alberto1.rota@polimi.it) ⋅
190
+ """)
191
  return demo
192
 
193
 
194
  demo = build_ui()
195
 
196
+
197
+ def _launch_allowed_paths():
198
+ """Paths Gradio is allowed to serve (e.g. for gr.Examples from cache)."""
199
+ from unreflectanything._shared import get_cache_dir
200
+
201
+ paths = [str(_GRADIO_DIR)]
202
+ images_cache = get_cache_dir("images")
203
+ if images_cache.is_dir():
204
+ paths.append(str(images_cache))
205
+ return paths
206
+
207
+
208
  if __name__ == "__main__":
209
+ demo.launch(
210
+ share=True,
211
+ allowed_paths=_launch_allowed_paths(),
212
+ theme=gr.themes.Soft(primary_hue="orange", secondary_hue="blue"),
213
+ )
tmp/engine_initializers.log ADDED
File without changes
tmp/main.log ADDED
File without changes
tmp/models.log ADDED
File without changes
tmp/optimization.log ADDED
File without changes
tmp/rgbp.log ADDED
File without changes
tmp/run_resume.log ADDED
File without changes