Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI application for Panic Detection Model | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import Literal | |
| import tensorflow as tf | |
| import numpy as np | |
| from preprocessing import PanicDetectionPreprocessor | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Panic Detection API", | |
| description="API for detecting panic situations based on health metrics", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware for Flutter app | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify your Flutter app's domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load model and preprocessor | |
| try: | |
| # Load H5 model using legacy Keras format | |
| # Force use of tf.keras instead of standalone keras | |
| import tensorflow.keras.models | |
| model = tensorflow.keras.models.load_model('cnn_model.h5', compile=False) | |
| # Recompile the model | |
| model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) | |
| preprocessor = PanicDetectionPreprocessor() | |
| logger.info("Model and preprocessor loaded successfully") | |
| logger.info(f"Model input shape: {model.input_shape}") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| logger.error("Please ensure the model file (cnn_model.h5) exists") | |
| raise | |
| # Request model | |
| class PredictionRequest(BaseModel): | |
| gender: Literal['Male', 'Female'] = Field(..., description="Gender of the person") | |
| age: int = Field(..., ge=1, le=120, description="Age in years") | |
| weight: float = Field(..., ge=20, le=300, description="Weight in kg") | |
| heartrate: int = Field(..., ge=40, le=220, description="Heart rate in bpm") | |
| stepcount: int = Field(..., ge=0, le=50000, description="Step count") | |
| activity: Literal['Running', 'Walking', 'Sitting', 'Standing', 'Cycling'] = Field( | |
| ..., description="Current activity" | |
| ) | |
| model_config = { | |
| "json_schema_extra": { | |
| "examples": [{ | |
| "gender": "Male", | |
| "age": 25, | |
| "weight": 70.5, | |
| "heartrate": 120, | |
| "stepcount": 5000, | |
| "activity": "Running" | |
| }] | |
| } | |
| } | |
| # Response model | |
| class PredictionResponse(BaseModel): | |
| panic_detected: bool | |
| panic_probability: float | |
| confidence: str | |
| input_data: dict | |
| model_config = { | |
| "json_schema_extra": { | |
| "examples": [{ | |
| "panic_detected": True, | |
| "panic_probability": 0.85, | |
| "confidence": "high", | |
| "input_data": { | |
| "gender": "Male", | |
| "age": 25, | |
| "weight": 70.5, | |
| "heartrate": 120, | |
| "stepcount": 5000, | |
| "activity": "Running" | |
| } | |
| }] | |
| } | |
| } | |
| async def root(): | |
| """Root endpoint""" | |
| return { | |
| "message": "Panic Detection API is running", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "predict": "/predict", | |
| "health": "/health", | |
| "model_info": "/model-info" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "preprocessor_loaded": preprocessor is not None | |
| } | |
| async def model_info(): | |
| """Get model information""" | |
| return { | |
| "model_type": "CNN (Convolutional Neural Network)", | |
| "input_features": [ | |
| "gender", "age", "weight", "heartrate", "stepcount", "activity" | |
| ], | |
| "output": "Binary classification (Panic/No Panic)", | |
| "valid_genders": preprocessor.get_valid_genders(), | |
| "valid_activities": preprocessor.get_valid_activities() | |
| } | |
| async def predict(request: PredictionRequest): | |
| """ | |
| Predict panic situation based on input health metrics | |
| Args: | |
| request: PredictionRequest containing health metrics | |
| Returns: | |
| PredictionResponse with panic detection results | |
| """ | |
| try: | |
| # Log the request | |
| logger.info(f"Received prediction request: {request.dict()}") | |
| # Preprocess input | |
| preprocessed_data = preprocessor.preprocess( | |
| gender=request.gender, | |
| age=request.age, | |
| weight=request.weight, | |
| heartrate=request.heartrate, | |
| stepcount=request.stepcount, | |
| activity=request.activity | |
| ) | |
| # Make prediction | |
| prediction = model.predict(preprocessed_data, verbose=0) | |
| panic_probability = float(prediction[0][0]) | |
| # Determine panic detection (threshold = 0.5) | |
| panic_detected = panic_probability >= 0.5 | |
| # Determine confidence level | |
| if panic_probability >= 0.8 or panic_probability <= 0.2: | |
| confidence = "high" | |
| elif panic_probability >= 0.6 or panic_probability <= 0.4: | |
| confidence = "medium" | |
| else: | |
| confidence = "low" | |
| # Prepare response | |
| response = PredictionResponse( | |
| panic_detected=panic_detected, | |
| panic_probability=round(panic_probability, 4), | |
| confidence=confidence, | |
| input_data=request.dict() | |
| ) | |
| logger.info(f"Prediction result: {response.dict()}") | |
| return response | |
| except ValueError as ve: | |
| logger.error(f"Validation error: {str(ve)}") | |
| raise HTTPException(status_code=400, detail=str(ve)) | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| async def batch_predict(requests: list[PredictionRequest]): | |
| """ | |
| Predict panic for multiple inputs | |
| Args: | |
| requests: List of PredictionRequest | |
| Returns: | |
| List of prediction results | |
| """ | |
| try: | |
| results = [] | |
| for req in requests: | |
| result = await predict(req) | |
| results.append(result) | |
| return {"predictions": results, "count": len(results)} | |
| except Exception as e: | |
| logger.error(f"Batch prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |