gagndeep commited on
Commit
c4d348a
·
1 Parent(s): 1b4e413
Files changed (1) hide show
  1. model_utils.py +231 -569
model_utils.py CHANGED
@@ -1,612 +1,274 @@
1
- """SHARP inference + optional CUDA video rendering utilities.
2
-
3
- Design goals:
4
- - Reuse SHARP's own predict/render pipeline (no subprocess calls).
5
- - Be robust on Hugging Face Spaces + ZeroGPU.
6
- - Cache model weights and predictor construction across requests.
7
-
8
- Public API (used by the Gradio app):
9
- - TrajectoryType
10
- - predict_and_maybe_render_gpu(...)
11
  """
12
 
13
  from __future__ import annotations
14
 
15
- import os
16
- import threading
17
- import time
18
- import uuid
19
- from contextlib import contextmanager
20
- from dataclasses import dataclass
21
  from pathlib import Path
22
- from typing import Final, Literal
 
23
 
24
- import torch
 
25
 
 
 
26
  try:
27
- import spaces
28
- except Exception: # pragma: no cover
29
- spaces = None # type: ignore[assignment]
 
 
 
 
30
 
31
- try:
32
- # Prefer HF cache / Hub downloads (works with Spaces `preload_from_hub`).
33
- from huggingface_hub import hf_hub_download, try_to_load_from_cache
34
- except Exception: # pragma: no cover
35
- hf_hub_download = None # type: ignore[assignment]
36
- try_to_load_from_cache = None # type: ignore[assignment]
37
 
38
- from sharp.cli.predict import DEFAULT_MODEL_URL, predict_image
39
- from sharp.cli.render import render_gaussians as sharp_render_gaussians
40
- from sharp.models import PredictorParams, create_predictor
41
- from sharp.utils import camera, io
42
- from sharp.utils.gaussians import Gaussians3D, SceneMetaData, save_ply
43
- from sharp.utils.gsplat import GSplatRenderer
44
 
45
- TrajectoryType = Literal["swipe", "shake", "rotate", "rotate_forward"]
46
 
47
  # -----------------------------------------------------------------------------
48
  # Helpers
49
  # -----------------------------------------------------------------------------
50
 
51
-
52
- def _now_ms() -> int:
53
- return int(time.time() * 1000)
54
-
55
-
56
  def _ensure_dir(path: Path) -> Path:
57
  path.mkdir(parents=True, exist_ok=True)
58
  return path
59
 
60
-
61
- def _make_even(x: int) -> int:
62
- return x if x % 2 == 0 else x + 1
63
-
64
-
65
- def _select_device(preference: str = "auto") -> torch.device:
66
- """Select the best available device for inference (CPU/CUDA/MPS)."""
67
- if preference not in {"auto", "cpu", "cuda", "mps"}:
68
- raise ValueError("device preference must be one of: auto|cpu|cuda|mps")
69
-
70
- if preference == "cpu":
71
- return torch.device("cpu")
72
- if preference == "cuda":
73
- return torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
- if preference == "mps":
75
- return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
76
-
77
- # auto
78
- if torch.cuda.is_available():
79
- return torch.device("cuda")
80
- if torch.backends.mps.is_available():
81
- return torch.device("mps")
82
- return torch.device("cpu")
83
-
84
-
85
- # -----------------------------------------------------------------------------
86
- # Prediction outputs
87
- # -----------------------------------------------------------------------------
88
-
89
-
90
- @dataclass(frozen=True, slots=True)
91
- class PredictionOutputs:
92
- """Outputs of SHARP inference (plus derived metadata for rendering)."""
93
-
94
- ply_path: Path
95
- gaussians: Gaussians3D
96
- metadata_for_render: SceneMetaData
97
- input_resolution_hw: tuple[int, int]
98
- focal_length_px: float
99
-
100
-
101
- # -----------------------------------------------------------------------------
102
- # Patch SHARP VideoWriter to properly close the optional depth writer
103
- # -----------------------------------------------------------------------------
104
-
105
-
106
- class _PatchedVideoWriter(io.VideoWriter):
107
- """Ensure depth writer is closed so files can be safely cleaned up."""
108
-
109
- def __init__(
110
- self, output_path: Path, fps: float = 30.0, render_depth: bool = True
111
- ) -> None:
112
- super().__init__(output_path, fps=fps, render_depth=render_depth)
113
- # Ensure attribute exists for downstream code paths.
114
- if not hasattr(self, "depth_writer"):
115
- self.depth_writer = None # type: ignore[attribute-defined-outside-init]
116
-
117
- def close(self):
118
- super().close()
119
- depth_writer = getattr(self, "depth_writer", None)
120
  try:
