lidavidsh commited on
Commit
246c42f
ยท
1 Parent(s): cfa6140

init push

Browse files
Files changed (4) hide show
  1. api_server.py +464 -0
  2. app.py +372 -46
  3. pyproject.toml +3 -3
  4. requirements.txt +1 -22
api_server.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
3
+ # Backend API server for Depth Anything 3 remote inference
4
+
5
+ import os
6
+ import sys
7
+ import asyncio
8
+ import base64
9
+ import io
10
+ import json
11
+ import uuid
12
+ from typing import Dict, Any, Optional
13
+ from datetime import datetime
14
+ import glob
15
+ import shutil
16
+ import zipfile
17
+
18
+ import numpy as np
19
+ import torch
20
+ from fastapi import FastAPI, WebSocket, HTTPException, Query
21
+ from fastapi.responses import JSONResponse
22
+ from fastapi.middleware.cors import CORSMiddleware
23
+ from pydantic import BaseModel
24
+ import uvicorn
25
+
26
+ sys.path.append("depth-anything-3/")
27
+
28
+ from depth_anything_3.api import DepthAnything3 # noqa: E402
29
+ from depth_anything_3.utils.export.glb import export_to_glb # noqa: E402
30
+ from depth_anything_3.utils.export.gs import export_to_gs_video # noqa: E402
31
+
32
+ # Initialize FastAPI app
33
+ app = FastAPI(title="Depth Anything 3 Inference API", version="1.0.0")
34
+
35
+ # Add CORS middleware
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"],
39
+ allow_credentials=True,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+
44
+ # Global model instance
45
+ model: Optional[DepthAnything3] = None
46
+ device: Optional[str] = None
47
+
48
+ # Job storage: {job_id: {"status": "processing/completed/failed", "result": {...}, "progress": 0}}
49
+ jobs: Dict[str, Dict[str, Any]] = {}
50
+
51
+ # WebSocket connections: {client_id: websocket}
52
+ websocket_connections: Dict[str, WebSocket] = {}
53
+
54
+
55
+ # -------------------------------------------------------------------------
56
+ # Request/Response Models
57
+ # -------------------------------------------------------------------------
58
+ class ImageData(BaseModel):
59
+ filename: str
60
+ data: str # base64 encoded image
61
+
62
+
63
+ class Options(BaseModel):
64
+ process_res_method: Optional[str] = "upper_bound_resize"
65
+ selected_first_frame: Optional[str] = ""
66
+ infer_gs: Optional[bool] = False
67
+ # Optional export tuning (defaults if not provided)
68
+ conf_thresh_percentile: Optional[float] = 40.0
69
+ num_max_points: Optional[int] = 1_000_000
70
+ show_cameras: Optional[bool] = True
71
+ gs_trj_mode: Optional[str] = "extend" # "extend" | "smooth"
72
+ gs_video_quality: Optional[str] = "low" # "low" | "high"
73
+
74
+
75
+ class InferenceRequest(BaseModel):
76
+ images: list[ImageData]
77
+ client_id: str
78
+ options: Optional[Options] = None
79
+
80
+
81
+ class InferenceResponse(BaseModel):
82
+ job_id: str
83
+ status: str = "queued"
84
+
85
+
86
+ # -------------------------------------------------------------------------
87
+ # Model Loading
88
+ # -------------------------------------------------------------------------
89
+ def load_model():
90
+ """Load Depth Anything 3 model on startup (GPU required)"""
91
+ global model, device
92
+
93
+ print("Initializing and loading Depth Anything 3 model...")
94
+ if not torch.cuda.is_available():
95
+ raise RuntimeError("CUDA is not available. GPU is required for DA3 inference.")
96
+
97
+ device = "cuda"
98
+ model_dir = os.getenv("DA3_MODEL_DIR", "depth-anything/DA3NESTED-GIANT-LARGE")
99
+
100
+ # Load from HF Hub or local path
101
+ model = DepthAnything3.from_pretrained(model_dir) # type: ignore
102
+ model = model.to(device)
103
+ model.eval()
104
+
105
+ print(f"Model loaded successfully on {device} from {model_dir}")
106
+
107
+
108
+ # -------------------------------------------------------------------------
109
+ # Helpers
110
+ # -------------------------------------------------------------------------
111
+ def _serialize_bytes(b: bytes) -> str:
112
+ """Serialize raw bytes to base64 string"""
113
+ return base64.b64encode(b).decode("utf-8")
114
+
115
+
116
+ def _serialize_file(path: str) -> str:
117
+ """Serialize a file at 'path' to base64 string"""
118
+ with open(path, "rb") as f:
119
+ return _serialize_bytes(f.read())
120
+
121
+
122
+ def _zip_dir_to_bytes(dir_path: str) -> bytes:
123
+ """Zip a directory and return zip bytes"""
124
+ buffer = io.BytesIO()
125
+ with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zf:
126
+ for root, _, files in os.walk(dir_path):
127
+ for fn in files:
128
+ full = os.path.join(root, fn)
129
+ arcname = os.path.relpath(full, start=dir_path)
130
+ zf.write(full, arcname)
131
+ buffer.seek(0)
132
+ return buffer.read()
133
+
134
+
135
+ def _actual_process_method(name: str) -> str:
136
+ """Map frontend option to actual processing method used by DA3"""
137
+ mapping = {
138
+ "high_res": "lower_bound_resize",
139
+ "low_res": "upper_bound_resize",
140
+ "upper_bound_resize": "upper_bound_resize",
141
+ "lower_bound_resize": "lower_bound_resize",
142
+ "upper_bound_crop": "upper_bound_crop",
143
+ }
144
+ return mapping.get(name or "upper_bound_resize", "upper_bound_resize")
145
+
146
+
147
+ def _save_predictions_npz(target_dir: str, prediction: Any):
148
+ """Save predictions data to predictions.npz for caching."""
149
+ try:
150
+ output_file = os.path.join(target_dir, "predictions.npz")
151
+ save_dict: Dict[str, Any] = {}
152
+
153
+ if getattr(prediction, "processed_images", None) is not None:
154
+ save_dict["images"] = prediction.processed_images
155
+ if getattr(prediction, "depth", None) is not None:
156
+ save_dict["depths"] = np.round(prediction.depth, 6)
157
+ if getattr(prediction, "conf", None) is not None:
158
+ save_dict["conf"] = np.round(prediction.conf, 2)
159
+ if getattr(prediction, "extrinsics", None) is not None:
160
+ save_dict["extrinsics"] = prediction.extrinsics
161
+ if getattr(prediction, "intrinsics", None) is not None:
162
+ save_dict["intrinsics"] = prediction.intrinsics
163
+
164
+ np.savez_compressed(output_file, **save_dict)
165
+ print(f"[backend] Saved predictions cache to: {output_file}")
166
+ except Exception as e:
167
+ print(f"[backend] Warning: Failed to save predictions cache: {e}")
168
+
169
+
170
+ # -------------------------------------------------------------------------
171
+ # Core Inference Function
172
+ # -------------------------------------------------------------------------
173
+ async def run_inference(
174
+ job_id: str,
175
+ target_dir: str,
176
+ client_id: Optional[str] = None,
177
+ options: Optional[Options] = None,
178
+ ):
179
+ """Run DA3 model inference on images and export all artifacts server-side"""
180
+ try:
181
+ # Update job status
182
+ jobs[job_id]["status"] = "processing"
183
+
184
+ # Send WebSocket update (start)
185
+ if client_id and client_id in websocket_connections:
186
+ await websocket_connections[client_id].send_json(
187
+ {"type": "executing", "data": {"job_id": job_id, "node": "start"}}
188
+ )
189
+
190
+ # Load and preprocess images
191
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
192
+ image_names = sorted(image_names)
193
+ print(f"Found {len(image_names)} images for job {job_id}")
194
+
195
+ if len(image_names) == 0:
196
+ raise ValueError("No images found in target directory")
197
+
198
+ # Reorder for selected first frame
199
+ selected_first = options.selected_first_frame if options else ""
200
+ if selected_first:
201
+ sel_path = None
202
+ for p in image_names:
203
+ if os.path.basename(p) == selected_first:
204
+ sel_path = p
205
+ break
206
+ if sel_path:
207
+ image_names = [sel_path] + [p for p in image_names if p != sel_path]
208
+ print(f"Selected first frame: {selected_first} -> {sel_path}")
209
+
210
+ # Send progress updates
211
+ if client_id and client_id in websocket_connections:
212
+ await websocket_connections[client_id].send_json(
213
+ {"type": "executing", "data": {"job_id": job_id, "node": "preprocess"}}
214
+ )
215
+
216
+ # Run inference (do not export during inference; export explicitly below)
217
+ print(f"Running inference for job {job_id}...")
218
+ actual_method = _actual_process_method(
219
+ options.process_res_method if options else "upper_bound_resize"
220
+ )
221
+ with torch.no_grad():
222
+ prediction = model.inference(
223
+ image=image_names,
224
+ process_res_method=actual_method,
225
+ export_dir=None, # export manually below
226
+ export_format="mini_npz",
227
+ infer_gs=bool(options.infer_gs) if options else False,
228
+ )
229
+
230
+ if client_id and client_id in websocket_connections:
231
+ await websocket_connections[client_id].send_json(
232
+ {"type": "executing", "data": {"job_id": job_id, "node": "postprocess"}}
233
+ )
234
+
235
+ # Export GLB and (optional) GS video on backend
236
+ try:
237
+ export_to_glb(
238
+ prediction,
239
+ export_dir=target_dir,
240
+ num_max_points=int(options.num_max_points) if options else 1_000_000,
241
+ conf_thresh_percentile=float(options.conf_thresh_percentile) if options else 40.0,
242
+ show_cameras=bool(options.show_cameras) if options else True,
243
+ )
244
+ print(f"[backend] Exported GLB + depth_vis to {target_dir}")
245
+ except Exception as e:
246
+ print(f"[backend] GLB export failed: {e}")
247
+
248
+ if options and bool(options.infer_gs):
249
+ try:
250
+ mode_mapping = {"extend": "extend", "smooth": "interpolate_smooth"}
251
+ export_to_gs_video(
252
+ prediction,
253
+ export_dir=target_dir,
254
+ chunk_size=4,
255
+ trj_mode=mode_mapping.get(options.gs_trj_mode or "extend", "extend"),
256
+ enable_tqdm=False,
257
+ vis_depth="hcat",
258
+ video_quality=options.gs_video_quality or "low",
259
+ )
260
+ print(f"[backend] Exported GS video to {target_dir}")
261
+ except Exception as e:
262
+ print(f"[backend] GS video export failed: {e}")
263
+
264
+ # Save predictions.npz on backend
265
+ _save_predictions_npz(target_dir, prediction)
266
+
267
+ # Package artifacts
268
+ artifacts: Dict[str, Any] = {}
269
+ glb_path = os.path.join(target_dir, "scene.glb")
270
+ if os.path.exists(glb_path):
271
+ artifacts["glb"] = _serialize_file(glb_path)
272
+
273
+ depth_vis_dir = os.path.join(target_dir, "depth_vis")
274
+ if os.path.isdir(depth_vis_dir):
275
+ try:
276
+ artifacts["depth_vis_zip"] = _serialize_bytes(_zip_dir_to_bytes(depth_vis_dir))
277
+ except Exception as e:
278
+ print(f"[backend] depth_vis zip failed: {e}")
279
+
280
+ npz_path = os.path.join(target_dir, "predictions.npz")
281
+ if os.path.exists(npz_path):
282
+ artifacts["predictions_npz"] = _serialize_file(npz_path)
283
+
284
+ # Optional GS video: search for mp4 under target_dir
285
+ mp4_candidates = glob.glob(os.path.join(target_dir, "*.mp4"))
286
+ if mp4_candidates:
287
+ # take first mp4 (backend exporter may use fixed name)
288
+ artifacts["gs_video"] = _serialize_file(mp4_candidates[0])
289
+
290
+ # Store result
291
+ jobs[job_id]["status"] = "completed"
292
+ jobs[job_id]["result"] = {"artifacts": artifacts}
293
+
294
+ # Send completion via WebSocket
295
+ if client_id and client_id in websocket_connections:
296
+ await websocket_connections[client_id].send_json(
297
+ {
298
+ "type": "executing",
299
+ "data": {
300
+ "job_id": job_id,
301
+ "node": None, # None indicates completion
302
+ },
303
+ }
304
+ )
305
+
306
+ # Clean up
307
+ try:
308
+ torch.cuda.empty_cache()
309
+ except Exception:
310
+ pass
311
+ shutil.rmtree(target_dir, ignore_errors=True)
312
+
313
+ print(f"Job {job_id} completed successfully")
314
+
315
+ except Exception as e:
316
+ print(f"Error in job {job_id}: {str(e)}")
317
+ jobs[job_id]["status"] = "failed"
318
+ jobs[job_id]["error"] = str(e)
319
+
320
+ if client_id and client_id in websocket_connections:
321
+ try:
322
+ await websocket_connections[client_id].send_json(
323
+ {"type": "error", "data": {"job_id": job_id, "error": str(e)}}
324
+ )
325
+ except Exception:
326
+ pass
327
+
328
+
329
+ # -------------------------------------------------------------------------
330
+ # API Endpoints
331
+ # -------------------------------------------------------------------------
332
+ @app.on_event("startup")
333
+ async def startup_event():
334
+ """Load model on startup"""
335
+ load_model()
336
+
337
+
338
+ @app.get("/")
339
+ async def root():
340
+ """Health check endpoint"""
341
+ return {"status": "ok", "service": "Depth Anything 3 Inference API"}
342
+
343
+
344
+ @app.post("/inference")
345
+ async def create_inference(request: InferenceRequest, token: str = Query(...)):
346
+ """
347
+ Submit an inference job
348
+
349
+ Args:
350
+ request: InferenceRequest containing images, client_id, options
351
+ token: Authentication token (currently not validated, for compatibility)
352
+
353
+ Returns:
354
+ InferenceResponse with job_id
355
+ """
356
+ # Generate unique job ID
357
+ job_id = str(uuid.uuid4())
358
+
359
+ # Create temporary directory for images
360
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
361
+ target_dir = f"/tmp/da3_job_{job_id}_{timestamp}"
362
+ target_dir_images = os.path.join(target_dir, "images")
363
+ os.makedirs(target_dir_images, exist_ok=True)
364
+
365
+ # Decode and save images
366
+ try:
367
+ for img_data in request.images:
368
+ img_bytes = base64.b64decode(img_data.data)
369
+ img_path = os.path.join(target_dir_images, img_data.filename)
370
+ with open(img_path, "wb") as f:
371
+ f.write(img_bytes)
372
+
373
+ # Initialize job
374
+ jobs[job_id] = {
375
+ "status": "queued",
376
+ "result": None,
377
+ "created_at": datetime.now().isoformat(),
378
+ }
379
+
380
+ # Start inference in background
381
+ asyncio.create_task(run_inference(job_id, target_dir, request.client_id, request.options))
382
+
383
+ return InferenceResponse(job_id=job_id, status="queued")
384
+
385
+ except Exception as e:
386
+ shutil.rmtree(target_dir, ignore_errors=True)
387
+ raise HTTPException(status_code=400, detail=f"Failed to process images: {str(e)}")
388
+
389
+
390
+ @app.get("/result/{job_id}")
391
+ async def get_result(job_id: str, token: str = Query(...)):
392
+ """
393
+ Get inference result for a job
394
+
395
+ Args:
396
+ job_id: Job ID
397
+ token: Authentication token (currently not validated, for compatibility)
398
+
399
+ Returns:
400
+ Job result with artifacts
401
+ """
402
+ if job_id not in jobs:
403
+ raise HTTPException(status_code=404, detail="Job not found")
404
+
405
+ job = jobs[job_id]
406
+
407
+ if job["status"] == "failed":
408
+ raise HTTPException(status_code=500, detail=job.get("error", "Job failed"))
409
+
410
+ if job["status"] != "completed":
411
+ return {job_id: {"status": job["status"]}}
412
+
413
+ return {job_id: job["result"]}
414
+
415
+
416
+ @app.websocket("/ws")
417
+ async def websocket_endpoint(
418
+ websocket: WebSocket, clientId: str = Query(...), token: str = Query(...)
419
+ ):
420
+ """
421
+ WebSocket endpoint for real-time progress updates
422
+
423
+ Args:
424
+ websocket: WebSocket connection
425
+ clientId: Client ID
426
+ token: Authentication token (currently not validated, for compatibility)
427
+ """
428
+ await websocket.accept()
429
+ websocket_connections[clientId] = websocket
430
+
431
+ try:
432
+ while True:
433
+ # Keep connection alive
434
+ data = await websocket.receive_text()
435
+ # Echo back for heartbeat
436
+ await websocket.send_text(data)
437
+ except Exception as e:
438
+ print(f"WebSocket error for client {clientId}: {str(e)}")
439
+ finally:
440
+ if clientId in websocket_connections:
441
+ del websocket_connections[clientId]
442
+
443
+
444
+ @app.get("/history/{job_id}")
445
+ async def get_history(job_id: str, token: str = Query(...)):
446
+ """
447
+ Get job history (alias for /result for compatibility)
448
+
449
+ Args:
450
+ job_id: Job ID
451
+ token: Authentication token
452
+
453
+ Returns:
454
+ Job history
455
+ """
456
+ return await get_result(job_id, token)
457
+
458
+
459
+ # -------------------------------------------------------------------------
460
+ # Main
461
+ # -------------------------------------------------------------------------
462
+ if __name__ == "__main__":
463
+ # Run server (default port 7860)
464
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
app.py CHANGED
@@ -13,89 +13,416 @@
13
  # limitations under the License.
