Harshilforworks commited on
Commit
bae8d6e
·
verified ·
1 Parent(s): 4f45196

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -20
app.py CHANGED
@@ -16,21 +16,8 @@ import logging
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
- # Initialize FastAPI app
20
- app = FastAPI(
21
- title="MediGuard Disease Prediction API",
22
- description="AI-powered disease prediction using stacking ensemble",
23
- version="1.0.0"
24
- )
25
-
26
- # CORS middleware
27
- app.add_middleware(
28
- CORSMiddleware,
29
- allow_origins=["*"],
30
- allow_credentials=True,
31
- allow_methods=["*"],
32
- allow_headers=["*"],
33
- )
34
 
35
  # Model directory
36
  MODEL_DIR = Path(".")
@@ -61,7 +48,7 @@ class PatientData(BaseModel):
61
  features: List[float] = Field(
62
  ...,
63
  description="List of biomarker values in the correct order",
64
- example=[13.2, 165, 245, 280, 7.5, 4.8, 42, 88, 28, 33, 18, 32.5, 145, 92, 210, 7.8, 145, 38, 35, 28, 78, 1.1, 0.01, 2.8]
65
  )
66
 
67
 
@@ -79,11 +66,15 @@ class HealthResponse(BaseModel):
79
  feature_count: int
80
 
81
 
82
- @app.on_event("startup")
83
- async def load_models():
84
- """Load all trained models on startup"""
 
 
 
85
  global rf_model, nn_model, meta_model, scaler, label_encoder, feature_cols
86
 
 
87
  try:
88
  logger.info("Loading models...")
89
 
@@ -108,6 +99,29 @@ async def load_models():
108
  except Exception as e:
109
  logger.error(f"❌ Error loading models: {e}")
110
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
  def predict_disease(patient_features: np.ndarray):
@@ -157,7 +171,7 @@ def predict_disease(patient_features: np.ndarray):
157
  return disease, confidence, top_3
158
 
159
 
160
- @app.get("/", response_model=Dict[str, str])
161
  async def root():
162
  """Root endpoint"""
163
  return {
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
+ # Initialize FastAPI app with lifespan (will be defined below)
20
+ # We need to define lifespan first, then create app
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Model directory
23
  MODEL_DIR = Path(".")
 
48
  features: List[float] = Field(
49
  ...,
50
  description="List of biomarker values in the correct order",
51
+ json_schema_extra={"example": [13.2, 165, 245, 280, 7.5, 4.8, 42, 88, 28, 33, 18, 32.5, 145, 92, 210, 7.8, 145, 38, 35, 28, 78, 1.1, 0.01, 2.8]}
52
  )
53
 
54
 
 
66
  feature_count: int
67
 
68
 
69
+ from contextlib import asynccontextmanager
70
+ from typing import AsyncGenerator
71
+
72
+ @asynccontextmanager
73
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
74
+ """Lifespan context manager for startup and shutdown events"""
75
  global rf_model, nn_model, meta_model, scaler, label_encoder, feature_cols
76
 
77
+ # Startup
78
  try:
79
  logger.info("Loading models...")
80
 
 
99
  except Exception as e:
100
  logger.error(f"❌ Error loading models: {e}")
101
  raise
102
+
103
+ yield
104
+
105
+ # Shutdown (cleanup if needed)
106
+ logger.info("Shutting down...")
107
+
108
+
109
+ # Initialize FastAPI app with lifespan
110
+ app = FastAPI(
111
+ title="MediGuard Disease Prediction API",
112
+ description="AI-powered disease prediction using stacking ensemble",
113
+ version="1.0.0",
114
+ lifespan=lifespan
115
+ )
116
+
117
+ # CORS middleware
118
+ app.add_middleware(
119
+ CORSMiddleware,
120
+ allow_origins=["*"],
121
+ allow_credentials=True,
122
+ allow_methods=["*"],
123
+ allow_headers=["*"],
124
+ )
125
 
126
 
127
  def predict_disease(patient_features: np.ndarray):
 
171
  return disease, confidence, top_3
172
 
173
 
174
+ @app.get("/")
175
  async def root():
176
  """Root endpoint"""
177
  return {