121
- if depth_writer is not None:
122
- depth_writer.close()
123
- except Exception:
124
- pass
125
-
126
-
127
- @contextmanager
128
- def _patched_sharp_videowriter():
129
- """Temporarily patch `sharp.utils.io.VideoWriter` used by `sharp.cli.render`."""
130
- original = io.VideoWriter
131
- io.VideoWriter = _PatchedVideoWriter # type: ignore[assignment]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  try:
133
- yield
134
- finally:
135
- io.VideoWriter = original # type: ignore[assignment]
136
-
137
-
138
- # -----------------------------------------------------------------------------
139
- # Model wrapper
140
- # -----------------------------------------------------------------------------
141
-
142
-
143
- class ModelWrapper:
144
- """Cached SHARP model wrapper for Gradio/Spaces."""
145
-
146
- def __init__(
147
- self,
148
- *,
149
- outputs_dir: str | Path = "outputs",
150
- checkpoint_url: str = DEFAULT_MODEL_URL,
151
- checkpoint_path: str | Path | None = None,
152
- device_preference: str = "auto",
153
- keep_model_on_device: bool | None = None,
154
- hf_repo_id: str | None = None,
155
- hf_filename: str | None = None,
156
- hf_revision: str | None = None,
157
- ) -> None:
158
- self.outputs_dir = _ensure_dir(Path(outputs_dir))
159
- self.checkpoint_url = checkpoint_url
160
-
161
- env_ckpt = os.getenv("SHARP_CHECKPOINT_PATH") or os.getenv("SHARP_CHECKPOINT")
162
- if checkpoint_path:
163
- self.checkpoint_path = Path(checkpoint_path)
164
- elif env_ckpt:
165
- self.checkpoint_path = Path(env_ckpt)
166
- else:
167
- self.checkpoint_path = None
168
-
169
- # Optional Hugging Face Hub fallback (useful when direct CDN download fails).
170
- self.hf_repo_id = hf_repo_id or os.getenv("SHARP_HF_REPO_ID", "apple/Sharp")
171
- self.hf_filename = hf_filename or os.getenv(
172
- "SHARP_HF_FILENAME", "sharp_2572gikvuh.pt"
 
 
 
 
173
  )
174
- self.hf_revision = hf_revision or os.getenv("SHARP_HF_REVISION") or None
175
 
176
- self.device_preference = device_preference
177
-
178
- # For ZeroGPU, it's safer to not keep large tensors on CUDA across calls.
179
- if keep_model_on_device is None:
180
- keep_env = (
181
- os.getenv("SHARP_KEEP_MODEL_ON_DEVICE")
182
- )
183
- self.keep_model_on_device = keep_env == "1"
184
- else:
185
- self.keep_model_on_device = keep_model_on_device
186
 
187
- self._lock = threading.RLock()
188
- self._predictor: torch.nn.Module | None = None
189
- self._predictor_device: torch.device | None = None
190
- self._state_dict: dict | None = None
191
 
192
- def has_cuda(self) -> bool:
193
- return torch.cuda.is_available()
 
194
 
195
- def _load_state_dict(self) -> dict:
196
- with self._lock:
197
- if self._state_dict is not None:
198
- return self._state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- # 1) Explicit local checkpoint path
201
- if self.checkpoint_path is not None:
202
- try:
203
- self._state_dict = torch.load(
204
- self.checkpoint_path,
205
- weights_only=True,
206
- map_location="cpu",
 
 
 
 
 
 
 
 
207
  )
208
- return self._state_dict
209
- except Exception as e:
210
- raise RuntimeError(
211
- "Failed to load SHARP checkpoint from local path.\n\n"
212
- f"Path:\n {self.checkpoint_path}\n\n"
213
- f"Original error:\n {type(e).__name__}: {e}"
214
- ) from e
215
-
216
- # 2) HF cache (no-network): best match for Spaces `preload_from_hub`.
217
- hf_cache_error: Exception | None = None
218
- if try_to_load_from_cache is not None:
219
- try:
220
- cached = try_to_load_from_cache(
221
- repo_id=self.hf_repo_id,
222
- filename=self.hf_filename,
223
- revision=self.hf_revision,
224
- repo_type="model",
225
  )
