ritianyu commited on
Commit
1c7685e
·
1 Parent(s): b6640e8
Files changed (4) hide show
  1. .gitignore +3 -2
  2. InfiniDepth/utils/hf_demo_utils.py +5 -2
  3. app.py +125 -8
  4. requirements.txt +1 -1
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  checkpoints
2
- __pycache__
3
- .pyc
 
 
1
  checkpoints
2
+ __pycache__/
3
+ *.pyc
4
+ example_data/
InfiniDepth/utils/hf_demo_utils.py CHANGED
@@ -58,7 +58,8 @@ class ModelCache:
58
 
59
  def _parse_image_size(size_text: str) -> tuple[int, int]:
60
  try:
61
- h_text, w_text = size_text.lower().split("x")
 
62
  return int(h_text), int(w_text)
63
  except Exception as exc:
64
  raise ValueError(f"Invalid image size format: {size_text}, expected like 768x1024") from exc
@@ -228,6 +229,8 @@ def run_single_image_demo(
228
  raise ValueError("upsample_ratio must be in [1, 8]")
229
  if max_points_preview < 1000:
230
  raise ValueError("max_points_preview must be at least 1000")
 
 
231
 
232
  device = torch.device("cuda")
233
  image, org_h, org_w = _prepare_image_tensor(image_np, input_size, device)
@@ -238,7 +241,7 @@ def run_single_image_demo(
238
  org_w=org_w,
239
  h=h_in,
240
  w=w_in,
241
- output_size=input_size,
242
  upsample_ratio=upsample_ratio,
243
  )
244
 
 
58
 
59
  def _parse_image_size(size_text: str) -> tuple[int, int]:
60
  try:
61
+ normalized = size_text.lower().replace(" ", "")
62
+ h_text, w_text = normalized.split("x")
63
  return int(h_text), int(w_text)
64
  except Exception as exc:
65
  raise ValueError(f"Invalid image size format: {size_text}, expected like 768x1024") from exc
 
229
  raise ValueError("upsample_ratio must be in [1, 8]")
230
  if max_points_preview < 1000:
231
  raise ValueError("max_points_preview must be at least 1000")
232
+ else:
233
+ output_size = input_size
234
 
235
  device = torch.device("cuda")
236
  image, org_h, org_w = _prepare_image_tensor(image_np, input_size, device)
 
241
  org_w=org_w,
242
  h=h_in,
243
  w=w_in,
244
+ output_size=output_size,
245
  upsample_ratio=upsample_ratio,
246
  )
247
 
app.py CHANGED
@@ -33,10 +33,19 @@ except ImportError:
33
 
34
  MODEL_CACHE = ModelCache()
35
  OUTPUT_ROOT = Path(tempfile.gettempdir()) / "infinidepth_hf_demo"
 
 
 
36
 
