lidavidsh commited on
Commit
c788f41
·
1 Parent(s): 3adf7d3

add api_server.py

Browse files
Files changed (1) hide show
  1. api_server.py +345 -0
api_server.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # Backend API server for VGGT model inference
4
+
5
+ import os
6
+ import sys
7
+ import asyncio
8
+ import base64
9
+ import io
10
+ import json
11
+ import uuid
12
+ from typing import Dict, Any, Optional
13
+ from datetime import datetime
14
+ import glob
15
+ import shutil
16
+
17
+ import numpy as np
18
+ import torch
19
+ from fastapi import FastAPI, WebSocket, HTTPException, Query
20
+ from fastapi.responses import JSONResponse
21
+ from fastapi.middleware.cors import CORSMiddleware
22
+ from pydantic import BaseModel, Field
23
+ import uvicorn
24
+
25
+ sys.path.append("vggt/")
26
+
27
+ from vggt.models.vggt import VGGT
28
+ from vggt.utils.load_fn import load_and_preprocess_images
29
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
30
+ from vggt.utils.geometry import unproject_depth_map_to_point_map
31
+
32
+ # Initialize FastAPI app
33
+ app = FastAPI(title="VGGT Inference API", version="1.0.0")
34
+
35
+ # Add CORS middleware
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"],
39
+ allow_credentials=True,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+
44
+ # Global model instance
45
+ model = None
46
+ device = None
47
+
48
+ # Job storage: {job_id: {"status": "processing/completed/failed", "result": {...}, "progress": 0}}
49
+ jobs: Dict[str, Dict[str, Any]] = {}
50
+
51
+ # WebSocket connections: {client_id: websocket}
52
+ websocket_connections: Dict[str, WebSocket] = {}
53
+
54
+
55
+ # -------------------------------------------------------------------------
56
+ # Request/Response Models
57
+ # -------------------------------------------------------------------------
58
+ class ImageData(BaseModel):
59
+ filename: str
60
+ data: str # base64 encoded image
61
+
62
+
63
+ class InferenceRequest(BaseModel):
64
+ images: list[ImageData]
65
+ client_id: str
66
+
67
+
68
+ class InferenceResponse(BaseModel):
69
+ job_id: str
70
+ status: str = "queued"
71
+
72
+
73
+ # -------------------------------------------------------------------------
74
+ # Model Loading
75
+ # -------------------------------------------------------------------------
76
+ def load_model():
77
+ """Load VGGT model on startup"""
78
+ global model, device
79
+
80
+ print("Initializing and loading VGGT model...")
81
+ device = "cuda" if torch.cuda.is_available() else "cpu"
82
+
83
+ if not torch.cuda.is_available():
84
+ raise RuntimeError("CUDA is not available. GPU is required for VGGT inference.")
85
+
86
+ model = VGGT()
87
+ _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
88
+ model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
89
+ model = model.to(device)
90
+ model.eval()
91
+
92
+ print(f"Model loaded successfully on {device}")
93
+
94
+
95
+ # -------------------------------------------------------------------------
96
+ # Core Inference Function
97
+ # -------------------------------------------------------------------------
98
+ async def run_inference(job_id: str, target_dir: str, client_id: Optional[str] = None):
99
+ """Run VGGT model inference on images"""
100
+ try:
101
+ # Update job status
102
+ jobs[job_id]["status"] = "processing"
103
+
104
+ # Send WebSocket update
105
+ if client_id and client_id in websocket_connections:
106
+ await websocket_connections[client_id].send_json(
107
+ {"type": "executing", "data": {"job_id": job_id, "node": "start"}}
108
+ )
109
+
110
+ # Load and preprocess images
111
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
112
+ image_names = sorted(image_names)
113
+ print(f"Found {len(image_names)} images for job {job_id}")
114
+
115
+ if len(image_names) == 0:
116
+ raise ValueError("No images found in target directory")
117
+
118
+ images = load_and_preprocess_images(image_names).to(device)
119
+ print(f"Preprocessed images shape: {images.shape}")
120
+
121
+ # Run inference
122
+ print(f"Running inference for job {job_id}...")
123
+ with torch.no_grad():
124
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
125
+ predictions = model(images)
126
+
127
+ # Send progress updates via WebSocket
128
+ total_nodes = len(predictions)
129
+ for i, key in enumerate(predictions.keys()):
130
+ if client_id and client_id in websocket_connections:
131
+ await websocket_connections[client_id].send_json(
132
+ {"type": "executing", "data": {"job_id": job_id, "node": key}}
133
+ )
134
+ await asyncio.sleep(0.01) # Small delay for progress updates
135
+
136
+ # Convert pose encoding to extrinsic and intrinsic matrices
137
+ print("Converting pose encoding to extrinsic and intrinsic matrices...")
138
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(
139
+ predictions["pose_enc"], images.shape[-2:]
140
+ )
141
+ predictions["extrinsic"] = extrinsic
142
+ predictions["intrinsic"] = intrinsic
143
+
144
+ # Convert tensors to numpy
145
+ predictions_numpy = {}
146
+ for key in predictions.keys():
147
+ if isinstance(predictions[key], torch.Tensor):
148
+ predictions_numpy[key] = predictions[key].cpu().numpy().squeeze(0)
149
+ else:
150
+ predictions_numpy[key] = predictions[key]
151
+
152
+ # Generate world points from depth map
153
+ print("Computing world points from depth map...")
154
+ depth_map = predictions_numpy["depth"]
155
+ world_points = unproject_depth_map_to_point_map(
156
+ depth_map, predictions_numpy["extrinsic"], predictions_numpy["intrinsic"]
157
+ )
158
+ predictions_numpy["world_points_from_depth"] = world_points
159
+
160
+ # Serialize predictions to base64-encoded numpy arrays
161
+ serialized_predictions = {}
162
+ for key, value in predictions_numpy.items():
163
+ if isinstance(value, np.ndarray):
164
+ # Save numpy array to bytes
165
+ buffer = io.BytesIO()
166
+ np.save(buffer, value, allow_pickle=True)
167
+ buffer.seek(0)
168
+ # Encode as base64
169
+ serialized_predictions[key] = base64.b64encode(buffer.read()).decode(
170
+ "utf-8"
171
+ )
172
+ else:
173
+ serialized_predictions[key] = value
174
+
175
+ # Store result
176
+ jobs[job_id]["status"] = "completed"
177
+ jobs[job_id]["result"] = {"predictions": serialized_predictions}
178
+
179
+ # Send completion via WebSocket
180
+ if client_id and client_id in websocket_connections:
181
+ await websocket_connections[client_id].send_json(
182
+ {
183
+ "type": "executing",
184
+ "data": {
185
+ "job_id": job_id,
186
+ "node": None,
187
+ }, # None indicates completion
188
+ }
189
+ )
190
+
191
+ # Clean up
192
+ torch.cuda.empty_cache()
193
+ shutil.rmtree(target_dir, ignore_errors=True)
194
+
195
+ print(f"Job {job_id} completed successfully")
196
+
197
+ except Exception as e:
198
+ print(f"Error in job {job_id}: {str(e)}")
199
+ jobs[job_id]["status"] = "failed"
200
+ jobs[job_id]["error"] = str(e)
201
+
202
+ if client_id and client_id in websocket_connections:
203
+ await websocket_connections[client_id].send_json(
204
+ {"type": "error", "data": {"job_id": job_id, "error": str(e)}}
205
+ )
206
+
207
+
208
+ # -------------------------------------------------------------------------
209
+ # API Endpoints
210
+ # -------------------------------------------------------------------------
211
+ @app.on_event("startup")
212
+ async def startup_event():
213
+ """Load model on startup"""
214
+ load_model()
215
+
216
+
217
+ @app.get("/")
218
+ async def root():
219
+ """Health check endpoint"""
220
+ return {"status": "ok", "service": "VGGT Inference API"}
221
+
222
+
223
+ @app.post("/inference")
224
+ async def create_inference(request: InferenceRequest, token: str = Query(...)):
225
+ """
226
+ Submit an inference job
227
+
228
+ Args:
229
+ request: InferenceRequest containing images and client_id
230
+ token: Authentication token (currently not validated, for compatibility)
231
+
232
+ Returns:
233
+ InferenceResponse with job_id
234
+ """
235
+ # Generate unique job ID
236
+ job_id = str(uuid.uuid4())
237
+
238
+ # Create temporary directory for images
239
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
240
+ target_dir = f"/tmp/vggt_job_{job_id}_{timestamp}"
241
+ target_dir_images = os.path.join(target_dir, "images")
242
+ os.makedirs(target_dir_images, exist_ok=True)
243
+
244
+ # Decode and save images
245
+ try:
246
+ for img_data in request.images:
247
+ img_bytes = base64.b64decode(img_data.data)
248
+ img_path = os.path.join(target_dir_images, img_data.filename)
249
+ with open(img_path, "wb") as f:
250
+ f.write(img_bytes)
251
+
252
+ # Initialize job
253
+ jobs[job_id] = {
254
+ "status": "queued",
255
+ "result": None,
256
+ "created_at": datetime.now().isoformat(),
257
+ }
258
+
259
+ # Start inference in background
260
+ asyncio.create_task(run_inference(job_id, target_dir, request.client_id))
261
+
262
+ return InferenceResponse(job_id=job_id, status="queued")
263
+
264
+ except Exception as e:
265
+ shutil.rmtree(target_dir, ignore_errors=True)
266
+ raise HTTPException(
267
+ status_code=400, detail=f"Failed to process images: {str(e)}"
268
+ )
269
+
270
+
271
+ @app.get("/result/{job_id}")
272
+ async def get_result(job_id: str, token: str = Query(...)):
273
+ """
274
+ Get inference result for a job
275
+
276
+ Args:
277
+ job_id: Job ID
278
+ token: Authentication token (currently not validated, for compatibility)
279
+
280
+ Returns:
281
+ Job result with predictions
282
+ """
283
+ if job_id not in jobs:
284
+ raise HTTPException(status_code=404, detail="Job not found")
285
+
286
+ job = jobs[job_id]
287
+
288
+ if job["status"] == "failed":
289
+ raise HTTPException(status_code=500, detail=job.get("error", "Job failed"))
290
+
291
+ if job["status"] != "completed":
292
+ return {job_id: {"status": job["status"]}}
293
+
294
+ return {job_id: job["result"]}
295
+
296
+
297
+ @app.websocket("/ws")
298
+ async def websocket_endpoint(
299
+ websocket: WebSocket, clientId: str = Query(...), token: str = Query(...)
300
+ ):
301
+ """
302
+ WebSocket endpoint for real-time progress updates
303
+
304
+ Args:
305
+ websocket: WebSocket connection
306
+ clientId: Client ID
307
+ token: Authentication token (currently not validated, for compatibility)
308
+ """
309
+ await websocket.accept()
310
+ websocket_connections[clientId] = websocket
311
+
312
+ try:
313
+ while True:
314
+ # Keep connection alive
315
+ data = await websocket.receive_text()
316
+ # Echo back for heartbeat
317
+ await websocket.send_text(data)
318
+ except Exception as e:
319
+ print(f"WebSocket error for client {clientId}: {str(e)}")
320
+ finally:
321
+ if clientId in websocket_connections:
322
+ del websocket_connections[clientId]
323
+
324
+
325
+ @app.get("/history/{job_id}")
326
+ async def get_history(job_id: str, token: str = Query(...)):
327
+ """
328
+ Get job history (alias for /result for compatibility)
329
+
330
+ Args:
331
+ job_id: Job ID
332
+ token: Authentication token
333
+
334
+ Returns:
335
+ Job history
336
+ """
337
+ return await get_result(job_id, token)
338
+
339
+
340
+ # -------------------------------------------------------------------------
341
+ # Main
342
+ # -------------------------------------------------------------------------
343
+ if __name__ == "__main__":
344
+ # Run server
345
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")