226
- except TypeError:
227
- cached = try_to_load_from_cache(self.hf_repo_id, self.hf_filename) # type: ignore[misc]
228
 
229
- try:
230
- if isinstance(cached, str) and Path(cached).exists():
231
- self._state_dict = torch.load(
232
- cached, weights_only=True, map_location="cpu"
233
- )
234
- return self._state_dict
235
- except Exception as e:
236
- hf_cache_error = e
237
 
238
- # 3) HF Hub download (reuse cache when available; may download otherwise).
239
- hf_error: Exception | None = None
240
- if hf_hub_download is not None:
241
- # Attempt "local only" mode if supported (avoids network).
242
- try:
243
- import inspect
244
 
245
- if "local_files_only" in inspect.signature(hf_hub_download).parameters:
246
- ckpt_path = hf_hub_download(
247
- repo_id=self.hf_repo_id,
248
- filename=self.hf_filename,
249
- revision=self.hf_revision,
250
- local_files_only=True,
251
- )
252
- if Path(ckpt_path).exists():
253
- self._state_dict = torch.load(
254
- ckpt_path, weights_only=True, map_location="cpu"
255
- )
256
- return self._state_dict
257
- except Exception:
258
- pass
259
-
260
- try:
261
- ckpt_path = hf_hub_download(
262
- repo_id=self.hf_repo_id,
263
- filename=self.hf_filename,
264
- revision=self.hf_revision,
265
- )
266
- self._state_dict = torch.load(
267
- ckpt_path,
268
- weights_only=True,
269
- map_location="cpu",
270
- )
271
- return self._state_dict
272
- except Exception as e:
273
- hf_error = e
274
-
275
- # 4) Default upstream CDN (torch hub cache). Last resort.
276
- url_error: Exception | None = None
277
- try:
278
- self._state_dict = torch.hub.load_state_dict_from_url(
279
- self.checkpoint_url,
280
- progress=True,
281
- map_location="cpu",
282
  )
