techSnipe's picture
Upload folder using huggingface_hub
83ee618 verified
import os
import shutil
import tempfile
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
# Import your existing analyzer logic
from gait_analysis import EnhancedGaitAnalyzer
app = FastAPI(
title="Gait Analysis Inference Microservice",
description="Microservice to process video with YOLO11 and return gait metrics. Upload a video to extract clinical gait analysis features including stability, symmetry, and developmental observations.",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Standard CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Pydantic Models for Swagger Documentation ---
class HealthResponse(BaseModel):
status: str = Field(..., example="ok")
message: str = Field(..., example="Gait Analysis Inference Service is running.")
class GaitSummary(BaseModel):
stability_score: float = Field(..., description="Calculated stability score (out of 100). Higher is better.")
symmetry_score: float = Field(..., description="Calculated symmetry score (out of 100). Higher is better.")
weight_shift_score: float = Field(..., description="Calculated weight shift score (out of 100). Higher is better.")
strength_score: float = Field(..., description="Calculated strength score (out of 100). Higher is better.")
phase_asymmetry_percent: float = Field(..., description="Percentage of asymmetry between left and right stance times")
overall_gait_quality: str = Field(..., description="Overall assessment of gait quality (e.g. acceptable, concerning)")
walking_condition: str = Field(..., description="E.g., Independent or Assisted")
class GaitObservations(BaseModel):
trunk: str = Field(..., description="Observation of trunk movement (e.g. 'Neutral', 'Rotated')")
head_neck: str = Field(..., description="Observation of head and neck (e.g. 'Neutral & mobile')")
arms_hands: str = Field(..., description="Observation of arms and hands (e.g. 'Free arm swing')")
lower_limbs: str = Field(..., description="Observation of lower limbs")
symmetry: str = Field(..., description="General symmetry observation")
weight_distribution: str = Field(..., description="Observation of weight distribution")
postural_control: str = Field(..., description="Observation of postural control")
class WeightDistribution(BaseModel):
overall_imbalance_percent: float = Field(..., description="Percentage of imbalance")
primary_weight_side: str = Field(..., description="Primary weight side (left/right/balanced)")
weight_shift_issues: List[str] = Field(..., description="Identified weight shift issues")
hip_drop_analysis: List[Dict[str, str]] = Field(..., description="Hip drop analysis")
class PosturalStrength(BaseModel):
trunk_sway_severity: str = Field(..., description="Severity of trunk sway")
shoulder_stability: str = Field(..., description="Shoulder stability observation")
forward_lean: str = Field(..., description="Forward lean observation")
knee_stability: str = Field(..., description="Knee stability observation")
overall_strength_assessment: str = Field(..., description="Overall strength assessment")
class GaitPhases(BaseModel):
stance_time_l: int = Field(..., description="Number of frames the left foot is in stance phase")
stance_time_r: int = Field(..., description="Number of frames the right foot is in stance phase")
swing_time_l: int = Field(..., description="Number of frames the left foot is in swing phase")
swing_time_r: int = Field(..., description="Number of frames the right foot is in swing phase")
double_support_frames: int = Field(..., description="Number of frames where both feet are in stance phase")
phase_asymmetry: float = Field(..., description="Phase asymmetry percentage")
class GaitAnalysisData(BaseModel):
frames: List[Dict[str, Any]] = Field(..., description="List of processed frame data including keypoints and angles")
summary: GaitSummary = Field(..., description="Summary metrics of the gait analysis")
weight_distribution: WeightDistribution = Field(..., description="Weight distribution analysis")
postural_strength: PosturalStrength = Field(..., description="Postural strength assessment")
gait_phases: GaitPhases = Field(..., description="Gait phase breakdown")
risk_indicators: List[str] = Field(..., description="List of identified risk factors")
flags: Dict[str, Any] = Field(..., description="Specific flags or markers triggered during analysis")
observations: GaitObservations = Field(..., description="Clinical observations of body segments")
compensatory_strategies: List[str] = Field(..., description="Identified compensatory movement strategies")
reflex_influence: str = Field(..., description="Notes on primitive reflex influence")
safety_note: str = Field(..., description="Any constraints or limitations of the analysis")
clinical_notes: List[str] = Field(..., description="Clinical notes on findings")
class AnalyzeResponse(BaseModel):
status: str = Field(..., example="success")
data: GaitAnalysisData = Field(..., description="The complete gait analysis results")
# --------------------------------------------------
# We initialize the analyzer globally so the YOLO model is loaded into memory ONCE on startup.
# This prevents "cold starts" where the heavy model has to reload for every single request.
analyzer = None
@app.on_event("startup")
async def startup_event():
global analyzer
print("Initializing YOLO framework and loading model into memory...")
# Initialize with default placeholders; we dynamically update context per request
analyzer = EnhancedGaitAnalyzer(model_name="yolo11n-pose.pt")
print("Model initialized and ready for inference!")
@app.get(
"/",
response_model=HealthResponse,
tags=["System"],
summary="Health Check",
description="Simple health check endpoint to ping from the main backend."
)
def health_check():
return {"status": "ok", "message": "Gait Analysis Inference Service is running."}
@app.post(
"/analyze",
response_model=AnalyzeResponse,
tags=["Analysis"],
summary="Analyze Video",
description="Upload a video to perform full gait analysis. This endpoint loads the video, detects poses, and calculates clinical gait metrics."
)
async def analyze_video(
file: UploadFile = File(..., description="The highly compressed video file (e.g., .mp4, .mov)"),
age: Optional[int] = Form(None, description="Age of the patient for developmental benchmarks (determines expected normative values). Defaults to 4 if not provided."),
independent_walking: bool = Form(True, description="Whether the patient is walking independently")
):
if not file.filename:
raise HTTPException(status_code=400, detail="No file uploaded")
if age is not None and age > 11:
raise HTTPException(status_code=400, detail="OVER_AGE: Patient age exceeds the supported range (maximum 11 years) for this analysis.")
# Default age to 4 if not provided
process_age = age if age is not None else 4
# Create a temporary directory to store the incoming video
temp_dir = tempfile.mkdtemp()
temp_video_path = os.path.join(temp_dir, file.filename)
try:
# 1. Save the incoming stream to a temporary disk file
with open(temp_video_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
print(f"Received video: {file.filename} (Age: {process_age}, Independent: {independent_walking})")
# 2. Update the analyzer's dynamic properties for this specific request
analyzer.age = process_age
analyzer.independent_walking = independent_walking
# 3. Process the video
results = analyzer.process_video(temp_video_path)
# 4. Return the massive JSON payload back to the main backend server
# We use JSONResponse directly to avoid Pydantic serialization overhead for 'frames' list,
# but the Swagger documentation is still powered by the AnalyzeResponse model.
return JSONResponse(content={"status": "success", "data": results})
except ValueError as ve:
print(f"Validation Error: {ve}")
# Return a 400 Bad Request since this is a client/input issue
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
print(f"Error running inference: {e}")
raise HTTPException(status_code=500, detail=str(e))
finally:
# 5. Cleanup memory/disk: delete the video so the server doesn't run out of storage
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == "__main__":
import uvicorn
# Use standard uvicorn runner for local testing
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)