github-actions[bot] commited on
Commit
ad65a2a
·
1 Parent(s): a1f9c6d

Sync from GitHub: bcbc0c1c101625b271610b9d2a7d3fa0a10bd1fe

Browse files
Files changed (1) hide show
  1. app.py +120 -49
app.py CHANGED
@@ -4,6 +4,8 @@
4
 
5
  """Gradio demo for rgbd-depth on Hugging Face Spaces."""
6
 
 
 
7
  import gradio as gr
8
  import numpy as np
9
  import torch
@@ -14,42 +16,91 @@ from rgbddepth import RGBDDepth
14
  # Global model cache
15
  MODELS = {}
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- def load_model(encoder: str, use_xformers: bool = False):
19
- """Load model with caching."""
20
- cache_key = f"{encoder}_{use_xformers}"
 
 
 
 
 
 
 
 
21
 
22
  if cache_key not in MODELS:
23
- # Model configs
24
- configs = {
25
- "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
26
- "vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
27
- "vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
28
- "vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
29
  }
30
 
31
- config = configs[encoder].copy()
32
- config["use_xformers"] = use_xformers
33
-
34
  model = RGBDDepth(**config)
35
 
36
- # Try to load weights if checkpoint exists
37
- try:
38
- checkpoint = torch.load(f"checkpoints/{encoder}.pt", map_location="cpu")
39
- if "model" in checkpoint:
40
- states = {k[7:]: v for k, v in checkpoint["model"].items()}
41
- elif "state_dict" in checkpoint:
42
- states = {k[9:]: v for k, v in checkpoint["state_dict"].items()}
43
- else:
44
- states = checkpoint
45
 