283
- return self._state_dict
284
- except Exception as e:
285
- url_error = e
286
-
287
- # If we got here: all options failed.
288
- hint_lines = [
289
- "Failed to load SHARP checkpoint.",
290
- "",
291
- "Tried (in order):",
292
- f" 1) HF cache (preload_from_hub): repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}",
293
- f" 2) HF Hub download: repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}",
294
- f" 3) URL (torch hub): {self.checkpoint_url}",
295
- "",
296
- "If network access is restricted, set a local checkpoint path:",
297
- " - SHARP_CHECKPOINT_PATH=/path/to/sharp_2572gikvuh.pt",
298
- "",
299
- "Original errors:",
300
- ]
301
- if try_to_load_from_cache is None:
302
- hint_lines.append(" HF cache: huggingface_hub not installed")
303
- elif hf_cache_error is not None:
304
- hint_lines.append(
305
- f" HF cache: {type(hf_cache_error).__name__}: {hf_cache_error}"
306
- )
307
- else:
308
- hint_lines.append(" HF cache: (not found in cache)")
309
-
310
- if hf_hub_download is None:
311
- hint_lines.append(" HF download: huggingface_hub not installed")
312
- else:
313
- hint_lines.append(f" HF download: {type(hf_error).__name__}: {hf_error}")
314
-
315
- hint_lines.append(f" URL: {type(url_error).__name__}: {url_error}")
316
-
317
- raise RuntimeError("\n".join(hint_lines))
318
-
319
- def _get_predictor(self, device: torch.device) -> torch.nn.Module:
320
- with self._lock:
321
- if self._predictor is None:
322
- state_dict = self._load_state_dict()
323
- predictor = create_predictor(PredictorParams())
324
- predictor.load_state_dict(state_dict)
325
- predictor.eval()
326
- self._predictor = predictor
327
- self._predictor_device = torch.device("cpu")
328
-
329
- assert self._predictor is not None
330
- assert self._predictor_device is not None
331
-
332
- if self._predictor_device != device:
333
- self._predictor.to(device)
334
- self._predictor_device = device
335
-
336
- return self._predictor
337
-
338
- def _maybe_move_model_back_to_cpu(self) -> None:
339
- if self.keep_model_on_device:
340
- return
341
- with self._lock:
342
- if self._predictor is not None and self._predictor_device is not None:
343
- if self._predictor_device.type != "cpu":
344
- self._predictor.to("cpu")
345
- self._predictor_device = torch.device("cpu")
346
- if torch.cuda.is_available():
347
- torch.cuda.empty_cache()
348
-
349
- def _make_output_stem(self, input_path: Path) -> str:
350
- return f"{input_path.stem}-{_now_ms()}-{uuid.uuid4().hex[:8]}"
351
-
352
- def predict_to_ply(self, image_path: str | Path) -> PredictionOutputs:
353
- """Run SHARP inference and export a .ply file."""
354
- image_path = Path(image_path)
355
- if not image_path.exists():
356
- raise FileNotFoundError(f"Image does not exist: {image_path}")
357
-
358
- device = _select_device(self.device_preference)
359
- predictor = self._get_predictor(device)
360
-
361
- image_np, _, f_px = io.load_rgb(image_path)
362
- height, width = image_np.shape[:2]
363
-
364
- with torch.no_grad():
365
- gaussians = predict_image(predictor, image_np, f_px, device)
366
-
367
- stem = self._make_output_stem(image_path)
368
- ply_path = self.outputs_dir / f"{stem}.ply"
369
-
370
- # save_ply expects (height, width).
371
- save_ply(gaussians, f_px, (height, width), ply_path)
372
-
373
- # SceneMetaData expects (width, height) for resolution.
374
- metadata_for_render = SceneMetaData(
375
- focal_length_px=float(f_px),
376
- resolution_px=(int(width), int(height)),
377
- color_space="linearRGB",
378
- )
379
-
380
- self._maybe_move_model_back_to_cpu()
381
-
382
- return PredictionOutputs(
383
- ply_path=ply_path,
384
- gaussians=gaussians,
385
- metadata_for_render=metadata_for_render,
386
- input_resolution_hw=(int(height), int(width)),
387
- focal_length_px=float(f_px),
388
- )
389
-
390
- def _render_video_impl(
391
- self,
392
- *,
393
- gaussians: Gaussians3D,
394
- metadata: SceneMetaData,
395
- output_path: Path,
396
- trajectory_type: TrajectoryType,
397
- num_frames: int,
398
- fps: int,
399
- output_long_side: int | None,
400
- ) -> Path:
401
- if not torch.cuda.is_available():
402
- raise RuntimeError("Rendering requires CUDA (gsplat).")
403
-
404
- if num_frames < 2:
405
- raise ValueError("num_frames must be >= 2")
406
- if fps < 1:
407
- raise ValueError("fps must be >= 1")
408
-
409
- # Keep aligned with upstream CLI pipeline where possible.
410
- if output_long_side is None and int(fps) == 30:
411
- params = camera.TrajectoryParams(
412
- type=trajectory_type,
413
- num_steps=int(num_frames),
414
- num_repeats=1,
415
  )
416
- with _patched_sharp_videowriter():
417
- sharp_render_gaussians(
418
- gaussians=gaussians,
419
- metadata=metadata,
420
- params=params,
421
- output_path=output_path,
422
- )
423
- depth_path = output_path.with_suffix(".depth.mp4")
424
- try:
425
- if depth_path.exists():
426
- depth_path.unlink()
427
- except Exception:
428
- pass
429
- return output_path
430
-
431
- # Adapted pipeline for custom output resolution / FPS.
432
- src_w, src_h = metadata.resolution_px
433
- src_f = float(metadata.focal_length_px)
434
-
435
- if output_long_side is None:
436
- out_w, out_h, out_f = src_w, src_h, src_f
437
- else:
438
- long_side = max(src_w, src_h)
439
- scale = float(output_long_side) / float(long_side)
440
- out_w = _make_even(max(2, int(round(src_w * scale))))
441
- out_h = _make_even(max(2, int(round(src_h * scale))))
442
- out_f = src_f * scale
443
-
444
- traj_params = camera.TrajectoryParams(
445
- type=trajectory_type,
446
- num_steps=int(num_frames),
447
- num_repeats=1,
448
- )
449
-
450
- device = torch.device("cuda")
451
- gaussians_cuda = gaussians.to(device)
452
-
453
- intrinsics = torch.tensor(
454
- [
455
- [out_f, 0.0, (out_w - 1) / 2.0, 0.0],
456
- [0.0, out_f, (out_h - 1) / 2.0, 0.0],
457
- [0.0, 0.0, 1.0, 0.0],
458
- [0.0, 0.0, 0.0, 1.0],
459
  ],
460
- device=device,
461
- dtype=torch.float32,
462
- )
463
-
464
- cam_model = camera.create_camera_model(
465
- gaussians_cuda,
466
- intrinsics,
467
- resolution_px=(out_w, out_h),
468
- lookat_mode=traj_params.lookat_mode,
469
  )
