import os import shutil import tempfile from typing import List, Dict, Any, Optional from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field # Import your existing analyzer logic from gait_analysis import EnhancedGaitAnalyzer app = FastAPI( title="Gait Analysis Inference Microservice", description="Microservice to process video with YOLO11 and return gait metrics. Upload a video to extract clinical gait analysis features including stability, symmetry, and developmental observations.", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # Standard CORS configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Pydantic Models for Swagger Documentation --- class HealthResponse(BaseModel): status: str = Field(..., example="ok") message: str = Field(..., example="Gait Analysis Inference Service is running.") class GaitSummary(BaseModel): stability_score: float = Field(..., description="Calculated stability score (out of 100). Higher is better.") symmetry_score: float = Field(..., description="Calculated symmetry score (out of 100). Higher is better.") weight_shift_score: float = Field(..., description="Calculated weight shift score (out of 100). Higher is better.") strength_score: float = Field(..., description="Calculated strength score (out of 100). Higher is better.") phase_asymmetry_percent: float = Field(..., description="Percentage of asymmetry between left and right stance times") overall_gait_quality: str = Field(..., description="Overall assessment of gait quality (e.g. acceptable, concerning)") walking_condition: str = Field(..., description="E.g., Independent or Assisted") class GaitObservations(BaseModel): trunk: str = Field(..., description="Observation of trunk movement (e.g. 'Neutral', 'Rotated')") head_neck: str = Field(..., description="Observation of head and neck (e.g. 'Neutral & mobile')") arms_hands: str = Field(..., description="Observation of arms and hands (e.g. 'Free arm swing')") lower_limbs: str = Field(..., description="Observation of lower limbs") symmetry: str = Field(..., description="General symmetry observation") weight_distribution: str = Field(..., description="Observation of weight distribution") postural_control: str = Field(..., description="Observation of postural control") class WeightDistribution(BaseModel): overall_imbalance_percent: float = Field(..., description="Percentage of imbalance") primary_weight_side: str = Field(..., description="Primary weight side (left/right/balanced)") weight_shift_issues: List[str] = Field(..., description="Identified weight shift issues") hip_drop_analysis: List[Dict[str, str]] = Field(..., description="Hip drop analysis") class PosturalStrength(BaseModel): trunk_sway_severity: str = Field(..., description="Severity of trunk sway") shoulder_stability: str = Field(..., description="Shoulder stability observation") forward_lean: str = Field(..., description="Forward lean observation") knee_stability: str = Field(..., description="Knee stability observation") overall_strength_assessment: str = Field(..., description="Overall strength assessment") class GaitPhases(BaseModel): stance_time_l: int = Field(..., description="Number of frames the left foot is in stance phase") stance_time_r: int = Field(..., description="Number of frames the right foot is in stance phase") swing_time_l: int = Field(..., description="Number of frames the left foot is in swing phase") swing_time_r: int = Field(..., description="Number of frames the right foot is in swing phase") double_support_frames: int = Field(..., description="Number of frames where both feet are in stance phase") phase_asymmetry: float = Field(..., description="Phase asymmetry percentage") class GaitAnalysisData(BaseModel): frames: List[Dict[str, Any]] = Field(..., description="List of processed frame data including keypoints and angles") summary: GaitSummary = Field(..., description="Summary metrics of the gait analysis") weight_distribution: WeightDistribution = Field(..., description="Weight distribution analysis") postural_strength: PosturalStrength = Field(..., description="Postural strength assessment") gait_phases: GaitPhases = Field(..., description="Gait phase breakdown") risk_indicators: List[str] = Field(..., description="List of identified risk factors") flags: Dict[str, Any] = Field(..., description="Specific flags or markers triggered during analysis") observations: GaitObservations = Field(..., description="Clinical observations of body segments") compensatory_strategies: List[str] = Field(..., description="Identified compensatory movement strategies") reflex_influence: str = Field(..., description="Notes on primitive reflex influence") safety_note: str = Field(..., description="Any constraints or limitations of the analysis") clinical_notes: List[str] = Field(..., description="Clinical notes on findings") class AnalyzeResponse(BaseModel): status: str = Field(..., example="success") data: GaitAnalysisData = Field(..., description="The complete gait analysis results") # -------------------------------------------------- # We initialize the analyzer globally so the YOLO model is loaded into memory ONCE on startup. # This prevents "cold starts" where the heavy model has to reload for every single request. analyzer = None @app.on_event("startup") async def startup_event(): global analyzer print("Initializing YOLO framework and loading model into memory...") # Initialize with default placeholders; we dynamically update context per request analyzer = EnhancedGaitAnalyzer(model_name="yolo11n-pose.pt") print("Model initialized and ready for inference!") @app.get( "/", response_model=HealthResponse, tags=["System"], summary="Health Check", description="Simple health check endpoint to ping from the main backend." ) def health_check(): return {"status": "ok", "message": "Gait Analysis Inference Service is running."} @app.post( "/analyze", response_model=AnalyzeResponse, tags=["Analysis"], summary="Analyze Video", description="Upload a video to perform full gait analysis. This endpoint loads the video, detects poses, and calculates clinical gait metrics." ) async def analyze_video( file: UploadFile = File(..., description="The highly compressed video file (e.g., .mp4, .mov)"), age: Optional[int] = Form(None, description="Age of the patient for developmental benchmarks (determines expected normative values). Defaults to 4 if not provided."), independent_walking: bool = Form(True, description="Whether the patient is walking independently") ): if not file.filename: raise HTTPException(status_code=400, detail="No file uploaded") if age is not None and age > 11: raise HTTPException(status_code=400, detail="OVER_AGE: Patient age exceeds the supported range (maximum 11 years) for this analysis.") # Default age to 4 if not provided process_age = age if age is not None else 4 # Create a temporary directory to store the incoming video temp_dir = tempfile.mkdtemp() temp_video_path = os.path.join(temp_dir, file.filename) try: # 1. Save the incoming stream to a temporary disk file with open(temp_video_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) print(f"Received video: {file.filename} (Age: {process_age}, Independent: {independent_walking})") # 2. Update the analyzer's dynamic properties for this specific request analyzer.age = process_age analyzer.independent_walking = independent_walking # 3. Process the video results = analyzer.process_video(temp_video_path) # 4. Return the massive JSON payload back to the main backend server # We use JSONResponse directly to avoid Pydantic serialization overhead for 'frames' list, # but the Swagger documentation is still powered by the AnalyzeResponse model. return JSONResponse(content={"status": "success", "data": results}) except ValueError as ve: print(f"Validation Error: {ve}") # Return a 400 Bad Request since this is a client/input issue raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: print(f"Error running inference: {e}") raise HTTPException(status_code=500, detail=str(e)) finally: # 5. Cleanup memory/disk: delete the video so the server doesn't run out of storage if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) if __name__ == "__main__": import uvicorn # Use standard uvicorn runner for local testing uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)