14
 
15
  """
16
- Hugging Face Spaces App for Depth Anything 3.
17
 
18
- This app uses the @spaces.GPU decorator to dynamically allocate GPU resources
19
- for model inference on Hugging Face Spaces.
 
 
20
  """
21
 
22
  import os
23
- import spaces
 
 
 
 
 
 
 
 
 
 
 
24
  from depth_anything_3.app.gradio_app import DepthAnything3App
25
  from depth_anything_3.app.modules.model_inference import ModelInference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Apply @spaces.GPU decorator to run_inference method
28
- # This ensures GPU operations happen in isolated subprocess
29
- # Model loading and inference will occur in GPU subprocess, not main process
30
- original_run_inference = ModelInference.run_inference
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- @spaces.GPU(duration=120) # Request GPU for up to 120 seconds per inference
33
- def gpu_run_inference(self, *args, **kwargs):
34
  """
35
- GPU-accelerated inference with Spaces decorator.
36
-
37
- This function runs in a GPU subprocess where:
38
- - Model is loaded and moved to GPU (safe)
39
- - CUDA operations are allowed
40
- - All CUDA tensors are moved to CPU before return (for pickle safety)
 
 
 
41
  """
42
- return original_run_inference(self, *args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # Replace the original method with the GPU-decorated version
45
- ModelInference.run_inference = gpu_run_inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Initialize and launch the app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  if __name__ == "__main__":
49
- # Configure directories for Hugging Face Spaces
 
 
 
 
 
 
50
  model_dir = os.environ.get("DA3_MODEL_DIR", "depth-anything/DA3NESTED-GIANT-LARGE")
51
  workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "workspace/gradio")
52
  gallery_dir = os.environ.get("DA3_GALLERY_DIR", "workspace/gallery")
53
-
54
  # Create directories if they don't exist
55
  os.makedirs(workspace_dir, exist_ok=True)
56
  os.makedirs(gallery_dir, exist_ok=True)
57
-
58
- # Initialize the app
59
  app = DepthAnything3App(
60
  model_dir=model_dir,
61
  workspace_dir=workspace_dir,
62
- gallery_dir=gallery_dir
63
  )
64
-
65
  # Check if examples directory exists
66
  examples_dir = os.path.join(workspace_dir, "examples")
67
  examples_exist = os.path.exists(examples_dir)
68
-
69
- # Check if caching is enabled via environment variable (default: True if examples exist)
70
- # Allow disabling via environment variable: DA3_CACHE_EXAMPLES=false
71
  cache_examples_env = os.environ.get("DA3_CACHE_EXAMPLES", "").lower()
72
  if cache_examples_env in ("false", "0", "no"):
73
  cache_examples = False
74
  elif cache_examples_env in ("true", "1", "yes"):
75
  cache_examples = True
76
  else:
77
- # Default: enable caching if examples directory exists
78
  cache_examples = examples_exist
79
-
80
- # Get cache_gs_tag from environment variable (default: "dl3dv")
81
  cache_gs_tag = os.environ.get("DA3_CACHE_GS_TAG", "dl3dv")
