github-actions[bot] commited on
Commit
b4dbe35
·
1 Parent(s): 0567380

Sync from GitHub: f84ca3dac3c962b2c71c590ad187e2352331038b

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +41 -94
README.md CHANGED
@@ -26,7 +26,7 @@ Optimized Python package for RGB-D depth refinement using Vision Transformer enc
26
 
27
  [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/Aedelon/rgbd-depth)
28
 
29
- Try **rgbd-depth** directly in your browser with our interactive Gradio demo—no installation required.
30
 
31
  **Available on Hugging Face Spaces:** Upload your RGB and depth images, adjust parameters (camera model, precision, resolution), and get refined depth maps instantly. Models are automatically downloaded from Hugging Face Hub on first use.
32
 
 
26
 
27
  [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/Aedelon/rgbd-depth)
28
 
29
+ Try **rgbd-depth** directly in your browser with our interactive Gradio demo—no installation required. Upload your images and refine depth maps instantly.
30
 
31
  **Available on Hugging Face Spaces:** Upload your RGB and depth images, adjust parameters (camera model, precision, resolution), and get refined depth maps instantly. Models are automatically downloaded from Hugging Face Hub on first use.
32
 
app.py CHANGED
@@ -8,102 +8,48 @@ import gradio as gr
8
  import numpy as np
9
  import torch
10
  from PIL import Image
11
- from pathlib import Path
12
 
13
  from rgbddepth import RGBDDepth
14
 
15
  # Global model cache
16
  MODELS = {}
17
 
18
- # Model mappings from HuggingFace (all are vitl encoder)
19
- # Format: "camera_model": ("repo_id", "checkpoint_filename")
20
- HF_MODELS = {
21
- "d435": ("depth-anything/camera-depth-model-d435", "cdm_d435.ckpt"),
22
- "d405": ("depth-anything/camera-depth-model-d405", "cdm_d405.ckpt"),
23
- "l515": ("depth-anything/camera-depth-model-l515", "cdm_l515.ckpt"),
24
- "zed2i": ("depth-anything/camera-depth-model-zed2i", "cdm_zed2i.ckpt"),
25
- }
26
 
27
- # Default model
28
- DEFAULT_MODEL = "d435"
29
-
30
-
31
- def download_model(camera_model: str = DEFAULT_MODEL):
32
- """Download model from HuggingFace Hub."""
33
- try:
34
- from huggingface_hub import hf_hub_download
35
-
36
- repo_id, filename = HF_MODELS.get(camera_model, HF_MODELS[DEFAULT_MODEL])
37
- print(f"📥 Downloading {camera_model} model from {repo_id}/{filename}...")
38
-
39
- # Download the checkpoint
40
- checkpoint_path = hf_hub_download(
41
- repo_id=repo_id,
42
- filename=filename,
43
- cache_dir=".cache"
44
- )
45
-
46
- print(f"✓ Downloaded to {checkpoint_path}")
47
- return checkpoint_path
48
-
49
- except Exception as e:
50
- print(f"❌ Failed to download model: {e}")
51
- return None
52
-
53
-
54
- def load_model(camera_model: str = DEFAULT_MODEL, use_xformers: bool = False):
55
- """Load model with automatic download from HuggingFace."""
56
- cache_key = f"{camera_model}_{use_xformers}"
57
 
58
  if cache_key not in MODELS:
59
- # All HF models use vitl encoder
60
- config = {
61
- "encoder": "vitl",
62
- "features": 256,
63
- "out_channels": [256, 512, 1024, 1024],
64
- "use_xformers": use_xformers,
65
  }
66
 
67
- model = RGBDDepth(**config)
 
68
 
69
- # Try to load weights
70
- checkpoint_path = None
71
 
72
- # 1. Try local checkpoints/ directory first
73
- local_path = Path(f"checkpoints/{camera_model}.pt")
74
- if local_path.exists():
75
- checkpoint_path = str(local_path)
76
- print(f"✓ Using local checkpoint: {checkpoint_path}")
77
- else:
78
- # 2. Download from HuggingFace
79
- checkpoint_path = download_model(camera_model)
80
-
81
- # Load checkpoint if available
82
- if checkpoint_path:
83
- try:
84
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
85
- if "model" in checkpoint:
86
- states = {k[7:]: v for k, v in checkpoint["model"].items()}
87
- elif "state_dict" in checkpoint:
88
- states = {k[9:]: v for k, v in checkpoint["state_dict"].items()}
89
- else:
90
- states = checkpoint
91
-
92
- model.load_state_dict(states, strict=False)
93
- print(f"✓ Loaded checkpoint for {camera_model}")
94
- except Exception as e:
95
- print(f"⚠ Failed to load checkpoint: {e}, using random weights")
96
- else:
97
- print(f"⚠ No checkpoint available for {camera_model}, using random weights (demo only)")
98
 
99
- # Move to GPU if available (CUDA or MPS for macOS)
100
- if torch.cuda.is_available():
101
- device = "cuda"
102
- elif torch.backends.mps.is_available():
103
- device = "mps"
104
- else:
105
- device = "cpu"
106
 
 
 
107
  model = model.to(device).eval()
108
 
109
  MODELS[cache_key] = model