470
 
471
- trajectory = camera.create_eye_trajectory(
472
- gaussians_cuda,
473
- traj_params,
474
- resolution_px=(out_w, out_h),
475
- f_px=out_f,
476
- )
477
-
478
- renderer = GSplatRenderer(color_space=metadata.color_space)
479
-
480
- # IMPORTANT: Keep render_depth=True (avoids upstream AttributeError).
481
- video_writer = _PatchedVideoWriter(output_path, fps=float(fps), render_depth=True)
482
-
483
- for eye_position in trajectory:
484
- cam_info = cam_model.compute(eye_position)
485
- rendering = renderer(
486
- gaussians_cuda,
487
- extrinsics=cam_info.extrinsics[None].to(device),
488
- intrinsics=cam_info.intrinsics[None].to(device),
489
- image_width=cam_info.width,
490
- image_height=cam_info.height,
491
- )
492
- color = (rendering.color[0].permute(1, 2, 0) * 255.0).to(dtype=torch.uint8)
493
- depth = rendering.depth[0]
494
- video_writer.add_frame(color, depth)
495
-
496
- video_writer.close()
497
-
498
- depth_path = output_path.with_suffix(".depth.mp4")
499
- try:
500
- if depth_path.exists():
501
- depth_path.unlink()
502
- except Exception:
503
- pass
504
-
505
- return output_path
506
-
507
- def render_video(
508
- self,
509
- *,
510
- gaussians: Gaussians3D,
511
- metadata: SceneMetaData,
512
- output_stem: str,
513
- trajectory_type: TrajectoryType = "rotate_forward",
514
- num_frames: int = 60,
515
- fps: int = 30,
516
- output_long_side: int | None = None,
517
- ) -> Path:
518
- """Render a camera trajectory as an MP4 (CUDA-only)."""
519
- output_path = self.outputs_dir / f"{output_stem}.mp4"
520
- return self._render_video_impl(
521
- gaussians=gaussians,
522
- metadata=metadata,
523
- output_path=output_path,
524
- trajectory_type=trajectory_type,
525
- num_frames=num_frames,
526
- fps=fps,
527
- output_long_side=output_long_side,
528
- )
529
-
530
- def predict_and_maybe_render(
531
- self,
532
- image_path: str | Path,
533
- *,
534
- trajectory_type: TrajectoryType,
535
- num_frames: int,
536
- fps: int,
537
- output_long_side: int | None,
538
- render_video: bool = True,
539
- ) -> tuple[Path | None, Path]:
540
- """One-shot helper for the UI: returns (video_path, ply_path)."""
541
- pred = self.predict_to_ply(image_path)
542
-
543
- if not render_video:
544
- return None, pred.ply_path
545
-
546
- if not torch.cuda.is_available():
547
- return None, pred.ply_path
548
-
549
- output_stem = pred.ply_path.with_suffix("").name
550
- video_path = self.render_video(
551
- gaussians=pred.gaussians,
552
- metadata=pred.metadata_for_render,
553
- output_stem=output_stem,
554
- trajectory_type=trajectory_type,
555
- num_frames=num_frames,
556
- fps=fps,
557
- output_long_side=output_long_side,
558
- )
559
- return video_path, pred.ply_path
560
-
561
 
562
  # -----------------------------------------------------------------------------
563
- # ZeroGPU entrypoints
564
  # -----------------------------------------------------------------------------
