Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import tempfile | |
| from typing import List, Dict, Any, Optional | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| # Import your existing analyzer logic | |
| from gait_analysis import EnhancedGaitAnalyzer | |
| app = FastAPI( | |
| title="Gait Analysis Inference Microservice", | |
| description="Microservice to process video with YOLO11 and return gait metrics. Upload a video to extract clinical gait analysis features including stability, symmetry, and developmental observations.", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Standard CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Pydantic Models for Swagger Documentation --- | |
| class HealthResponse(BaseModel): | |
| status: str = Field(..., example="ok") | |
| message: str = Field(..., example="Gait Analysis Inference Service is running.") | |
| class GaitSummary(BaseModel): | |
| stability_score: float = Field(..., description="Calculated stability score (out of 100). Higher is better.") | |
| symmetry_score: float = Field(..., description="Calculated symmetry score (out of 100). Higher is better.") | |
| weight_shift_score: float = Field(..., description="Calculated weight shift score (out of 100). Higher is better.") | |
| strength_score: float = Field(..., description="Calculated strength score (out of 100). Higher is better.") | |
| phase_asymmetry_percent: float = Field(..., description="Percentage of asymmetry between left and right stance times") | |
| overall_gait_quality: str = Field(..., description="Overall assessment of gait quality (e.g. acceptable, concerning)") | |
| walking_condition: str = Field(..., description="E.g., Independent or Assisted") | |
| class GaitObservations(BaseModel): | |
| trunk: str = Field(..., description="Observation of trunk movement (e.g. 'Neutral', 'Rotated')") | |
| head_neck: str = Field(..., description="Observation of head and neck (e.g. 'Neutral & mobile')") | |
| arms_hands: str = Field(..., description="Observation of arms and hands (e.g. 'Free arm swing')") | |
| lower_limbs: str = Field(..., description="Observation of lower limbs") | |
| symmetry: str = Field(..., description="General symmetry observation") | |
| weight_distribution: str = Field(..., description="Observation of weight distribution") | |
| postural_control: str = Field(..., description="Observation of postural control") | |
| class WeightDistribution(BaseModel): | |
| overall_imbalance_percent: float = Field(..., description="Percentage of imbalance") | |
| primary_weight_side: str = Field(..., description="Primary weight side (left/right/balanced)") | |
| weight_shift_issues: List[str] = Field(..., description="Identified weight shift issues") | |
| hip_drop_analysis: List[Dict[str, str]] = Field(..., description="Hip drop analysis") | |
| class PosturalStrength(BaseModel): | |
| trunk_sway_severity: str = Field(..., description="Severity of trunk sway") | |
| shoulder_stability: str = Field(..., description="Shoulder stability observation") | |
| forward_lean: str = Field(..., description="Forward lean observation") | |
| knee_stability: str = Field(..., description="Knee stability observation") | |
| overall_strength_assessment: str = Field(..., description="Overall strength assessment") | |
| class GaitPhases(BaseModel): | |
| stance_time_l: int = Field(..., description="Number of frames the left foot is in stance phase") | |
| stance_time_r: int = Field(..., description="Number of frames the right foot is in stance phase") | |
| swing_time_l: int = Field(..., description="Number of frames the left foot is in swing phase") | |
| swing_time_r: int = Field(..., description="Number of frames the right foot is in swing phase") | |
| double_support_frames: int = Field(..., description="Number of frames where both feet are in stance phase") | |
| phase_asymmetry: float = Field(..., description="Phase asymmetry percentage") | |
| class GaitAnalysisData(BaseModel): | |
| frames: List[Dict[str, Any]] = Field(..., description="List of processed frame data including keypoints and angles") | |
| summary: GaitSummary = Field(..., description="Summary metrics of the gait analysis") | |
| weight_distribution: WeightDistribution = Field(..., description="Weight distribution analysis") | |
| postural_strength: PosturalStrength = Field(..., description="Postural strength assessment") | |
| gait_phases: GaitPhases = Field(..., description="Gait phase breakdown") | |
| risk_indicators: List[str] = Field(..., description="List of identified risk factors") | |
| flags: Dict[str, Any] = Field(..., description="Specific flags or markers triggered during analysis") | |
| observations: GaitObservations = Field(..., description="Clinical observations of body segments") | |
| compensatory_strategies: List[str] = Field(..., description="Identified compensatory movement strategies") | |
| reflex_influence: str = Field(..., description="Notes on primitive reflex influence") | |
| safety_note: str = Field(..., description="Any constraints or limitations of the analysis") | |
| clinical_notes: List[str] = Field(..., description="Clinical notes on findings") | |
| class AnalyzeResponse(BaseModel): | |
| status: str = Field(..., example="success") | |
| data: GaitAnalysisData = Field(..., description="The complete gait analysis results") | |
| # -------------------------------------------------- | |
| # We initialize the analyzer globally so the YOLO model is loaded into memory ONCE on startup. | |
| # This prevents "cold starts" where the heavy model has to reload for every single request. | |
| analyzer = None | |
| async def startup_event(): | |
| global analyzer | |
| print("Initializing YOLO framework and loading model into memory...") | |
| # Initialize with default placeholders; we dynamically update context per request | |
| analyzer = EnhancedGaitAnalyzer(model_name="yolo11n-pose.pt") | |
| print("Model initialized and ready for inference!") | |
| def health_check(): | |
| return {"status": "ok", "message": "Gait Analysis Inference Service is running."} | |
| async def analyze_video( | |
| file: UploadFile = File(..., description="The highly compressed video file (e.g., .mp4, .mov)"), | |
| age: Optional[int] = Form(None, description="Age of the patient for developmental benchmarks (determines expected normative values). Defaults to 4 if not provided."), | |
| independent_walking: bool = Form(True, description="Whether the patient is walking independently") | |
| ): | |
| if not file.filename: | |
| raise HTTPException(status_code=400, detail="No file uploaded") | |
| if age is not None and age > 11: | |
| raise HTTPException(status_code=400, detail="OVER_AGE: Patient age exceeds the supported range (maximum 11 years) for this analysis.") | |
| # Default age to 4 if not provided | |
| process_age = age if age is not None else 4 | |
| # Create a temporary directory to store the incoming video | |
| temp_dir = tempfile.mkdtemp() | |
| temp_video_path = os.path.join(temp_dir, file.filename) | |
| try: | |
| # 1. Save the incoming stream to a temporary disk file | |
| with open(temp_video_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| print(f"Received video: {file.filename} (Age: {process_age}, Independent: {independent_walking})") | |
| # 2. Update the analyzer's dynamic properties for this specific request | |
| analyzer.age = process_age | |
| analyzer.independent_walking = independent_walking | |
| # 3. Process the video | |
| results = analyzer.process_video(temp_video_path) | |
| # 4. Return the massive JSON payload back to the main backend server | |
| # We use JSONResponse directly to avoid Pydantic serialization overhead for 'frames' list, | |
| # but the Swagger documentation is still powered by the AnalyzeResponse model. | |
| return JSONResponse(content={"status": "success", "data": results}) | |
| except ValueError as ve: | |
| print(f"Validation Error: {ve}") | |
| # Return a 400 Bad Request since this is a client/input issue | |
| raise HTTPException(status_code=400, detail=str(ve)) | |
| except Exception as e: | |
| print(f"Error running inference: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| # 5. Cleanup memory/disk: delete the video so the server doesn't run out of storage | |
| if os.path.exists(temp_dir): | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Use standard uvicorn runner for local testing | |
| uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) | |