vggt / api_server.py
lidavidsh's picture
add api_server.py
c788f41
raw
history blame
11.4 kB
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Backend API server for VGGT model inference
import os
import sys
import asyncio
import base64
import io
import json
import uuid
from typing import Dict, Any, Optional
from datetime import datetime
import glob
import shutil
import numpy as np
import torch
from fastapi import FastAPI, WebSocket, HTTPException, Query
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
sys.path.append("vggt/")
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map
# Initialize FastAPI app
app = FastAPI(title="VGGT Inference API", version="1.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model instance
model = None
device = None
# Job storage: {job_id: {"status": "processing/completed/failed", "result": {...}, "progress": 0}}
jobs: Dict[str, Dict[str, Any]] = {}
# WebSocket connections: {client_id: websocket}
websocket_connections: Dict[str, WebSocket] = {}
# -------------------------------------------------------------------------
# Request/Response Models
# -------------------------------------------------------------------------
class ImageData(BaseModel):
filename: str
data: str # base64 encoded image
class InferenceRequest(BaseModel):
images: list[ImageData]
client_id: str
class InferenceResponse(BaseModel):
job_id: str
status: str = "queued"
# -------------------------------------------------------------------------
# Model Loading
# -------------------------------------------------------------------------
def load_model():
"""Load VGGT model on startup"""
global model, device
print("Initializing and loading VGGT model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. GPU is required for VGGT inference.")
model = VGGT()
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
model = model.to(device)
model.eval()
print(f"Model loaded successfully on {device}")
# -------------------------------------------------------------------------
# Core Inference Function
# -------------------------------------------------------------------------
async def run_inference(job_id: str, target_dir: str, client_id: Optional[str] = None):
"""Run VGGT model inference on images"""
try:
# Update job status
jobs[job_id]["status"] = "processing"
# Send WebSocket update
if client_id and client_id in websocket_connections:
await websocket_connections[client_id].send_json(
{"type": "executing", "data": {"job_id": job_id, "node": "start"}}
)
# Load and preprocess images
image_names = glob.glob(os.path.join(target_dir, "images", "*"))
image_names = sorted(image_names)
print(f"Found {len(image_names)} images for job {job_id}")
if len(image_names) == 0:
raise ValueError("No images found in target directory")
images = load_and_preprocess_images(image_names).to(device)
print(f"Preprocessed images shape: {images.shape}")
# Run inference
print(f"Running inference for job {job_id}...")
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
predictions = model(images)
# Send progress updates via WebSocket
total_nodes = len(predictions)
for i, key in enumerate(predictions.keys()):
if client_id and client_id in websocket_connections:
await websocket_connections[client_id].send_json(
{"type": "executing", "data": {"job_id": job_id, "node": key}}
)
await asyncio.sleep(0.01) # Small delay for progress updates
# Convert pose encoding to extrinsic and intrinsic matrices
print("Converting pose encoding to extrinsic and intrinsic matrices...")
extrinsic, intrinsic = pose_encoding_to_extri_intri(
predictions["pose_enc"], images.shape[-2:]
)
predictions["extrinsic"] = extrinsic
predictions["intrinsic"] = intrinsic
# Convert tensors to numpy
predictions_numpy = {}
for key in predictions.keys():
if isinstance(predictions[key], torch.Tensor):
predictions_numpy[key] = predictions[key].cpu().numpy().squeeze(0)
else:
predictions_numpy[key] = predictions[key]
# Generate world points from depth map
print("Computing world points from depth map...")
depth_map = predictions_numpy["depth"]
world_points = unproject_depth_map_to_point_map(
depth_map, predictions_numpy["extrinsic"], predictions_numpy["intrinsic"]
)
predictions_numpy["world_points_from_depth"] = world_points
# Serialize predictions to base64-encoded numpy arrays
serialized_predictions = {}
for key, value in predictions_numpy.items():
if isinstance(value, np.ndarray):
# Save numpy array to bytes
buffer = io.BytesIO()
np.save(buffer, value, allow_pickle=True)
buffer.seek(0)
# Encode as base64
serialized_predictions[key] = base64.b64encode(buffer.read()).decode(
"utf-8"
)
else:
serialized_predictions[key] = value
# Store result
jobs[job_id]["status"] = "completed"
jobs[job_id]["result"] = {"predictions": serialized_predictions}
# Send completion via WebSocket
if client_id and client_id in websocket_connections:
await websocket_connections[client_id].send_json(
{
"type": "executing",
"data": {
"job_id": job_id,
"node": None,
}, # None indicates completion
}
)
# Clean up
torch.cuda.empty_cache()
shutil.rmtree(target_dir, ignore_errors=True)
print(f"Job {job_id} completed successfully")
except Exception as e:
print(f"Error in job {job_id}: {str(e)}")
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
if client_id and client_id in websocket_connections:
await websocket_connections[client_id].send_json(
{"type": "error", "data": {"job_id": job_id, "error": str(e)}}
)
# -------------------------------------------------------------------------
# API Endpoints
# -------------------------------------------------------------------------
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
load_model()
@app.get("/")
async def root():
"""Health check endpoint"""
return {"status": "ok", "service": "VGGT Inference API"}
@app.post("/inference")
async def create_inference(request: InferenceRequest, token: str = Query(...)):
"""
Submit an inference job
Args:
request: InferenceRequest containing images and client_id
token: Authentication token (currently not validated, for compatibility)
Returns:
InferenceResponse with job_id
"""
# Generate unique job ID
job_id = str(uuid.uuid4())
# Create temporary directory for images
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = f"/tmp/vggt_job_{job_id}_{timestamp}"
target_dir_images = os.path.join(target_dir, "images")
os.makedirs(target_dir_images, exist_ok=True)
# Decode and save images
try:
for img_data in request.images:
img_bytes = base64.b64decode(img_data.data)
img_path = os.path.join(target_dir_images, img_data.filename)
with open(img_path, "wb") as f:
f.write(img_bytes)
# Initialize job
jobs[job_id] = {
"status": "queued",
"result": None,
"created_at": datetime.now().isoformat(),
}
# Start inference in background
asyncio.create_task(run_inference(job_id, target_dir, request.client_id))
return InferenceResponse(job_id=job_id, status="queued")
except Exception as e:
shutil.rmtree(target_dir, ignore_errors=True)
raise HTTPException(
status_code=400, detail=f"Failed to process images: {str(e)}"
)
@app.get("/result/{job_id}")
async def get_result(job_id: str, token: str = Query(...)):
"""
Get inference result for a job
Args:
job_id: Job ID
token: Authentication token (currently not validated, for compatibility)
Returns:
Job result with predictions
"""
if job_id not in jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = jobs[job_id]
if job["status"] == "failed":
raise HTTPException(status_code=500, detail=job.get("error", "Job failed"))
if job["status"] != "completed":
return {job_id: {"status": job["status"]}}
return {job_id: job["result"]}
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket, clientId: str = Query(...), token: str = Query(...)
):
"""
WebSocket endpoint for real-time progress updates
Args:
websocket: WebSocket connection
clientId: Client ID
token: Authentication token (currently not validated, for compatibility)
"""
await websocket.accept()
websocket_connections[clientId] = websocket
try:
while True:
# Keep connection alive
data = await websocket.receive_text()
# Echo back for heartbeat
await websocket.send_text(data)
except Exception as e:
print(f"WebSocket error for client {clientId}: {str(e)}")
finally:
if clientId in websocket_connections:
del websocket_connections[clientId]
@app.get("/history/{job_id}")
async def get_history(job_id: str, token: str = Query(...)):
"""
Get job history (alias for /result for compatibility)
Args:
job_id: Job ID
token: Authentication token
Returns:
Job history
"""
return await get_result(job_id, token)
# -------------------------------------------------------------------------
# Main
# -------------------------------------------------------------------------
if __name__ == "__main__":
# Run server
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")