565
- #
566
- # IMPORTANT: Do NOT decorate bound instance methods with `@spaces.GPU` on ZeroGPU.
567
- # The wrapper uses multiprocessing queues and pickles args/kwargs. If `self` is
568
- # included, Python will try to pickle the whole instance. ModelWrapper contains
569
- # a threading.RLock (not pickleable) and the model itself should not be pickled.
570
- #
571
- # Expose module-level functions that accept only pickleable arguments and
572
- # create/cache the ModelWrapper inside the GPU worker process.
573
-
574
- DEFAULT_OUTPUTS_DIR: Final[Path] = _ensure_dir(Path(__file__).resolve().parent / "outputs")
575
-
576
- _GLOBAL_MODEL: ModelWrapper | None = None
577
- _GLOBAL_MODEL_INIT_LOCK: Final[threading.Lock] = threading.Lock()
578
-
579
-
580
- def get_global_model(*, outputs_dir: str | Path = DEFAULT_OUTPUTS_DIR) -> ModelWrapper:
581
- global _GLOBAL_MODEL
582
- with _GLOBAL_MODEL_INIT_LOCK:
583
- if _GLOBAL_MODEL is None:
584
- _GLOBAL_MODEL = ModelWrapper(outputs_dir=outputs_dir)
585
- return _GLOBAL_MODEL
586
-
587
-
588
- def predict_and_maybe_render(
589
- image_path: str | Path,
590
- *,
591
- trajectory_type: TrajectoryType,
592
- num_frames: int,
593
- fps: int,
594
- output_long_side: int | None,
595
- render_video: bool = True,
596
- ) -> tuple[Path | None, Path]:
597
- model = get_global_model()
598
- return model.predict_and_maybe_render(
599
- image_path,
600
- trajectory_type=trajectory_type,
601
- num_frames=num_frames,
602
- fps=fps,
603
- output_long_side=output_long_side,
604
- render_video=render_video,
605
- )
606
 
 
607
 
608
- # Export the GPU-wrapped callable (or a no-op wrapper locally).
609
- if spaces is not None:
610
- predict_and_maybe_render_gpu = spaces.GPU(duration=180)(predict_and_maybe_render)
611
- else: # pragma: no cover
612
- predict_and_maybe_render_gpu = predict_and_maybe_render
 
 
1
+ """
2
+ SHARP Gradio Demo (Fixed)
3
+ - Standard Two-Column Layout
4
+ - Robust Error Handling
5
+ - Glitch-free Examples (Load-only)
 
 
 
 
 
6
  """
7
 
8
  from __future__ import annotations
9
 
10
+ import warnings
11
+ import json
 
 
 
 
12
  from pathlib import Path
13
+ from typing import Final
14
+ import gradio as gr
15
 
16
+ # Suppress internal warnings
17
+ warnings.filterwarnings("ignore", category=FutureWarning, module="torch.distributed")
18
 
19
+ # Ensure model_utils is present
20
+ # We wrap this import to prevent app crash if model_utils is missing during UI dev
21
  try:
22
+ from model_utils import TrajectoryType, predict_and_maybe_render_gpu
23
+ except ImportError:
24
+ # Dummy mocks for testing/building UI without backend
25
+ class TrajectoryType:
26
+ pass
27
+ def predict_and_maybe_render_gpu(*args, **kwargs):
28
+ return None, Path("dummy.ply")
29
 
30
+ # -----------------------------------------------------------------------------
31
+ # Paths & Config
32
+ # -----------------------------------------------------------------------------
 
 
 
33
 
34
+ APP_DIR: Final[Path] = Path(__file__).resolve().parent
35
+ OUTPUTS_DIR: Final[Path] = APP_DIR / "outputs"
36
+ ASSETS_DIR: Final[Path] = APP_DIR / "assets"
37
+ EXAMPLES_DIR: Final[Path] = ASSETS_DIR / "examples"
 
 
38
 
39
+ IMAGE_EXTS: Final[tuple[str, ...]] = (".png", ".jpg", ".jpeg", ".webp")
40
 
41
  # -----------------------------------------------------------------------------
42
  # Helpers
43
  # -----------------------------------------------------------------------------
44
 
 
 
 
 
 
45
  def _ensure_dir(path: Path) -> Path:
46
  path.mkdir(parents=True, exist_ok=True)
47
  return path
48
 