82
-
83
- # Launch with Spaces-friendly settings
84
- print("๐Ÿš€ Launching Depth Anything 3 on Hugging Face Spaces...")
85
- print(f"๐Ÿ“ฆ Model Directory: {model_dir}")
 
86
  print(f"๐Ÿ“ Workspace Directory: {workspace_dir}")
87
  print(f"๐Ÿ–ผ๏ธ Gallery Directory: {gallery_dir}")
88
  print(f"๐Ÿ’พ Cache Examples: {cache_examples}")
89
  if cache_examples:
90
  if cache_gs_tag:
91
- print(f"๐Ÿท๏ธ Cache GS Tag: '{cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)")
 
 
92
  else:
93
  print("๐Ÿท๏ธ Cache GS Tag: None (all scenes will use low-res only)")
94
-
95
- # Pre-cache examples if requested
96
  if cache_examples:
97
  print("\n" + "=" * 60)
98
- print("Pre-caching mode enabled")
99
  if cache_gs_tag:
100
  print(f"Scenes containing '{cache_gs_tag}' will use HIGH-RES + 3DGS")
101
  print("Other scenes will use LOW-RES only")
@@ -112,11 +439,10 @@ if __name__ == "__main__":
112
  gs_trj_mode="smooth",
113
  gs_video_quality="low",