37
  CUSTOM_CSS = """
38
- .gradio-container {
39
- max-width: 1280px !important;
 
 
 
 
 
 
40
  }
41
 
42
  #hero {
@@ -93,6 +102,97 @@ def _none_if_invalid(value: Optional[float]) -> Optional[float]:
93
  return None
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def _prepare_output_dir(request: Optional[gr.Request]) -> Path:
97
  session_hash = "local"
98
  if request is not None and getattr(request, "session_hash", None):
@@ -194,11 +294,11 @@ DESCRIPTION_MD = """
194
  """
195
 
196
 
197
- with gr.Blocks(title="InfiniDepth Demo", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
198
  gr.Markdown(DESCRIPTION_MD)
199
 
200
- with gr.Row():
201
- with gr.Column(scale=5):
202
  image_input = gr.Image(type="numpy", label="Input RGB Image")
203
  depth_input = gr.File(
204
  label="Optional Depth Map (.png/.npy/.npz/.h5/.hdf5/.exr)",
@@ -227,11 +327,12 @@ with gr.Blocks(title="InfiniDepth Demo", theme=gr.themes.Soft(), css=CUSTOM_CSS)
227
  value=1,
228
  step=1,
229
  label="Super-resolution Ratio",
 
230
  )
231
  max_points_preview = gr.Slider(
232
  minimum=10000,
233
  maximum=1000000,
234
- value=60000,
235
  step=5000,
236
  label="Max Preview Points",
237
  )
@@ -248,14 +349,14 @@ with gr.Blocks(title="InfiniDepth Demo", theme=gr.themes.Soft(), css=CUSTOM_CSS)
248
  "Use lower preview points for faster 3D interaction."
249
  )
250
 
251
- with gr.Column(scale=7):
252
  with gr.Tabs():
253
  with gr.Tab("3D View"):
254
  pcd_viewer = gr.Model3D(
255
  label="Point Cloud Viewer",
256
  display_mode="solid",
257
  clear_color=[1, 1, 1, 1],
258
- height=560,
259
  )
260
  with gr.Tab("Depth"):
261
  depth_output = gr.Image(type="numpy", label="Predicted Depth (Colorized)")
@@ -267,6 +368,7 @@ with gr.Blocks(title="InfiniDepth Demo", theme=gr.themes.Soft(), css=CUSTOM_CSS)
267
  )
268
  status = gr.Textbox(label="Status", interactive=False)
269
 
 
270
  run_button.click(
271
  fn=lambda: (None, None, [], "Running..."),
272
  outputs=[depth_output, pcd_viewer, files_output, status],
@@ -288,6 +390,21 @@ with gr.Blocks(title="InfiniDepth Demo", theme=gr.themes.Soft(), css=CUSTOM_CSS)
288
  outputs=[depth_output, pcd_viewer, files_output, status],
289
  )
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  demo = demo.queue()
292
 
293
 
 
33
 
34
  MODEL_CACHE = ModelCache()
35
  OUTPUT_ROOT = Path(tempfile.gettempdir()) / "infinidepth_hf_demo"
36
+ EXAMPLE_DATA_ROOT = Path(__file__).resolve().parent / "example_data"
37
+ EXAMPLE_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp"}
38
+ EXAMPLE_DEPTH_EXTENSIONS = {".png", ".npy", ".npz", ".h5", ".hdf5", ".exr"}
39
 
40
  CUSTOM_CSS = """
41
+ #main-layout {
42
+ width: 100%;
43
+ gap: 16px;
44
+ }
45
+
46
+ #left-panel,
47
+ #right-panel {
48
+ min-width: 0 !important;
49
  }
50
 
51
  #hero {
 
102
  return None
103
 
104
 
105
+ def _strip_rgb_suffix(stem: str) -> str:
106
+ suffixes = ("_rgb", "-rgb", ".rgb", "_image", "-image", ".image", "_color", "-color", ".color")
107
+ lowered = stem.lower()
108
+ for suffix in suffixes:
109
+ if lowered.endswith(suffix):
110
+ return lowered[: -len(suffix)].strip("._-")
111
+ return lowered
112
+
113
+
114
+ def _find_paired_depth(rgb_path: Path) -> Optional[Path]:
115
+ parent = rgb_path.parent
116
+ stem = rgb_path.stem.lower()
117
+ base = _strip_rgb_suffix(stem)
118
+
119
+ candidate_stems = []
120
+ for key in (stem, base):
121
+ if not key:
122
+ continue
123
+ candidate_stems.extend(
124
+ [
125
+ f"{key}_depth",
126
+ f"{key}-depth",
127
+ f"{key}.depth",
128
+ key,
129
+ ]
130
+ )
131
+ candidate_stems.extend(["depth", "gt_depth", "depth_gt"])
132
+
133
+ seen = set()
134
+ ordered_candidate_stems = []
135
+ for item in candidate_stems:
136
+ norm = item.strip("._-")
137
+ if not norm or norm in seen:
138
+ continue
139
+ seen.add(norm)
140
+ ordered_candidate_stems.append(norm)
141
+
142
+ for candidate_stem in ordered_candidate_stems:
143
+ for ext in EXAMPLE_DEPTH_EXTENSIONS:
144
+ candidate = parent / f"{candidate_stem}{ext}"
145
+ if candidate.is_file() and candidate.resolve() != rgb_path.resolve():
146
+ return candidate
147
+
148
+ fallback = [
149
+ path
150
+ for path in sorted(parent.iterdir())
151
+ if path.is_file()
152
+ and path.suffix.lower() in EXAMPLE_DEPTH_EXTENSIONS
153
+ and path.resolve() != rgb_path.resolve()
154
+ and "depth" in path.stem.lower()
155
+ ]
156
+ if len(fallback) == 1:
157
+ return fallback[0]
158
+ return None
159
+
160
+
161
+ def _collect_example_samples(limit: int = 24) -> list[dict[str, Optional[str]]]:
162
+ if not EXAMPLE_DATA_ROOT.exists():
163
+ return []
164
+
165
+ rows: list[dict[str, Optional[str]]] = []
166
+ for path in sorted(EXAMPLE_DATA_ROOT.rglob("*")):
167
+ if not path.is_file():
168
+ continue
169
+ if path.suffix.lower() not in EXAMPLE_IMAGE_EXTENSIONS:
170
+ continue
171
+ if "depth" in path.stem.lower():
172
+ continue
173
+
174
+ depth_path = _find_paired_depth(path)
175
+ rows.append(
176
+ {
177
+ "rgb": path.as_posix(),
178
+ "depth": depth_path.as_posix() if depth_path else None,
179
+ }
180
+ )
181
+ if len(rows) >= int(limit):
182
+ break
183
+ return rows
184
+
185
+
186
+ def _build_examples_rows(samples: list[dict[str, Optional[str]]]) -> list[list[Optional[str]]]:
187
+ rows: list[list[Optional[str]]] = []
188
+ for sample in samples:
189
+ rgb_path = sample.get("rgb")
190
+ if not rgb_path:
191
+ continue
192
+ rows.append([rgb_path, sample.get("depth")])
193
+ return rows
194
+
195
+
196
  def _prepare_output_dir(request: Optional[gr.Request]) -> Path:
