David Shavin
updating backend for different configs
72ab437
from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Query
from fastapi.responses import FileResponse
from ultralytics import YOLO
import shutil
import os
import cv2
app = FastAPI()
model = YOLO('yolov8n.pt')
TEMP_FOLDER = "temp_files"
os.makedirs(TEMP_FOLDER, exist_ok=True)
MY_PASSWORD = "david-shavin"
MAX_DURATION_SECONDS = 30
@app.get("/")
def home():
return {"message": "YOLOv8 Tracking API is Active"}
@app.post("/verify")
async def verify_password(x_password: str = Header(None)):
if x_password != MY_PASSWORD:
raise HTTPException(status_code=401, detail="Incorrect Password")
return {"message": "Access Granted"}
@app.post("/track-video")
async def track_video(
file: UploadFile = File(...),
x_password: str = Header(None),
# NEW ARGUMENT: target_fps (defaults to 5 if not sent)
target_fps: int = Query(5, ge=1, le=30)
):
# 1. Security Check
if x_password != MY_PASSWORD:
raise HTTPException(status_code=401, detail="Unauthorized")
# 2. Save File
input_path = os.path.join(TEMP_FOLDER, "input.mp4")
with open(input_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# 3. Calculate Stride based on Target FPS
try:
cap = cv2.VideoCapture(input_path)
orig_fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = frame_count / orig_fps if orig_fps > 0 else 0
cap.release()
if duration > MAX_DURATION_SECONDS:
os.remove(input_path)
raise HTTPException(status_code=400, detail=f"Video too long ({int(duration)}s). Limit is {MAX_DURATION_SECONDS}s.")
# --- RESAMPLING LOGIC ---
# If video is 30fps and user wants 5fps: stride = 30/5 = 6 (Process every 6th frame)
if orig_fps > 0:
stride = int(orig_fps / target_fps)
# Ensure stride is at least 1 (process every frame)
if stride < 1: stride = 1
else:
stride = 1 # Fallback
print(f"Original FPS: {orig_fps}, Target: {target_fps}, Calculated Stride: {stride}")
except HTTPException as he:
raise he
except Exception as e:
print(f"Metadata Error: {e}")
stride = 1
# 4. Processing
try:
results = model.track(
source=input_path,
save=True,
project=TEMP_FOLDER,
name="tracking_result",
exist_ok=True,
vid_stride=stride # <--- Apply the calculated resampling
)
output_video_path = os.path.join(TEMP_FOLDER, "tracking_result", "input.avi")
return FileResponse(output_video_path, media_type="video/x-msvideo", filename="tracked_video.avi")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Processing Error: {str(e)}")