114
  )
115
-
116
- # Launch with minimal, Spaces-compatible configuration
117
- # Some parameters may cause routing issues, so we use minimal config
118
  app.launch(
119
- host="0.0.0.0", # Required for Spaces
120
- port=7860, # Standard Gradio port
121
- share=False # Not needed on Spaces
122
  )
 
13
  # limitations under the License.
14
 
15
  """
16
+ Depth Anything 3 Frontend App (Gradio UI) with remote backend inference via WebSocket/HTTP.
17
 
18
+ - Frontend responsibilities remain unchanged (UI, gallery, export glb/3DGS, caching examples)
19
+ - Model inference is delegated to a remote backend specified by DA3_HOST
20
+ - Communication helpers (_open_ws/_submit_inference/_get_result) are defined here (app.py),
21
+ similar to VGGT repo style.
22
  """
23
 
24
  import os
25
+ import glob
26
+ import json
27
+ import uuid
28
+ import base64
29
+ import io
30
+ from typing import Any, Dict, Optional, Tuple
31
+
32
+ import numpy as np
33
+ import requests
34
+ import websocket
35
+ import zipfile
36
+
37
  from depth_anything_3.app.gradio_app import DepthAnything3App
38
  from depth_anything_3.app.modules.model_inference import ModelInference
39
+ from depth_anything_3.specs import Prediction
40
+
41
+ # -------------------------------------------------------------------------
42
+ # Remote Backend Host (must be set)
43
+ # -------------------------------------------------------------------------
44
+ DA3_HOST = os.getenv("DA3_HOST", None) # Expected format: "ip:port"
45
+
46
+
47
+ # -------------------------------------------------------------------------
48
+ # Remote service communication functions (VGGT style)
49
+ # -------------------------------------------------------------------------
50
+ def _open_ws(client_id: str, token: str):
51
+ """Open WebSocket connection to remote DA3 service"""
52
+ if not DA3_HOST:
53
+ raise RuntimeError(
54
+ "DA3_HOST is not set. Please set env DA3_HOST=ip:port for remote inference."
55
+ )
56
+ ws = websocket.WebSocket()
57
+ ws.connect(f"ws://{DA3_HOST}/ws?clientId={client_id}&token={token}", timeout=1800)
58
+ return ws
59
+
60
+
61
+ def _submit_inference(target_dir: str, client_id: str, token: str, options: Dict[str, Any]) -> str:
62
+ """Submit inference job to remote DA3 service"""
63
+ if not DA3_HOST:
64
+ raise RuntimeError(
65
+ "DA3_HOST is not set. Please set env DA3_HOST=ip:port for remote inference."
66
+ )
67
+
68
+ # Prepare image files for upload
69
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
70
+ image_names = sorted(image_names)
71
+
72
+ if len(image_names) == 0:
73
+ raise ValueError("No images found. Check your upload.")
74
+
75
+ # Encode images as base64
76
+ images_data = []
77
+ for img_path in image_names:
78
+ with open(img_path, "rb") as f:
79
+ img_bytes = f.read()
80
+ img_b64 = base64.b64encode(img_bytes).decode("utf-8")
81
+ images_data.append({"filename": os.path.basename(img_path), "data": img_b64})
82
+
83
+ payload = {
84
+ "images": images_data,
85
+ "client_id": client_id,
86
+ "options": options,
87
+ }
88
+
89
+ resp = requests.post(f"http://{DA3_HOST}/inference?token={token}", json=payload, timeout=1800)
90
+ if resp.status_code != 200:
91
+ raise RuntimeError(f"DA3 service /inference error: {resp.text}")
92
+
93
+ data = resp.json()
94
+ if "job_id" not in data:
95
+ raise RuntimeError(f"/inference response missing job_id: {data}")
96
+
97
+ return data["job_id"]
98
+
99
 
100
+ def _get_result(job_id: str, token: str) -> Dict[str, Any]:
101
+ """Get inference result from remote DA3 service"""
102
+ if not DA3_HOST:
103
+ raise RuntimeError(
104
+ "DA3_HOST is not set. Please set env DA3_HOST=ip:port for remote inference."
105
+ )
106
+ resp = requests.get(f"http://{DA3_HOST}/result/{job_id}?token={token}", timeout=1800)
107
+ resp.raise_for_status()
108
+ return resp.json()
109
+
110
+
111
+ def _deserialize_np(b64_str: str) -> Any:
112
+ """Deserialize base64-encoded numpy array saved via np.save into Python object"""
113
+ arr_bytes = base64.b64decode(b64_str)
114
+ return np.load(io.BytesIO(arr_bytes), allow_pickle=True)
115
 
116
+
117
+ def _build_prediction_from_remote(preds: Dict[str, Any]) -> Prediction:
118
  """
119
+ Build a lightweight Prediction object from remote 'predictions' dictionary.
120
+ Expected keys (base64 npy unless otherwise specified):
121
+ - depths: <b64npy> (N,H,W)
122
+ - conf: <b64npy> (N,H,W) [required by export_to_glb]
123
+ - extrinsics: <b64npy> (N,4,4)
124
+ - intrinsics: <b64npy> (N,3,3)
125
+ - processed_images: <b64npy> (N,H,W,3) uint8 [required by export_to_glb]
126
+ - sky_mask: <b64npy> (optional)
127
+ - gaussians: {means, scales, rotations, harmonics, opacities} (optional, each b64npy)
128
  """
