Spaces:
Sleeping
Sleeping
File size: 9,403 Bytes
83ee618 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import os
import shutil
import tempfile
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
# Import your existing analyzer logic
from gait_analysis import EnhancedGaitAnalyzer
app = FastAPI(
title="Gait Analysis Inference Microservice",
description="Microservice to process video with YOLO11 and return gait metrics. Upload a video to extract clinical gait analysis features including stability, symmetry, and developmental observations.",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Standard CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Pydantic Models for Swagger Documentation ---
class HealthResponse(BaseModel):
status: str = Field(..., example="ok")
message: str = Field(..., example="Gait Analysis Inference Service is running.")
class GaitSummary(BaseModel):
stability_score: float = Field(..., description="Calculated stability score (out of 100). Higher is better.")
symmetry_score: float = Field(..., description="Calculated symmetry score (out of 100). Higher is better.")
weight_shift_score: float = Field(..., description="Calculated weight shift score (out of 100). Higher is better.")
strength_score: float = Field(..., description="Calculated strength score (out of 100). Higher is better.")
phase_asymmetry_percent: float = Field(..., description="Percentage of asymmetry between left and right stance times")
overall_gait_quality: str = Field(..., description="Overall assessment of gait quality (e.g. acceptable, concerning)")
walking_condition: str = Field(..., description="E.g., Independent or Assisted")
class GaitObservations(BaseModel):
trunk: str = Field(..., description="Observation of trunk movement (e.g. 'Neutral', 'Rotated')")
head_neck: str = Field(..., description="Observation of head and neck (e.g. 'Neutral & mobile')")
arms_hands: str = Field(..., description="Observation of arms and hands (e.g. 'Free arm swing')")
lower_limbs: str = Field(..., description="Observation of lower limbs")
symmetry: str = Field(..., description="General symmetry observation")
weight_distribution: str = Field(..., description="Observation of weight distribution")
postural_control: str = Field(..., description="Observation of postural control")
class WeightDistribution(BaseModel):
overall_imbalance_percent: float = Field(..., description="Percentage of imbalance")
primary_weight_side: str = Field(..., description="Primary weight side (left/right/balanced)")
weight_shift_issues: List[str] = Field(..., description="Identified weight shift issues")
hip_drop_analysis: List[Dict[str, str]] = Field(..., description="Hip drop analysis")
class PosturalStrength(BaseModel):
trunk_sway_severity: str = Field(..., description="Severity of trunk sway")
shoulder_stability: str = Field(..., description="Shoulder stability observation")
forward_lean: str = Field(..., description="Forward lean observation")
knee_stability: str = Field(..., description="Knee stability observation")
overall_strength_assessment: str = Field(..., description="Overall strength assessment")
class GaitPhases(BaseModel):
stance_time_l: int = Field(..., description="Number of frames the left foot is in stance phase")
stance_time_r: int = Field(..., description="Number of frames the right foot is in stance phase")
swing_time_l: int = Field(..., description="Number of frames the left foot is in swing phase")
swing_time_r: int = Field(..., description="Number of frames the right foot is in swing phase")
double_support_frames: int = Field(..., description="Number of frames where both feet are in stance phase")
phase_asymmetry: float = Field(..., description="Phase asymmetry percentage")
class GaitAnalysisData(BaseModel):
frames: List[Dict[str, Any]] = Field(..., description="List of processed frame data including keypoints and angles")
summary: GaitSummary = Field(..., description="Summary metrics of the gait analysis")
weight_distribution: WeightDistribution = Field(..., description="Weight distribution analysis")
postural_strength: PosturalStrength = Field(..., description="Postural strength assessment")
gait_phases: GaitPhases = Field(..., description="Gait phase breakdown")
risk_indicators: List[str] = Field(..., description="List of identified risk factors")
flags: Dict[str, Any] = Field(..., description="Specific flags or markers triggered during analysis")
observations: GaitObservations = Field(..., description="Clinical observations of body segments")
compensatory_strategies: List[str] = Field(..., description="Identified compensatory movement strategies")
reflex_influence: str = Field(..., description="Notes on primitive reflex influence")
safety_note: str = Field(..., description="Any constraints or limitations of the analysis")
clinical_notes: List[str] = Field(..., description="Clinical notes on findings")
class AnalyzeResponse(BaseModel):
status: str = Field(..., example="success")
data: GaitAnalysisData = Field(..., description="The complete gait analysis results")
# --------------------------------------------------
# We initialize the analyzer globally so the YOLO model is loaded into memory ONCE on startup.
# This prevents "cold starts" where the heavy model has to reload for every single request.
analyzer = None
@app.on_event("startup")
async def startup_event():
global analyzer
print("Initializing YOLO framework and loading model into memory...")
# Initialize with default placeholders; we dynamically update context per request
analyzer = EnhancedGaitAnalyzer(model_name="yolo11n-pose.pt")
print("Model initialized and ready for inference!")
@app.get(
"/",
response_model=HealthResponse,
tags=["System"],
summary="Health Check",
description="Simple health check endpoint to ping from the main backend."
)
def health_check():
return {"status": "ok", "message": "Gait Analysis Inference Service is running."}
@app.post(
"/analyze",
response_model=AnalyzeResponse,
tags=["Analysis"],
summary="Analyze Video",
description="Upload a video to perform full gait analysis. This endpoint loads the video, detects poses, and calculates clinical gait metrics."
)
async def analyze_video(
file: UploadFile = File(..., description="The highly compressed video file (e.g., .mp4, .mov)"),
age: Optional[int] = Form(None, description="Age of the patient for developmental benchmarks (determines expected normative values). Defaults to 4 if not provided."),
independent_walking: bool = Form(True, description="Whether the patient is walking independently")
):
if not file.filename:
raise HTTPException(status_code=400, detail="No file uploaded")
if age is not None and age > 11:
raise HTTPException(status_code=400, detail="OVER_AGE: Patient age exceeds the supported range (maximum 11 years) for this analysis.")
# Default age to 4 if not provided
process_age = age if age is not None else 4
# Create a temporary directory to store the incoming video
temp_dir = tempfile.mkdtemp()
temp_video_path = os.path.join(temp_dir, file.filename)
try:
# 1. Save the incoming stream to a temporary disk file
with open(temp_video_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
print(f"Received video: {file.filename} (Age: {process_age}, Independent: {independent_walking})")
# 2. Update the analyzer's dynamic properties for this specific request
analyzer.age = process_age
analyzer.independent_walking = independent_walking
# 3. Process the video
results = analyzer.process_video(temp_video_path)
# 4. Return the massive JSON payload back to the main backend server
# We use JSONResponse directly to avoid Pydantic serialization overhead for 'frames' list,
# but the Swagger documentation is still powered by the AnalyzeResponse model.
return JSONResponse(content={"status": "success", "data": results})
except ValueError as ve:
print(f"Validation Error: {ve}")
# Return a 400 Bad Request since this is a client/input issue
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
print(f"Error running inference: {e}")
raise HTTPException(status_code=500, detail=str(e))
finally:
# 5. Cleanup memory/disk: delete the video so the server doesn't run out of storage
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == "__main__":
import uvicorn
# Use standard uvicorn runner for local testing
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)
|