Panic_Detection / app.py
Shib-Sankar-Das's picture
Upgrade to TensorFlow 2.18 with latest packages
80b19f7
"""
FastAPI application for Panic Detection Model
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Literal
import tensorflow as tf
import numpy as np
from preprocessing import PanicDetectionPreprocessor
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Panic Detection API",
description="API for detecting panic situations based on health metrics",
version="1.0.0"
)
# Add CORS middleware for Flutter app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify your Flutter app's domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and preprocessor
try:
# Load H5 model using legacy Keras format
# Force use of tf.keras instead of standalone keras
import tensorflow.keras.models
model = tensorflow.keras.models.load_model('cnn_model.h5', compile=False)
# Recompile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
preprocessor = PanicDetectionPreprocessor()
logger.info("Model and preprocessor loaded successfully")
logger.info(f"Model input shape: {model.input_shape}")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
logger.error("Please ensure the model file (cnn_model.h5) exists")
raise
# Request model
class PredictionRequest(BaseModel):
gender: Literal['Male', 'Female'] = Field(..., description="Gender of the person")
age: int = Field(..., ge=1, le=120, description="Age in years")
weight: float = Field(..., ge=20, le=300, description="Weight in kg")
heartrate: int = Field(..., ge=40, le=220, description="Heart rate in bpm")
stepcount: int = Field(..., ge=0, le=50000, description="Step count")
activity: Literal['Running', 'Walking', 'Sitting', 'Standing', 'Cycling'] = Field(
..., description="Current activity"
)
model_config = {
"json_schema_extra": {
"examples": [{
"gender": "Male",
"age": 25,
"weight": 70.5,
"heartrate": 120,
"stepcount": 5000,
"activity": "Running"
}]
}
}
# Response model
class PredictionResponse(BaseModel):
panic_detected: bool
panic_probability: float
confidence: str
input_data: dict
model_config = {
"json_schema_extra": {
"examples": [{
"panic_detected": True,
"panic_probability": 0.85,
"confidence": "high",
"input_data": {
"gender": "Male",
"age": 25,
"weight": 70.5,
"heartrate": 120,
"stepcount": 5000,
"activity": "Running"
}
}]
}
}
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "Panic Detection API is running",
"version": "1.0.0",
"endpoints": {
"predict": "/predict",
"health": "/health",
"model_info": "/model-info"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"model_loaded": model is not None,
"preprocessor_loaded": preprocessor is not None
}
@app.get("/model-info")
async def model_info():
"""Get model information"""
return {
"model_type": "CNN (Convolutional Neural Network)",
"input_features": [
"gender", "age", "weight", "heartrate", "stepcount", "activity"
],
"output": "Binary classification (Panic/No Panic)",
"valid_genders": preprocessor.get_valid_genders(),
"valid_activities": preprocessor.get_valid_activities()
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""
Predict panic situation based on input health metrics
Args:
request: PredictionRequest containing health metrics
Returns:
PredictionResponse with panic detection results
"""
try:
# Log the request
logger.info(f"Received prediction request: {request.dict()}")
# Preprocess input
preprocessed_data = preprocessor.preprocess(
gender=request.gender,
age=request.age,
weight=request.weight,
heartrate=request.heartrate,
stepcount=request.stepcount,
activity=request.activity
)
# Make prediction
prediction = model.predict(preprocessed_data, verbose=0)
panic_probability = float(prediction[0][0])
# Determine panic detection (threshold = 0.5)
panic_detected = panic_probability >= 0.5
# Determine confidence level
if panic_probability >= 0.8 or panic_probability <= 0.2:
confidence = "high"
elif panic_probability >= 0.6 or panic_probability <= 0.4:
confidence = "medium"
else:
confidence = "low"
# Prepare response
response = PredictionResponse(
panic_detected=panic_detected,
panic_probability=round(panic_probability, 4),
confidence=confidence,
input_data=request.dict()
)
logger.info(f"Prediction result: {response.dict()}")
return response
except ValueError as ve:
logger.error(f"Validation error: {str(ve)}")
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@app.post("/batch-predict")
async def batch_predict(requests: list[PredictionRequest]):
"""
Predict panic for multiple inputs
Args:
requests: List of PredictionRequest
Returns:
List of prediction results
"""
try:
results = []
for req in requests:
result = await predict(req)
results.append(result)
return {"predictions": results, "count": len(results)}
except Exception as e:
logger.error(f"Batch prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)