129
+ depth = _deserialize_np(preds.get("depths")) if preds.get("depths") is not None else None
130
+ conf = _deserialize_np(preds.get("conf")) if preds.get("conf") is not None else None
131
+ extrinsics = (
132
+ _deserialize_np(preds.get("extrinsics")) if preds.get("extrinsics") is not None else None
133
+ )
134
+ intrinsics = (
135
+ _deserialize_np(preds.get("intrinsics")) if preds.get("intrinsics") is not None else None
136
+ )
137
+ processed_images = (
138
+ _deserialize_np(preds.get("processed_images"))
139
+ if preds.get("processed_images") is not None
140
+ else None
141
+ )
142
+ sky_mask = (
143
+ _deserialize_np(preds.get("sky_mask")) if preds.get("sky_mask") is not None else None
144
+ )
145
+
146
+ # If conf is missing, fallback to ones with same shape as depth to satisfy export_to_glb requirements
147
+ if conf is None and depth is not None:
148
+ conf = np.ones_like(depth, dtype=np.float32)
149
 
150
+ gaussians_obj: Optional[Gaussians] = None
151
+ if preds.get("gaussians") is not None:
152
+ gdict = preds["gaussians"]
153
+ means = _deserialize_np(gdict.get("means")) if gdict.get("means") is not None else None
154
+ scales = _deserialize_np(gdict.get("scales")) if gdict.get("scales") is not None else None
155
+ rotations = (
156
+ _deserialize_np(gdict.get("rotations")) if gdict.get("rotations") is not None else None
157
+ )
158
+ harmonics = (
159
+ _deserialize_np(gdict.get("harmonics")) if gdict.get("harmonics") is not None else None
160
+ )
161
+ opacities = (
162
+ _deserialize_np(gdict.get("opacities")) if gdict.get("opacities") is not None else None
163
+ )
164
+
165
+ # Convert numpy arrays to torch tensors on CPU
166
+ def to_tensor(x):
167
+ return torch.from_numpy(x) if x is not None else None
168
+
169
+ gaussians_obj = Gaussians(
170
+ means=to_tensor(means),
171
+ scales=to_tensor(scales),
172
+ rotations=to_tensor(rotations),
173
+ harmonics=to_tensor(harmonics),
174
+ opacities=to_tensor(opacities),
175
+ )
176
+
177
+ pred = Prediction(
178
+ depth=depth,
179
+ is_metric=1,
180
+ sky=sky_mask,
181
+ conf=conf,
182
+ extrinsics=extrinsics,
183
+ intrinsics=intrinsics,
184
+ processed_images=processed_images,
185
+ gaussians=gaussians_obj,
186
+ aux={}, # optional aux dict
187
+ scale_factor=None,
188
+ )
189
+ return pred
190
 
