#!/usr/bin/env python """ FastAPI Server for VisioTrack on Hugging Face Spaces REST API for object tracking in videos """ from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware import cv2 import torch import numpy as np import tempfile import os import subprocess import shutil from pathlib import Path from siamrpn import TrackerSiamRPN import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="VisioTrack API", description="Object tracking API using SiamRPN", version="1.0.0", docs_url="/", # Swagger UI at root redoc_url="/redoc" ) # Enable CORS for frontend integration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Model configuration MODEL_PATH = "model.pth" tracker = None device = None def load_tracker(): """Load the SiamRPN tracker with GPU support""" global tracker, device if tracker is None: if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model file '{MODEL_PATH}' not found!") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') tracker = TrackerSiamRPN(net_path=MODEL_PATH) logger.info(f"✓ Tracker loaded on {device}") return tracker def process_video_tracking(video_path: str, bbox_x: int, bbox_y: int, bbox_w: int, bbox_h: int): """ Process video with object tracking Args: video_path: Path to input video bbox_x, bbox_y, bbox_w, bbox_h: Bounding box coordinates Returns: tuple: (output_path, message, metadata) """ try: tracker_instance = load_tracker() cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None, "Could not open video file", None # Get video properties fps = int(cap.get(cv2.CAP_PROP_FPS)) if fps == 0: fps = 30 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) logger.info(f"Video: {width}x{height} @ {fps}fps, {total_frames} frames") ret, frame = cap.read() if not ret: return None, "Could not read first frame", None # Validate bounding box if bbox_w <= 0 or bbox_h <= 0: return None, "Invalid bounding box dimensions", None if (bbox_x < 0 or bbox_y < 0 or bbox_x + bbox_w > width or bbox_y + bbox_h > height): return None, f"Bounding box out of bounds (frame: {width}x{height})", None bbox = [bbox_x, bbox_y, bbox_w, bbox_h] # Initialize tracker tracker_instance.init(frame, bbox) # Create temporary output file temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='_temp.mp4') temp_output.close() # Use XVID codec for initial write fourcc = cv2.VideoWriter_fourcc(*'XVID') writer = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height)) if not writer.isOpened(): return None, "Could not create video writer", None # Draw first frame with initial bbox x, y, w, h = [int(v) for v in bbox] cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 3) cv2.putText(frame, 'Frame: 1', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) writer.write(frame) # Process remaining frames frame_count = 1 while True: ret, frame = cap.read() if not ret: break frame_count += 1 # Update tracker bbox = tracker_instance.update(frame) # Draw tracking result x, y, w, h = [int(v) for v in bbox] x = max(0, min(x, width - 1)) y = max(0, min(y, height - 1)) w = max(1, min(w, width - x)) h = max(1, min(h, height - y)) cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 3) cv2.putText(frame, f'Frame: {frame_count}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) writer.write(frame) if frame_count % 30 == 0: logger.info(f"Processed {frame_count}/{total_frames} frames") cap.release() writer.release() # Re-encode with H.264 for browser compatibility final_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') final_output.close() try: logger.info("Re-encoding video for browser compatibility...") subprocess.run([ 'ffmpeg', '-i', temp_output.name, '-c:v', 'libx264', '-preset', 'fast', '-crf', '23', '-pix_fmt', 'yuv420p', '-movflags', '+faststart', '-y', final_output.name ], check=True, capture_output=True, text=True) os.unlink(temp_output.name) logger.info("✓ Video re-encoded successfully") except (subprocess.CalledProcessError, FileNotFoundError) as e: logger.warning(f"FFmpeg encoding failed: {e}, using original") shutil.move(temp_output.name, final_output.name) metadata = { 'frames_processed': frame_count, 'resolution': f"{width}x{height}", 'fps': fps, 'device': str(device) } return final_output.name, f"Successfully tracked {frame_count} frames", metadata except Exception as e: logger.error(f"Tracking error: {str(e)}") return None, f"Error: {str(e)}", None @app.get("/health") async def health_check(): """ Health check endpoint (required by HF Spaces) """ return JSONResponse({ 'status': 'healthy', 'gpu_available': torch.cuda.is_available(), 'gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, 'model_loaded': tracker is not None }) @app.post("/track") async def track_video( video: UploadFile = File(..., description="Video file to process"), bbox_x: int = Form(..., description="X coordinate of bounding box"), bbox_y: int = Form(..., description="Y coordinate of bounding box"), bbox_w: int = Form(..., description="Width of bounding box"), bbox_h: int = Form(..., description="Height of bounding box") ): """ Main tracking endpoint Upload a video and bounding box coordinates to track an object. Returns the processed video with tracking visualization. """ temp_input = None output_path = None try: # Validate file type if not video.content_type.startswith('video/'): raise HTTPException(status_code=400, detail="File must be a video") # Save uploaded video temp_input = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') content = await video.read() temp_input.write(content) temp_input.close() logger.info(f"Processing video: {video.filename}") logger.info(f"Bounding box: ({bbox_x}, {bbox_y}, {bbox_w}, {bbox_h})") # Process video output_path, message, metadata = process_video_tracking( temp_input.name, bbox_x, bbox_y, bbox_w, bbox_h ) if output_path is None: raise HTTPException(status_code=400, detail=message) # Return processed video return FileResponse( output_path, media_type='video/mp4', filename='tracked_video.mp4', headers={ 'X-Frames-Processed': str(metadata['frames_processed']), 'X-Resolution': metadata['resolution'], 'X-FPS': str(metadata['fps']) } ) except HTTPException: raise except Exception as e: logger.error(f"Error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) finally: # Cleanup temporary files if temp_input and os.path.exists(temp_input.name): try: os.unlink(temp_input.name) except: pass @app.get("/info") async def get_info(): """ Get API information and usage instructions """ return { 'name': 'VisioTrack API', 'version': '1.0.0', 'description': 'Object tracking API using SiamRPN', 'endpoints': { '/health': 'Health check', '/track': 'Track object in video (POST with multipart/form-data)', '/info': 'API information', '/': 'Interactive API documentation (Swagger UI)' }, 'usage': { 'method': 'POST', 'endpoint': '/track', 'content_type': 'multipart/form-data', 'parameters': { 'video': 'Video file', 'bbox_x': 'X coordinate (int)', 'bbox_y': 'Y coordinate (int)', 'bbox_w': 'Width (int)', 'bbox_h': 'Height (int)' } }, 'example_curl': ''' curl -X POST "https://your-space.hf.space/track" \\ -F "video=@video.mp4" \\ -F "bbox_x=100" \\ -F "bbox_y=100" \\ -F "bbox_w=200" \\ -F "bbox_h=200" \\ -o tracked_video.mp4 ''' } @app.on_event("startup") async def startup_event(): """Load model on startup""" logger.info("=" * 50) logger.info("VisioTrack FastAPI Server Starting...") logger.info("=" * 50) try: load_tracker() logger.info("✓ Model loaded successfully") except Exception as e: logger.error(f"✗ Failed to load model: {e}") logger.info("=" * 50) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)