46
- model.load_state_dict(states, strict=False)
47
- print(f"✓ Loaded checkpoint for {encoder}")
48
- except FileNotFoundError:
49
- print(f"⚠ No checkpoint found for {encoder}, using random weights (demo only)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Move to GPU if available
52
- device = "cuda" if torch.cuda.is_available() else "cpu"
53
  model = model.to(device).eval()
54
 
55
  MODELS[cache_key] = model
@@ -60,7 +111,7 @@ def load_model(encoder: str, use_xformers: bool = False):
60
  def process_depth(
61
  rgb_image: np.ndarray,
62
  depth_image: np.ndarray,
63
- encoder: str = "vitl",
64
  input_size: int = 518,
65
  depth_scale: float = 1000.0,
66
  max_depth: float = 25.0,
@@ -73,7 +124,7 @@ def process_depth(
73
  Args:
74
  rgb_image: RGB image as numpy array [H, W, 3]
75
  depth_image: Depth image as numpy array [H, W] or [H, W, 3]
76
- encoder: Model encoder type
77
  input_size: Input size for inference
78
  depth_scale: Scale factor for depth values
79
  max_depth: Maximum valid depth value
@@ -105,7 +156,7 @@ def process_depth(
105
  simi_depth[valid_mask] = 1.0 / depth_normalized[valid_mask]
106
 
107
  # Load model
108
- model = load_model(encoder, use_xformers and torch.cuda.is_available())
109
  device = next(model.parameters()).device
110
 
111
  # Determine precision
@@ -116,6 +167,15 @@ def process_depth(
116
  else:
117
  dtype = None # FP32
118
 
 
 
 
 
 
 
 
 
 
119
  # Run inference
120
  if dtype is not None:
121
  device_type = "cuda" if device.type == "cuda" else "cpu"
@@ -124,9 +184,15 @@ def process_depth(
124
  else:
125
  pred = model.infer_image(rgb_image, simi_depth, input_size=input_size)
126
 
 
 
 
127
  # Convert from inverse depth to depth
128
  pred = np.where(pred > 1e-8, 1.0 / pred, 0.0)
129
 
 
 
 
130
  # Colorize for visualization
131
  try:
132
  import matplotlib
@@ -148,14 +214,16 @@ def process_depth(
148
 
149
  except ImportError:
150
  # Fallback to grayscale if matplotlib not available
151
- pred_norm = ((pred - pred.min()) / (pred.max() - pred.min() + 1e-8) * 255).astype(np.uint8)
152
- output_image = Image.fromarray(pred_norm, mode='L').convert('RGB')
 
 
153
 
154
  # Create info message
155
  info = f"""
156
  ✅ **Refinement complete!**
157
 
158
- **Model:** {encoder.upper()}
159
  **Precision:** {precision.upper()}
160
  **Device:** {device.type.upper()}
161
  **Input size:** {input_size}px
@@ -171,16 +239,17 @@ def process_depth(
171
 
172
  # Create Gradio interface
173
  with gr.Blocks(title="rgbd-depth Demo") as demo:
174
- gr.Markdown("""
 
175
  # 🎨 rgbd-depth: RGB-D Depth Refinement
176
 
177
  High-quality depth map refinement using Vision Transformers. Based on [ByteDance's camera-depth-models](https://manipulation-as-in-simulation.github.io/).
178
 
179
- ⚠️ **Note:** This demo uses random weights for demonstration. For real results:
180
- 1. Download checkpoints from [Hugging Face](https://huggingface.co/collections/depth-anything/camera-depth-models-68b521181dedd223f4b020db)
181
- 2. Place in `checkpoints/` directory
182
- 3. Restart the app
183
- """)
184
 
185
  with gr.Row():
186
  with gr.Column():
@@ -199,11 +268,11 @@ with gr.Blocks(title="rgbd-depth Demo") as demo:
199
  )
200
 
201
  with gr.Accordion("⚙️ Advanced Settings", open=False):
202
- encoder_choice = gr.Radio(
203
- choices=["vits", "vitb", "vitl", "vitg"],
204
- value="vitl",
205
- label="Encoder Model",
206
- info="Larger = better quality but slower",
207
  )
208
 
209
  input_size = gr.Slider(
@@ -235,7 +304,7 @@ with gr.Blocks(title="rgbd-depth Demo") as demo:
235
  )
236
 
237
  use_xformers = gr.Checkbox(
238
- value=False,
239
  label="Use xFormers (CUDA only)",
240
  info="~8% faster on CUDA with xFormers installed",
241
  )
@@ -276,7 +345,7 @@ with gr.Blocks(title="rgbd-depth Demo") as demo:
276
  inputs=[
277
  rgb_input,
278
  depth_input,
279
- encoder_choice,
280
  input_size,
281
  depth_scale,
282
  max_depth,
@@ -288,7 +357,8 @@ with gr.Blocks(title="rgbd-depth Demo") as demo:
288
  )
289
 
290
  # Footer
291
- gr.Markdown("""
 
292
  ---
293
 
294
  ### 🔗 Links
@@ -316,7 +386,8 @@ with gr.Blocks(title="rgbd-depth Demo") as demo:
316
  ---
317
 
318
  Built with ❤️ by [Aedelon](https://github.com/Aedelon) | Powered by [Gradio](https://gradio.app)
319
- """)
 
320
 
321
  if __name__ == "__main__":
322
- demo.launch()
 
4
 
5
  """Gradio demo for rgbd-depth on Hugging Face Spaces."""
6
 
7
+ from pathlib import Path
8
+
9
  import gradio as gr
10
  import numpy as np
11
  import torch
 
16
  # Global model cache
17
  MODELS = {}
18
 
19
+ # Model mappings from HuggingFace (all are vitl encoder)
20
+ # Format: "camera_model": ("repo_id", "checkpoint_filename")
21
+ HF_MODELS = {
22
+ "d435": ("depth-anything/camera-depth-model-d435", "cdm_d435.ckpt"),
23
+ "d405": ("depth-anything/camera-depth-model-d405", "cdm_d405.ckpt"),
24
+ "l515": ("depth-anything/camera-depth-model-l515", "cdm_l515.ckpt"),
25
+ "zed2i": ("depth-anything/camera-depth-model-zed2i", "cdm_zed2i.ckpt"),
26
+ }
27
+
28
+ # Default model
29
+ DEFAULT_MODEL = "d435"
30
+
31
+
32
+ def download_model(camera_model: str = DEFAULT_MODEL):
33
+ """Download model from HuggingFace Hub."""
34
+ try:
35
+ from huggingface_hub import hf_hub_download
36
+
37
+ repo_id, filename = HF_MODELS.get(camera_model, HF_MODELS[DEFAULT_MODEL])
38
+ print(f"📥 Downloading {camera_model} model from {repo_id}/{filename}...")
39
+
40
+ # Download the checkpoint
41
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=".cache")
42
 
43
+ print(f"✓ Downloaded to {checkpoint_path}")
44
+ return checkpoint_path
45
+
46
+ except Exception as e:
47
+ print(f"❌ Failed to download model: {e}")
48
+ return None
49
+
50
+
51
+ def load_model(camera_model: str = DEFAULT_MODEL, use_xformers: bool = False):
52
+ """Load model with automatic download from HuggingFace."""
53
+ cache_key = f"{camera_model}_{use_xformers}"
54
 
55
  if cache_key not in MODELS:
56
+ # All HF models use vitl encoder
57
+ config = {
58
+ "encoder": "vitl",
59
+ "features": 256,
60
+ "out_channels": [256, 512, 1024, 1024],
61
+ "use_xformers": use_xformers,
62
  }
63
 
 
 
 
64
  model = RGBDDepth(**config)
65
 
66
+ # Try to load weights
67
+ checkpoint_path = None
 
 
 
 
 
 
 
68
 
69
+ # 1. Try local checkpoints/ directory first
70
+ local_path = Path(f"checkpoints/{camera_model}.pt")
71
+ if local_path.exists():
72
+ checkpoint_path = str(local_path)
73
+ print(f"✓ Using local checkpoint: {checkpoint_path}")
74
+ else:
75
+ # 2. Download from HuggingFace
76
+ checkpoint_path = download_model(camera_model)
77
+
78
+ # Load checkpoint if available
79
+ if checkpoint_path:
80
+ try:
81
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
82
+ if "model" in checkpoint:
83
+ states = {k[7:]: v for k, v in checkpoint["model"].items()}
84
+ elif "state_dict" in checkpoint:
85
+ states = {k[9:]: v for k, v in checkpoint["state_dict"].items()}
86
+ else:
87
+ states = checkpoint
88
+
89
+ model.load_state_dict(states, strict=False)
90
+ print(f"✓ Loaded checkpoint for {camera_model}")
91
+ except Exception as e:
92
+ print(f"⚠ Failed to load checkpoint: {e}, using random weights")
93
+ else:
94
+ print(f"⚠ No checkpoint available for {camera_model}, using random weights (demo only)")
95
+
96
+ # Move to GPU if available (CUDA or MPS for macOS)
97
+ if torch.cuda.is_available():
98
+ device = "cuda"
99
+ elif torch.backends.mps.is_available():
100
+ device = "mps"
101
+ else:
102
+ device = "cpu"
103
 
 
 
104
  model = model.to(device).eval()
105
 
106
  MODELS[cache_key] = model
 
111
  def process_depth(
112
  rgb_image: np.ndarray,
113
  depth_image: np.ndarray,
114
+ camera_model: str = DEFAULT_MODEL,
115
  input_size: int = 518,
116
  depth_scale: float = 1000.0,
117
  max_depth: float = 25.0,
 
124
  Args:
125
  rgb_image: RGB image as numpy array [H, W, 3]
126
  depth_image: Depth image as numpy array [H, W] or [H, W, 3]
127
+ camera_model: Camera model to use (d435, d405, l515, zed2i)
128
  input_size: Input size for inference
129
  depth_scale: Scale factor for depth values
130
  max_depth: Maximum valid depth value
 
156
  simi_depth[valid_mask] = 1.0 / depth_normalized[valid_mask]
157
 
158
  # Load model
159
+ model = load_model(camera_model, use_xformers and torch.cuda.is_available())
160
  device = next(model.parameters()).device
161
 
162
  # Determine precision
 
167
  else:
168
  dtype = None # FP32
169
 
170
+ # DEBUG: Print input stats
171
+ print(f"[DEBUG] depth_image raw: min={depth_image.min():.1f}, max={depth_image.max():.1f}")
172
+ print(
173
+ f"[DEBUG] depth_normalized: min={depth_normalized[depth_normalized>0].min():.4f}, max={depth_normalized.max():.4f}"
174
+ )
175
+ print(
176
+ f"[DEBUG] simi_depth: min={simi_depth[simi_depth>0].min():.4f}, max={simi_depth.max():.4f}"
177
+ )
178
+
179
  # Run inference
180
  if dtype is not None:
181
  device_type = "cuda" if device.type == "cuda" else "cpu"
 
184
  else:
185
  pred = model.infer_image(rgb_image, simi_depth, input_size=input_size)
186
 
187
+ # DEBUG: Print prediction stats before reconversion
188
+ print(f"[DEBUG] pred (inverse depth): min={pred[pred>0].min():.4f}, max={pred.max():.4f}")
189
+
190
  # Convert from inverse depth to depth
191
  pred = np.where(pred > 1e-8, 1.0 / pred, 0.0)
192
 
193
+ # DEBUG: Print final depth stats
194
+ print(f"[DEBUG] pred (depth): min={pred[pred>0].min():.4f}, max={pred.max():.4f}")
195
+
196
  # Colorize for visualization
197
  try:
198
  import matplotlib
 
214
 
215
  except ImportError:
216
  # Fallback to grayscale if matplotlib not available
217
+ pred_norm = ((pred - pred.min()) / (pred.max() - pred.min() + 1e-8) * 255).astype(
218
+ np.uint8
219
+ )
220
+ output_image = Image.fromarray(pred_norm, mode="L").convert("RGB")
221
 
222
  # Create info message
223
  info = f"""
224
  ✅ **Refinement complete!**
225
 
226
+ **Camera Model:** {camera_model.upper()}
227
  **Precision:** {precision.upper()}
228
  **Device:** {device.type.upper()}
229
  **Input size:** {input_size}px
 
239
 
240
  # Create Gradio interface
241
  with gr.Blocks(title="rgbd-depth Demo") as demo:
242
+ gr.Markdown(
243
+ """
244
  # 🎨 rgbd-depth: RGB-D Depth Refinement
245
 
246
  High-quality depth map refinement using Vision Transformers. Based on [ByteDance's camera-depth-models](https://manipulation-as-in-simulation.github.io/).
247
 
248
+ 📥 **Models are automatically downloaded from Hugging Face on first use!**
249
+
250
+ Choose your camera model (D435, D405, L515, or ZED 2i) and the trained weights will be downloaded automatically.
251
+ """
252
+ )
253
 
254
  with gr.Row():
255
  with gr.Column():
 
268
  )
269
 
270
  with gr.Accordion("⚙️ Advanced Settings", open=False):
271
+ camera_choice = gr.Dropdown(
272
+ choices=["d435", "d405", "l515", "zed2i"],
273
+ value=DEFAULT_MODEL,
274
+ label="Camera Model",
275
+ info="Choose the camera model for trained weights (auto-downloads from HF)",
276
  )
277
 
278
  input_size = gr.Slider(
 
304
  )
305
 
306
  use_xformers = gr.Checkbox(
307
+ value=False, # Set to True to test xFormers by default
308
  label="Use xFormers (CUDA only)",
309
  info="~8% faster on CUDA with xFormers installed",
310
  )
 
345
  inputs=[
346
  rgb_input,
347
  depth_input,
348
+ camera_choice,
349
  input_size,
350
  depth_scale,
351
  max_depth,
 
357
  )
358
 
359
  # Footer
360
+ gr.Markdown(
361
+ """
362
  ---
363
 
364
  ### 🔗 Links
 
386
  ---
387
 
388
  Built with ❤️ by [Aedelon](https://github.com/Aedelon) | Powered by [Gradio](https://gradio.app)
389
+ """
390
+ )
391
 
392
  if __name__ == "__main__":
393
+ demo.launch()