191
+
192
+ # -------------------------------------------------------------------------
193
+ # Monkey-patch ModelInference.run_inference to use remote backend
194
+ # -------------------------------------------------------------------------
195
+ def remote_run_inference(
196
+ self: ModelInference,
197
+ target_dir: str,
198
+ filter_black_bg: bool = False,
199
+ filter_white_bg: bool = False,
200
+ process_res_method: str = "upper_bound_resize",
201
+ show_camera: bool = True,
202
+ selected_first_frame: Optional[str] = None,
203
+ save_percentage: float = 30.0,
204
+ num_max_points: int = 1_000_000,
205
+ infer_gs: bool = False,
206
+ gs_trj_mode: str = "extend",
207
+ gs_video_quality: str = "high",
208
+ ) -> Tuple[Any, Dict[int, Dict[str, Any]]]:
209
+ """
210
+ Remote inference via DA3_HOST. Frontend ONLY consumes artifacts returned by backend:
211
+ - Writes scene.glb, depth_vis/, predictions.npz, (optional) gs_video.mp4 into target_dir
212
+ - Builds processed_data dict from files
213
+ - Returns (prediction, processed_data) where prediction is reconstructed from predictions.npz
214
+ """
215
+ if not DA3_HOST:
216
+ raise RuntimeError(
217
+ "DA3_HOST is not set. Please set env DA3_HOST=ip:port for remote inference."
218
+ )
219
+
220
+ # Validate images exist
221
+ image_folder_path = os.path.join(target_dir, "images")
222
+ all_image_paths = sorted(glob.glob(os.path.join(image_folder_path, "*")))
223
+ if len(all_image_paths) == 0:
224
+ raise ValueError("No images found. Check your upload.")
225
+
226
+ # Compose options to send to backend (no export on frontend)
227
+ options = {
228
+ "process_res_method": process_res_method,
229
+ "selected_first_frame": selected_first_frame or "",
230
+ "infer_gs": bool(infer_gs),
231
+ "conf_thresh_percentile": float(save_percentage),
232
+ "num_max_points": int(num_max_points),
233
+ "show_cameras": bool(show_camera),
234
+ "gs_trj_mode": gs_trj_mode,
235
+ "gs_video_quality": gs_video_quality,
236
+ }
237
+
238
+ # IDs and WebSocket
239
+ client_id = str(uuid.uuid4())
240
+ token = str(uuid.uuid4())
241
+ ws = _open_ws(client_id, token)
242
+
243
+ # Submit inference job
244
+ job_id = _submit_inference(target_dir, client_id, token, options)
245
+
246
+ # Monitor progress via WebSocket
247
+ ws.settimeout(180)
248
+ try:
249
+ while True:
250
+ out = ws.recv()
251
+ if isinstance(out, (bytes, bytearray)):
252
+ continue
253
+ msg = json.loads(out)
254
+ if msg.get("type") == "executing":
255
+ data = msg.get("data", {})
256
+ if data.get("job_id") != job_id:
257
+ continue
258
+ node = data.get("node")
259
+ if node is None:
260
+ # Job complete
261
+ break
262
+ except Exception as e:
263
+ print(f"WebSocket error: {e}")
264
+ finally:
265
+ try:
266
+ ws.close()
267
+ except Exception:
268
+ pass
269
+
270
+ # Fetch final result
271
+ result = _get_result(job_id, token)
272
+ if job_id not in result:
273
+ raise RuntimeError(f"Remote result missing job_id entry: {result}")
274
+ job_entry = result[job_id]
275
+ if job_entry.get("status") != "completed":
276
+ raise RuntimeError(f"Remote job not completed or failed: {job_entry}")
277
+
278
+ artifacts = job_entry.get("artifacts", {})
279
+ if not artifacts:
280
+ raise RuntimeError(f"No artifacts returned from backend for job {job_id}")
281
+
282
+ # Write artifacts to target_dir
283
+ os.makedirs(target_dir, exist_ok=True)
284
+
285
+ # scene.glb
286
+ glb_b64 = artifacts.get("glb")
287
+ if glb_b64:
288
+ with open(os.path.join(target_dir, "scene.glb"), "wb") as f:
289
+ f.write(base64.b64decode(glb_b64))
290
+
291
+ # depth_vis
292
+ depth_vis_b64 = artifacts.get("depth_vis_zip")
293
+ if depth_vis_b64:
294
+ depth_vis_dir = os.path.join(target_dir, "depth_vis")
295
+ os.makedirs(depth_vis_dir, exist_ok=True)
296
+ zip_bytes = base64.b64decode(depth_vis_b64)
297
+ with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zf:
298
+ zf.extractall(depth_vis_dir)
299
+
300
+ # predictions.npz
301
+ pred_npz_b64 = artifacts.get("predictions_npz")
302
+ prediction: Any = None
303
+ if pred_npz_b64:
304
+ npz_path = os.path.join(target_dir, "predictions.npz")
305
+ with open(npz_path, "wb") as f:
306
+ f.write(base64.b64decode(pred_npz_b64))
307
+ try:
308
+ loaded = np.load(npz_path, allow_pickle=True)
309
+ # reconstruct Prediction dataclass from npz content
310
+ images = loaded["images"] if "images" in loaded.files else None
311
+ depths = loaded["depths"] if "depths" in loaded.files else None
312
+ conf = loaded["conf"] if "conf" in loaded.files else None
313
+ extrinsics = loaded["extrinsics"] if "extrinsics" in loaded.files else None
314
+ intrinsics = loaded["intrinsics"] if "intrinsics" in loaded.files else None
315
+
316
+ prediction = Prediction(
317
+ depth=depths,
318
+ is_metric=1,
319
+ sky=None,
320
+ conf=(
321
+ conf
322
+ if conf is not None
323
+ else (np.ones_like(depths, dtype=np.float32) if depths is not None else None)
324
+ ),
325
+ extrinsics=extrinsics,
326
+ intrinsics=intrinsics,
327
+ processed_images=images,
328
+ gaussians=None,
329
+ aux={},
330
+ scale_factor=None,
331
+ )
332
+ except Exception as e:
333
+ print(f"Failed to reconstruct Prediction from predictions.npz: {e}")
334
+ prediction = Prediction(
335
+ depth=None,
336
+ is_metric=1,
337
+ sky=None,
338
+ conf=None,
339
+ extrinsics=None,
340
+ intrinsics=None,
341
+ processed_images=None,
342
+ gaussians=None,
343
+ aux={},
344
+ scale_factor=None,
345
+ )
346
+
347
+ # Optional GS video
348
+ gs_video_b64 = artifacts.get("gs_video")
349
+ if gs_video_b64:
350
+ gs_dir = os.path.join(target_dir, "gs_video")
351
+ os.makedirs(gs_dir, exist_ok=True)
352
+ with open(os.path.join(gs_dir, "gs_video.mp4"), "wb") as f:
353
+ f.write(base64.b64decode(gs_video_b64))
354
+
355
+ # Build processed_data from files (depth_vis + optional images from predictions.npz)
356
+ processed_data = self._process_results(target_dir, prediction, all_image_paths)
357
+
358
+ return prediction, processed_data
359
+
360
+
361
+ # Replace original ModelInference.run_inference with remote version
362
+ ModelInference.run_inference = remote_run_inference
363
+
364
+
365
+ # -------------------------------------------------------------------------
366
+ # Initialize and launch the frontend app (unchanged UI behavior)
367
+ # -------------------------------------------------------------------------
368
  if __name__ == "__main__":
369
+ # Enforce remote backend configuration
370
+ if not DA3_HOST:
371
+ raise RuntimeError(
372
+ "DA3_HOST is not set. Please export DA3_HOST=ip:port to use remote backend inference."
373
+ )
374
+
375
+ # Configure directories for frontend workspace/gallery
376
  model_dir = os.environ.get("DA3_MODEL_DIR", "depth-anything/DA3NESTED-GIANT-LARGE")
377
  workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "workspace/gradio")
378
  gallery_dir = os.environ.get("DA3_GALLERY_DIR", "workspace/gallery")
379
+
380
  # Create directories if they don't exist
381
  os.makedirs(workspace_dir, exist_ok=True)
382
  os.makedirs(gallery_dir, exist_ok=True)