@@ -114,7 +60,7 @@ def load_model(camera_model: str = DEFAULT_MODEL, use_xformers: bool = False):
114
  def process_depth(
115
  rgb_image: np.ndarray,
116
  depth_image: np.ndarray,
117
- camera_model: str = DEFAULT_MODEL,
118
  input_size: int = 518,
119
  depth_scale: float = 1000.0,
120
  max_depth: float = 25.0,
@@ -127,7 +73,7 @@ def process_depth(
127
  Args:
128
  rgb_image: RGB image as numpy array [H, W, 3]
129
  depth_image: Depth image as numpy array [H, W] or [H, W, 3]
130
- camera_model: Camera model to use (d435, d405, l515, zed2i)
131
  input_size: Input size for inference
132
  depth_scale: Scale factor for depth values
133
  max_depth: Maximum valid depth value
@@ -159,7 +105,7 @@ def process_depth(
159
  simi_depth[valid_mask] = 1.0 / depth_normalized[valid_mask]
160
 
161
  # Load model
162
- model = load_model(camera_model, use_xformers and torch.cuda.is_available())
163
  device = next(model.parameters()).device
164
 
165
  # Determine precision
@@ -209,7 +155,7 @@ def process_depth(
209
  info = f"""
210
  ✅ **Refinement complete!**
211
 
212
- **Camera Model:** {camera_model.upper()}
213
  **Precision:** {precision.upper()}
214
  **Device:** {device.type.upper()}
215
  **Input size:** {input_size}px
@@ -230,9 +176,10 @@ with gr.Blocks(title="rgbd-depth Demo") as demo:
230
 
231
  High-quality depth map refinement using Vision Transformers. Based on [ByteDance's camera-depth-models](https://manipulation-as-in-simulation.github.io/).
232
 
233
- 📥 **Models are automatically downloaded from Hugging Face on first use!**
234
-
235
- Choose your camera model (D435, D405, L515, or ZED 2i) and the trained weights will be downloaded automatically.
 
236
  """)
237
 
238
  with gr.Row():
@@ -252,11 +199,11 @@ with gr.Blocks(title="rgbd-depth Demo") as demo:
252
  )
253
 
254
  with gr.Accordion("⚙️ Advanced Settings", open=False):
255
- camera_choice = gr.Dropdown(
256
- choices=["d435", "d405", "l515", "zed2i"],
257
- value=DEFAULT_MODEL,
258
- label="Camera Model",
259
- info="Choose the camera model for trained weights (auto-downloads from HF)",
260
  )
261
 
262
  input_size = gr.Slider(
@@ -329,7 +276,7 @@ with gr.Blocks(title="rgbd-depth Demo") as demo:
329
  inputs=[
330
  rgb_input,
331
  depth_input,
332
- camera_choice,
333
  input_size,
334
  depth_scale,
335
  max_depth,
 
8
  import numpy as np
9
  import torch
10
  from PIL import Image
 
11
 
12
  from rgbddepth import RGBDDepth
13
 
14
  # Global model cache
15
  MODELS = {}
16
 
 
 
 
 
 
 
 
 
17
 
18
+ def load_model(encoder: str, use_xformers: bool = False):
19
+ """Load model with caching."""
20
+ cache_key = f"{encoder}_{use_xformers}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  if cache_key not in MODELS:
23
+ # Model configs
24
+ configs = {
25
+ "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
26
+ "vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
27
+ "vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
28
+ "vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
29
  }
30
 
31
+ config = configs[encoder].copy()
32
+ config["use_xformers"] = use_xformers
33
 
34
+ model = RGBDDepth(**config)
 
35
 
36
+ # Try to load weights if checkpoint exists
37
+ try:
38
+ checkpoint = torch.load(f"checkpoints/{encoder}.pt", map_location="cpu")
39
+ if "model" in checkpoint:
40
+ states = {k[7:]: v for k, v in checkpoint["model"].items()}
41
+ elif "state_dict" in checkpoint:
42
+ states = {k[9:]: v for k, v in checkpoint["state_dict"].items()}
43
+ else:
44
+ states = checkpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ model.load_state_dict(states, strict=False)
47
+ print(f"✓ Loaded checkpoint for {encoder}")
48
+ except FileNotFoundError:
49
+ print(f"⚠ No checkpoint found for {encoder}, using random weights (demo only)")
 
 
 
50
 
51
+ # Move to GPU if available
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
  model = model.to(device).eval()
54
 
55
  MODELS[cache_key] = model
 
60
  def process_depth(
61
  rgb_image: np.ndarray,
62
  depth_image: np.ndarray,
63
+ encoder: str = "vitl",
64
  input_size: int = 518,
65
  depth_scale: float = 1000.0,
66
  max_depth: float = 25.0,
 
73
  Args:
74
  rgb_image: RGB image as numpy array [H, W, 3]
75
  depth_image: Depth image as numpy array [H, W] or [H, W, 3]
76
+ encoder: Model encoder type
77
  input_size: Input size for inference
78
  depth_scale: Scale factor for depth values
79
  max_depth: Maximum valid depth value
 
105
  simi_depth[valid_mask] = 1.0 / depth_normalized[valid_mask]
106
 
107
  # Load model
108
+ model = load_model(encoder, use_xformers and torch.cuda.is_available())
109
  device = next(model.parameters()).device
110
 
111
  # Determine precision
 
155
  info = f"""
156
  ✅ **Refinement complete!**
157
 
158
+ **Model:** {encoder.upper()}
159
  **Precision:** {precision.upper()}
160
  **Device:** {device.type.upper()}
161
  **Input size:** {input_size}px
 
176
 
177
  High-quality depth map refinement using Vision Transformers. Based on [ByteDance's camera-depth-models](https://manipulation-as-in-simulation.github.io/).
178
 
179
+ ⚠️ **Note:** This demo uses random weights for demonstration. For real results:
180
+ 1. Download checkpoints from [Hugging Face](https://huggingface.co/collections/depth-anything/camera-depth-models-68b521181dedd223f4b020db)
181
+ 2. Place in `checkpoints/` directory
182
+ 3. Restart the app
183
  """)
184
 
185
  with gr.Row():
 
199
  )
200
 
201
  with gr.Accordion("⚙️ Advanced Settings", open=False):
202
+ encoder_choice = gr.Radio(
203
+ choices=["vits", "vitb", "vitl", "vitg"],
204
+ value="vitl",
205
+ label="Encoder Model",
206
+ info="Larger = better quality but slower",
207
  )
208
 
209
  input_size = gr.Slider(
 
276
  inputs=[
277
  rgb_input,
278
  depth_input,
279
+ encoder_choice,
280
  input_size,
281
  depth_scale,
282
  max_depth,