#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # Backend API server for VGGT model inference import os import sys import asyncio import base64 import io import json import uuid from typing import Dict, Any, Optional from datetime import datetime import glob import shutil import numpy as np import torch from fastapi import FastAPI, WebSocket, HTTPException, Query from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import uvicorn sys.path.append("vggt/") from vggt.models.vggt import VGGT from vggt.utils.load_fn import load_and_preprocess_images from vggt.utils.pose_enc import pose_encoding_to_extri_intri from vggt.utils.geometry import unproject_depth_map_to_point_map # Initialize FastAPI app app = FastAPI(title="VGGT Inference API", version="1.0.0") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global model instance model = None device = None # Job storage: {job_id: {"status": "processing/completed/failed", "result": {...}, "progress": 0}} jobs: Dict[str, Dict[str, Any]] = {} # WebSocket connections: {client_id: websocket} websocket_connections: Dict[str, WebSocket] = {} # ------------------------------------------------------------------------- # Request/Response Models # ------------------------------------------------------------------------- class ImageData(BaseModel): filename: str data: str # base64 encoded image class InferenceRequest(BaseModel): images: list[ImageData] client_id: str class InferenceResponse(BaseModel): job_id: str status: str = "queued" # ------------------------------------------------------------------------- # Model Loading # ------------------------------------------------------------------------- def load_model(): """Load VGGT model on startup""" global model, device print("Initializing and loading VGGT model...") device = "cuda" if torch.cuda.is_available() else "cpu" if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available. GPU is required for VGGT inference.") model = VGGT() _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) model = model.to(device) model.eval() print(f"Model loaded successfully on {device}") # ------------------------------------------------------------------------- # Core Inference Function # ------------------------------------------------------------------------- async def run_inference(job_id: str, target_dir: str, client_id: Optional[str] = None): """Run VGGT model inference on images""" try: # Update job status jobs[job_id]["status"] = "processing" # Send WebSocket update if client_id and client_id in websocket_connections: await websocket_connections[client_id].send_json( {"type": "executing", "data": {"job_id": job_id, "node": "start"}} ) # Load and preprocess images image_names = glob.glob(os.path.join(target_dir, "images", "*")) image_names = sorted(image_names) print(f"Found {len(image_names)} images for job {job_id}") if len(image_names) == 0: raise ValueError("No images found in target directory") images = load_and_preprocess_images(image_names).to(device) print(f"Preprocessed images shape: {images.shape}") # Run inference print(f"Running inference for job {job_id}...") with torch.no_grad(): with torch.cuda.amp.autocast(dtype=torch.bfloat16): predictions = model(images) # Send progress updates via WebSocket total_nodes = len(predictions) for i, key in enumerate(predictions.keys()): if client_id and client_id in websocket_connections: await websocket_connections[client_id].send_json( {"type": "executing", "data": {"job_id": job_id, "node": key}} ) await asyncio.sleep(0.01) # Small delay for progress updates # Convert pose encoding to extrinsic and intrinsic matrices print("Converting pose encoding to extrinsic and intrinsic matrices...") extrinsic, intrinsic = pose_encoding_to_extri_intri( predictions["pose_enc"], images.shape[-2:] ) predictions["extrinsic"] = extrinsic predictions["intrinsic"] = intrinsic # Convert tensors to numpy predictions_numpy = {} for key in predictions.keys(): if isinstance(predictions[key], torch.Tensor): predictions_numpy[key] = predictions[key].cpu().numpy().squeeze(0) else: predictions_numpy[key] = predictions[key] # Generate world points from depth map print("Computing world points from depth map...") depth_map = predictions_numpy["depth"] world_points = unproject_depth_map_to_point_map( depth_map, predictions_numpy["extrinsic"], predictions_numpy["intrinsic"] ) predictions_numpy["world_points_from_depth"] = world_points # Serialize predictions to base64-encoded numpy arrays serialized_predictions = {} for key, value in predictions_numpy.items(): if isinstance(value, np.ndarray): # Save numpy array to bytes buffer = io.BytesIO() np.save(buffer, value, allow_pickle=True) buffer.seek(0) # Encode as base64 serialized_predictions[key] = base64.b64encode(buffer.read()).decode( "utf-8" ) else: serialized_predictions[key] = value # Store result jobs[job_id]["status"] = "completed" jobs[job_id]["result"] = {"predictions": serialized_predictions} # Send completion via WebSocket if client_id and client_id in websocket_connections: await websocket_connections[client_id].send_json( { "type": "executing", "data": { "job_id": job_id, "node": None, }, # None indicates completion } ) # Clean up torch.cuda.empty_cache() shutil.rmtree(target_dir, ignore_errors=True) print(f"Job {job_id} completed successfully") except Exception as e: print(f"Error in job {job_id}: {str(e)}") jobs[job_id]["status"] = "failed" jobs[job_id]["error"] = str(e) if client_id and client_id in websocket_connections: await websocket_connections[client_id].send_json( {"type": "error", "data": {"job_id": job_id, "error": str(e)}} ) # ------------------------------------------------------------------------- # API Endpoints # ------------------------------------------------------------------------- @app.on_event("startup") async def startup_event(): """Load model on startup""" load_model() @app.get("/") async def root(): """Health check endpoint""" return {"status": "ok", "service": "VGGT Inference API"} @app.post("/inference") async def create_inference(request: InferenceRequest, token: str = Query(...)): """ Submit an inference job Args: request: InferenceRequest containing images and client_id token: Authentication token (currently not validated, for compatibility) Returns: InferenceResponse with job_id """ # Generate unique job ID job_id = str(uuid.uuid4()) # Create temporary directory for images timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"/tmp/vggt_job_{job_id}_{timestamp}" target_dir_images = os.path.join(target_dir, "images") os.makedirs(target_dir_images, exist_ok=True) # Decode and save images try: for img_data in request.images: img_bytes = base64.b64decode(img_data.data) img_path = os.path.join(target_dir_images, img_data.filename) with open(img_path, "wb") as f: f.write(img_bytes) # Initialize job jobs[job_id] = { "status": "queued", "result": None, "created_at": datetime.now().isoformat(), } # Start inference in background asyncio.create_task(run_inference(job_id, target_dir, request.client_id)) return InferenceResponse(job_id=job_id, status="queued") except Exception as e: shutil.rmtree(target_dir, ignore_errors=True) raise HTTPException( status_code=400, detail=f"Failed to process images: {str(e)}" ) @app.get("/result/{job_id}") async def get_result(job_id: str, token: str = Query(...)): """ Get inference result for a job Args: job_id: Job ID token: Authentication token (currently not validated, for compatibility) Returns: Job result with predictions """ if job_id not in jobs: raise HTTPException(status_code=404, detail="Job not found") job = jobs[job_id] if job["status"] == "failed": raise HTTPException(status_code=500, detail=job.get("error", "Job failed")) if job["status"] != "completed": return {job_id: {"status": job["status"]}} return {job_id: job["result"]} @app.websocket("/ws") async def websocket_endpoint( websocket: WebSocket, clientId: str = Query(...), token: str = Query(...) ): """ WebSocket endpoint for real-time progress updates Args: websocket: WebSocket connection clientId: Client ID token: Authentication token (currently not validated, for compatibility) """ await websocket.accept() websocket_connections[clientId] = websocket try: while True: # Keep connection alive data = await websocket.receive_text() # Echo back for heartbeat await websocket.send_text(data) except Exception as e: print(f"WebSocket error for client {clientId}: {str(e)}") finally: if clientId in websocket_connections: del websocket_connections[clientId] @app.get("/history/{job_id}") async def get_history(job_id: str, token: str = Query(...)): """ Get job history (alias for /result for compatibility) Args: job_id: Job ID token: Authentication token Returns: Job history """ return await get_result(job_id, token) # ------------------------------------------------------------------------- # Main # ------------------------------------------------------------------------- if __name__ == "__main__": # Run server uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")