383
+
384
+ # Initialize the app (frontend UI)
385
  app = DepthAnything3App(
386
  model_dir=model_dir,
387
  workspace_dir=workspace_dir,
388
+ gallery_dir=gallery_dir,
389
  )
390
+
391
  # Check if examples directory exists
392
  examples_dir = os.path.join(workspace_dir, "examples")
393
  examples_exist = os.path.exists(examples_dir)
394
+
395
+ # Check caching (default: True if examples exist)
 
396
  cache_examples_env = os.environ.get("DA3_CACHE_EXAMPLES", "").lower()
397
  if cache_examples_env in ("false", "0", "no"):
398
  cache_examples = False
399
  elif cache_examples_env in ("true", "1", "yes"):
400
  cache_examples = True
401
  else:
 
402
  cache_examples = examples_exist
403
+
404
+ # Cache tag for 3DGS
405
  cache_gs_tag = os.environ.get("DA3_CACHE_GS_TAG", "dl3dv")
406
+
407
+ # Launch logs
408
+ print("๐Ÿš€ Launching Depth Anything 3 Frontend (remote backend mode)...")
409
+ print(f"๐ŸŒ DA3_HOST (backend): {DA3_HOST}")
410
+ print(f"๐Ÿ“ฆ Model Directory (frontend env only): {model_dir}")
411
  print(f"๐Ÿ“ Workspace Directory: {workspace_dir}")
412
  print(f"๐Ÿ–ผ๏ธ Gallery Directory: {gallery_dir}")
413
  print(f"๐Ÿ’พ Cache Examples: {cache_examples}")
414
  if cache_examples:
415
  if cache_gs_tag:
416
+ print(
417
+ f"๐Ÿท๏ธ Cache GS Tag: '{cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)"
418
+ )
419
  else:
420
  print("๐Ÿท๏ธ Cache GS Tag: None (all scenes will use low-res only)")
421
+
422
+ # Pre-cache examples (requests inference from remote backend; artifacts still stored locally)
423
  if cache_examples:
424
  print("\n" + "=" * 60)
425
+ print("Pre-caching mode enabled (remote backend inference)")
426
  if cache_gs_tag:
427
  print(f"Scenes containing '{cache_gs_tag}' will use HIGH-RES + 3DGS")
428
  print("Other scenes will use LOW-RES only")
 
439
  gs_trj_mode="smooth",
440
  gs_video_quality="low",
441
  )
442
+
443
+ # Launch Gradio frontend (minimal, Spaces-compatible configuration)
 
444
  app.launch(
445
+ host="0.0.0.0",
446
+ port=7860,
447
+ share=False,
448
  )
pyproject.toml CHANGED
@@ -14,14 +14,14 @@ authors = [{ name = "Your Name" }]
14
  dependencies = [
15
  "pre-commit",
16
  "trimesh",
17
- "torch>=2",
18
- "torchvision",
19
  "einops",
20
  "huggingface_hub",
21
  "imageio",
22
  "numpy<2",
23
  "opencv-python",
24
- "xformers",
25
  "open3d",
26
  "fastapi",
27
  "unicorn",
 
14
  dependencies = [
15
  "pre-commit",
16
  "trimesh",
17
+ # "torch>=2",
18
+ # "torchvision",
19
  "einops",
20
  "huggingface_hub",
21
  "imageio",
22
  "numpy<2",
23
  "opencv-python",
24
+ # "xformers",
25
  "open3d",
26
  "fastapi",
27
  "unicorn",
requirements.txt CHANGED
@@ -1,8 +1,6 @@
1
  # Core dependencies
2
  torch>=2.0.0
3
- torchvision
4
  einops
5
- huggingface_hub
6
  numpy<2
7
  opencv-python
8
 
@@ -10,12 +8,8 @@ opencv-python
10
  gradio>=5.0.0
11
  spaces
12
  pillow>=9.0
13
- evo
14
 
15
  # 3D and visualization
16
- trimesh
17
- open3d
18
- plyfile
19
 
20
  # Image processing
21
  imageio
@@ -23,26 +17,11 @@ pillow_heif
23
  safetensors
24
 
25
  # Video processing
26
- moviepy==1.0.3
27
 
28
  # Math and geometry
29
- e3nn
30
 
31
  # Utilities
32
  requests
 
33
  omegaconf
34
  typer>=0.9.0
35
-
36
- # Web frameworks (if using API features)
37
- fastapi
38
- uvicorn
39
-
40
- # xformers - commented out due to potential build issues on Spaces
41
- # If needed, uncomment and use a version compatible with your PyTorch/CUDA:
42
- # xformers==0.0.22
43
- # Or install after deployment: pip install xformers --no-deps
44
-
45
- # 3D Gaussian Splatting
46
- # Note: This requires CUDA during build. If build fails on Spaces, see alternative solutions.
47
- gsplat @ https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.3/gsplat-1.5.3+pt24cu124-cp310-cp310-linux_x86_64.whl
48
-
 
1
  # Core dependencies
2
  torch>=2.0.0
 
3
  einops
 
4
  numpy<2
5
  opencv-python
6
 
 
8
  gradio>=5.0.0
9
  spaces
10
  pillow>=9.0
 
11
 
12
  # 3D and visualization
 
 
 
13
 
14
  # Image processing
15
  imageio
 
17
  safetensors
18
 
19
  # Video processing
 
20
 
21
  # Math and geometry
 
22
 
23
  # Utilities
24
  requests
25
+ websocket-client
26
  omegaconf
27
  typer>=0.9.0