Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Query | |
| from fastapi.responses import FileResponse | |
| from ultralytics import YOLO | |
| import shutil | |
| import os | |
| import cv2 | |
| app = FastAPI() | |
| model = YOLO('yolov8n.pt') | |
| TEMP_FOLDER = "temp_files" | |
| os.makedirs(TEMP_FOLDER, exist_ok=True) | |
| MY_PASSWORD = "david-shavin" | |
| MAX_DURATION_SECONDS = 30 | |
| def home(): | |
| return {"message": "YOLOv8 Tracking API is Active"} | |
| async def verify_password(x_password: str = Header(None)): | |
| if x_password != MY_PASSWORD: | |
| raise HTTPException(status_code=401, detail="Incorrect Password") | |
| return {"message": "Access Granted"} | |
| async def track_video( | |
| file: UploadFile = File(...), | |
| x_password: str = Header(None), | |
| # NEW ARGUMENT: target_fps (defaults to 5 if not sent) | |
| target_fps: int = Query(5, ge=1, le=30) | |
| ): | |
| # 1. Security Check | |
| if x_password != MY_PASSWORD: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| # 2. Save File | |
| input_path = os.path.join(TEMP_FOLDER, "input.mp4") | |
| with open(input_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| # 3. Calculate Stride based on Target FPS | |
| try: | |
| cap = cv2.VideoCapture(input_path) | |
| orig_fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| duration = frame_count / orig_fps if orig_fps > 0 else 0 | |
| cap.release() | |
| if duration > MAX_DURATION_SECONDS: | |
| os.remove(input_path) | |
| raise HTTPException(status_code=400, detail=f"Video too long ({int(duration)}s). Limit is {MAX_DURATION_SECONDS}s.") | |
| # --- RESAMPLING LOGIC --- | |
| # If video is 30fps and user wants 5fps: stride = 30/5 = 6 (Process every 6th frame) | |
| if orig_fps > 0: | |
| stride = int(orig_fps / target_fps) | |
| # Ensure stride is at least 1 (process every frame) | |
| if stride < 1: stride = 1 | |
| else: | |
| stride = 1 # Fallback | |
| print(f"Original FPS: {orig_fps}, Target: {target_fps}, Calculated Stride: {stride}") | |
| except HTTPException as he: | |
| raise he | |
| except Exception as e: | |
| print(f"Metadata Error: {e}") | |
| stride = 1 | |
| # 4. Processing | |
| try: | |
| results = model.track( | |
| source=input_path, | |
| save=True, | |
| project=TEMP_FOLDER, | |
| name="tracking_result", | |
| exist_ok=True, | |
| vid_stride=stride # <--- Apply the calculated resampling | |
| ) | |
| output_video_path = os.path.join(TEMP_FOLDER, "tracking_result", "input.avi") | |
| return FileResponse(output_video_path, media_type="video/x-msvideo", filename="tracked_video.avi") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Processing Error: {str(e)}") |