49
+ def get_example_files() -> list[list[str]]:
50
+ """Discover images in assets/examples for the UI."""
51
+ _ensure_dir(EXAMPLES_DIR)
52
+
53
+ # Check manifest.json first
54
+ manifest_path = EXAMPLES_DIR / "manifest.json"
55
+ if manifest_path.exists():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  try:
57
+ data = json.loads(manifest_path.read_text(encoding="utf-8"))
58
+ examples = []
59
+ for entry in data:
60
+ if "image" in entry:
61
+ img_path = EXAMPLES_DIR / entry["image"]
62
+ if img_path.exists():
63
+ examples.append([str(img_path)])
64
+ if examples:
65
+ return examples
66
+ except Exception as e:
67
+ print(f"Manifest error: {e}")
68
+
69
+ # Fallback: simple file scan
70
+ examples = []
71
+ for ext in IMAGE_EXTS:
72
+ for img in sorted(EXAMPLES_DIR.glob(f"*{ext}")):
73
+ examples.append([str(img)])
74
+ return examples
75
+
76
+ def run_sharp(
77
+ image_path: str | None,
78
+ trajectory_preset: str,
79
+ output_long_side: int | float | None,
80
+ num_frames: int | float,
81
+ fps: int | float,
82
+ render_video: bool,
83
+ progress=gr.Progress()
84
+ ) -> tuple[str | None, str | None, str]:
85
+ """
86
+ Main Inference Function
87
+ """
88
+ if not image_path:
89
+ raise gr.Error("Please upload an image first.")
90
+
91
+ # 1. Safe Integer Conversion (Handle None or Float inputs from sliders)
92
  try:
93
+ out_long_side_val = int(output_long_side) if output_long_side and int(output_long_side) > 0 else None
94
+ n_frames = int(num_frames)
95
+ fps_val = int(fps)
96
+ except (TypeError, ValueError):
97
+ # Fallbacks if UI sends weird data
98
+ out_long_side_val = None
99
+ n_frames = 60
100
+ fps_val = 30
101
+
102
+ # 2. Safe Trajectory Mapping
103
+ # Map UI friendly names to internal keys
104
+ traj_map = {
105
+ "Orbit (Standard)": "rotate",
106
+ "Orbit (Forward)": "rotate_forward",
107
+ "Swipe Left": "swipe",
108
+ "Shake": "shake",
109
+ "Zoom In": "zoom",
110
+ "Dolly": "dolly"
111
+ }
112
+
113
+ internal_name = traj_map.get(trajectory_preset, "rotate")
114
+
115
+ # Try to find the Enum member safely
116
+ traj_enum = internal_name # Default to string if Enum logic fails
117
+ try:
118
+ if hasattr(TrajectoryType, internal_name.upper()):
119
+ traj_enum = getattr(TrajectoryType, internal_name.upper())
120
+ elif hasattr(TrajectoryType, internal_name):
121
+ traj_enum = getattr(TrajectoryType, internal_name)
122
+ except Exception:
123
+ print(f"Warning: Could not resolve TrajectoryType.{internal_name}, passing string '{internal_name}'")
124
+ traj_enum = internal_name
125
+
126
+ # 3. Run Inference
127
+ try:
128
+ progress(0.1, desc="Initializing model...")
129
+
130
+ video_path, ply_path = predict_and_maybe_render_gpu(
131
+ image_path,
132
+ trajectory_type=traj_enum,
133
+ num_frames=n_frames,
134
+ fps=fps_val,
135
+ output_long_side=out_long_side_val,
136
+ render_video=bool(render_video),
137
  )
 
138
 
139
+ status_msg = f"✅ **Success**\n\nPLY: `{ply_path.name}`"
140
+ if video_path:
141
+ status_msg += f"\nVideo: `{video_path.name}`"
142
+
143
+ return (
144
+ str(video_path) if video_path else None,
145
+ str(ply_path),
146
+ status_msg
147
+ )
 
148
 
149
+ except Exception as e:
150
+ # Catch all errors to prevent UI crash
151
+ raise gr.Error(f"Generation failed: {str(e)}")
 
152
 
153
+ # -----------------------------------------------------------------------------
154
+ # UI Construction
155
+ # -----------------------------------------------------------------------------
156
 