197
  session_hash = "local"
198
  if request is not None and getattr(request, "session_hash", None):
 
294
  """
295
 
296
 
297
+ with gr.Blocks(title="InfiniDepth Demo", theme=gr.themes.Soft(), css=CUSTOM_CSS, fill_width=True) as demo:
298
  gr.Markdown(DESCRIPTION_MD)
299
 
300
+ with gr.Row(elem_id="main-layout"):
301
+ with gr.Column(elem_id="left-panel"):
302
  image_input = gr.Image(type="numpy", label="Input RGB Image")
303
  depth_input = gr.File(
304
  label="Optional Depth Map (.png/.npy/.npz/.h5/.hdf5/.exr)",
 
327
  value=1,
328
  step=1,
329
  label="Super-resolution Ratio",
330
+ visible=True,
331
  )
332
  max_points_preview = gr.Slider(
333
  minimum=10000,
334
  maximum=1000000,
335
+ value=500000,
336
  step=5000,
337
  label="Max Preview Points",
338
  )
 
349
  "Use lower preview points for faster 3D interaction."
350
  )
351
 
352
+ with gr.Column(elem_id="right-panel"):
353
  with gr.Tabs():
354
  with gr.Tab("3D View"):
355
  pcd_viewer = gr.Model3D(
356
  label="Point Cloud Viewer",
357
  display_mode="solid",
358
  clear_color=[1, 1, 1, 1],
359
+ height="60vh",
360
  )
361
  with gr.Tab("Depth"):
362
  depth_output = gr.Image(type="numpy", label="Predicted Depth (Colorized)")
 
368
  )
369
  status = gr.Textbox(label="Status", interactive=False)
370
 
371
+
372
  run_button.click(
373
  fn=lambda: (None, None, [], "Running..."),
374
  outputs=[depth_output, pcd_viewer, files_output, status],
 
390
  outputs=[depth_output, pcd_viewer, files_output, status],
391
  )
392
 
393
+ example_samples = _collect_example_samples()
394
+ if example_samples:
395
+ example_rows = _build_examples_rows(example_samples)
396
+ gr.Markdown("### Example Data")
397
+ gr.Markdown("Use template-style examples. RGB is always loaded; paired depth is loaded when available.")
398
+ gr.Examples(
399
+ examples=example_rows,
400
+ inputs=[image_input, depth_input],
401
+ label="Example Data",
402
+ cache_examples=False,
403
+ examples_per_page=10,
404
+ )
405
+ else:
406
+ gr.Markdown("### Example Data\nNo images found in `example_data/`.")
407
+
408
  demo = demo.queue()
409
 
410
 
requirements.txt CHANGED
@@ -4,7 +4,7 @@
4
  torch==2.9.1
5
  torchvision==0.24.1
6
  torchaudio==2.9.1
7
- xformers==0.0.33.post1
8
  hydra-colorlog
9
  hydra-core
10
  h5py
 
4
  torch==2.9.1
5
  torchvision==0.24.1
6
  torchaudio==2.9.1
7
+ xformers==0.0.33.post2
8
  hydra-colorlog
9
  hydra-core
10
  h5py