157
+ def build_demo() -> gr.Blocks:
158
+ theme = gr.themes.Default()
159
+
160
+ css = """
161
+ .container { max-width: 1200px; margin: auto; }
162
+ #header { text-align: center; margin-bottom: 20px; }
163
+ """
164
+
165
+ with gr.Blocks(theme=theme, css=css, title="SHARP 3D") as demo:
166
+
167
+ # --- Header ---
168
+ with gr.Column(elem_id="header"):
169
+ gr.Markdown("# SHARP: Single-Image 3D Generator")
170
+ gr.Markdown("Convert any static image into a 3D Gaussian Splat scene instantly.")
171
+
172
+ # --- Main Two-Column Layout ---
173
+ with gr.Row(equal_height=False):
174
+
175
+ # --- LEFT COLUMN: Input & Controls ---
176
+ with gr.Column():
177
+ image_in = gr.Image(
178
+ label="Input Image",
179
+ type="filepath",
180
+ sources=["upload", "clipboard"],
181
+ height=350
182
+ )
183
 
184
+ # Controls are visible (no accordion)
185
+ with gr.Group():
186
+ gr.Markdown("### 🎥 Settings")
187
+ trajectory_preset = gr.Dropdown(
188
+ label="Camera Movement",
189
+ choices=[
190
+ "Orbit (Standard)",
191
+ "Orbit (Forward)",
192
+ "Swipe Left",
193
+ "Shake",
194
+ "Zoom In",
195
+ "Dolly"
196
+ ],
197
+ value="Orbit (Forward)",
198
+ interactive=True
199
  )
200
+
201
+ output_res = gr.Dropdown(
202
+ label="Output Resolution",
203
+ choices=[("Original", 0), ("512px", 512), ("1024px", 1024)],
204
+ value=0,
205
+ interactive=True
 
 
 
 
 
 
 
 
 
 
 
206
  )
 
 
207
 
208
+ # Advanced (Collapsible)
209
+ with gr.Accordion("Advanced Options", open=False):
210
+ frames = gr.Slider(label="Frames", minimum=24, maximum=120, step=1, value=60)
211
+ fps_in = gr.Slider(label="FPS", minimum=8, maximum=60, step=1, value=30)
212
+ render_toggle = gr.Checkbox(label="Render Video Preview", value=True)
 
 
 
213
 
214
+ run_btn = gr.Button("🚀 Generate 3D Scene", variant="primary", size="lg")
 
 
 
 
 
215
 
216
+ # --- RIGHT COLUMN: Output ---
217
+ with gr.Column():
218
+ video_out = gr.Video(
219
+ label="3D Preview",
220
+ autoplay=True,
221
+ height=350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  )
223
+
224
+ with gr.Group():
225
+ ply_download = gr.DownloadButton(
226
+ label="Download .PLY File",
227
+ variant="secondary",
228
+ visible=True
229
+ )
230
+ status_md = gr.Markdown("Waiting for input...")
231
+
232
+ # --- Footer: Examples ---
233
+ gr.Markdown("### 📝 Examples")
234
+ example_files = get_example_files()
235
+
236
+ if example_files:
237
+ gr.Examples(
238
+ examples=example_files,
239
+ inputs=[image_in],
240
+ # CRITICAL FIX: We do NOT set fn=run_sharp here.
241
+ # This ensures clicking an example ONLY fills the image input.
242
+ # The user must click "Generate" to run (prevents the 'None' arguments crash).
243
+ label="Click an image to load it:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  )
245
+
246
+ # --- Event Binding ---
247
+ run_btn.click(
248
+ fn=run_sharp,
249
+ inputs=[
250
+ image_in,
251
+ trajectory_preset,
252
+ output_res,
253
+ frames,
254
+ fps_in,
255
+ render_toggle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  ],
257
+ outputs=[video_out, ply_download, status_md],
258
+ concurrency_limit=1
 
 
 
 
 
 
 
259
  )
260
 
261
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  # -----------------------------------------------------------------------------
264
+ # Entry Point
265
  # -----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
+ _ensure_dir(OUTPUTS_DIR)
268
 
269
+ if __name__ == "__main__":
270
+ demo = build_demo()
271
+ demo.queue().launch(
272
+ allowed_paths=[str(ASSETS_DIR)],
273
+ ssr_mode=False
274
+ )