chenhaoq87 commited on
Commit
d5c82d8
·
verified ·
1 Parent(s): 096b4be

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +487 -0
app.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E.coli Preharvest Risk Model - FastAPI Inference Application
3
+
4
+ This API provides endpoints for making predictions on E.coli contamination risk
5
+ using the trained machine learning model.
6
+ """
7
+
8
+ from fastapi import FastAPI, HTTPException
9
+ from pydantic import BaseModel, Field
10
+ from typing import List, Dict, Optional
11
+ import joblib
12
+ import pandas as pd
13
+ import numpy as np
14
+ import json
15
+ import os
16
+ from pathlib import Path
17
+ import logging
18
+
19
+ # Setup logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Initialize FastAPI app
24
+ app = FastAPI(
25
+ title="E.coli Preharvest Risk Prediction API",
26
+ description="API for predicting E.coli contamination risk in preharvest produce",
27
+ version="1.0.0"
28
+ )
29
+
30
+ # Global variables for model artifacts
31
+ MODEL = None
32
+ PREPROCESSOR = None
33
+ FEATURE_NAMES = None
34
+ MODEL_METRICS = None
35
+ MODEL_COMPARISON = None
36
+
37
+
38
+ class PredictionInput(BaseModel):
39
+ """Input schema for prediction requests."""
40
+ org_conv_kiptraq: str = Field(..., description="Organic or Conventional")
41
+ acres_kiptraq: float = Field(..., description="Farm size in acres")
42
+ lat: float = Field(..., description="Latitude")
43
+ lon: float = Field(..., description="Longitude")
44
+ season: str = Field(..., description="Season (Spring, Summer, Fall, Winter)")
45
+
46
+ # Day 0 weather features
47
+ temperature_avg_d0: float
48
+ temperature_max_d0: float
49
+ temperature_min_d0: float
50
+ humidity_avg_d0: float
51
+ humidity_max_d0: float
52
+ humidity_min_d0: float
53
+ wind_speed_avg_d0: float
54
+ wind_speed_max_d0: float
55
+ wind_speed_min_d0: float
56
+ wind_run_avg_d0: float
57
+ wind_run_max_d0: float
58
+ wind_run_min_d0: float
59
+ wind_chill_avg_d0: float
60
+ wind_chill_max_d0: float
61
+ wind_chill_min_d0: float
62
+ heat_index_avg_d0: float
63
+ heat_index_max_d0: float
64
+ heat_index_min_d0: float
65
+ thw_index_avg_d0: float
66
+ thw_index_max_d0: float
67
+ thw_index_min_d0: float
68
+ rain_avg_d0: float
69
+ rain_max_d0: float
70
+ rain_min_d0: float
71
+ rain_rate_avg_d0: float
72
+ rain_rate_max_d0: float
73
+ rain_rate_min_d0: float
74
+ solar_radiation_max_d0: float
75
+ solar_radiation_min_d0: float
76
+ ET_avg_d0: float
77
+ ET_max_d0: float
78
+ ET_min_d0: float
79
+ heating_degree_days_avg_d0: float
80
+ heating_degree_days_max_d0: float
81
+ heating_degree_days_min_d0: float
82
+ cooling_degree_days_avg_d0: float
83
+ cooling_degree_days_max_d0: float
84
+ cooling_degree_days_min_d0: float
85
+ wind_direction_mode_d0: str
86
+
87
+ # Day 1 before weather features
88
+ temperature_avg_d1_before: float
89
+ temperature_max_d1_before: float
90
+ temperature_min_d1_before: float
91
+ humidity_avg_d1_before: float
92
+ humidity_max_d1_before: float
93
+ humidity_min_d1_before: float
94
+ wind_speed_avg_d1_before: float
95
+ wind_speed_max_d1_before: float
96
+ wind_speed_min_d1_before: float
97
+ wind_run_avg_d1_before: float
98
+ wind_run_max_d1_before: float
99
+ wind_run_min_d1_before: float
100
+ wind_chill_avg_d1_before: float
101
+ wind_chill_max_d1_before: float
102
+ wind_chill_min_d1_before: float
103
+ heat_index_avg_d1_before: float
104
+ heat_index_max_d1_before: float
105
+ heat_index_min_d1_before: float
106
+ thw_index_avg_d1_before: float
107
+ thw_index_max_d1_before: float
108
+ thw_index_min_d1_before: float
109
+ rain_avg_d1_before: float
110
+ rain_max_d1_before: float
111
+ rain_min_d1_before: float
112
+ rain_rate_avg_d1_before: float
113
+ rain_rate_max_d1_before: float
114
+ rain_rate_min_d1_before: float
115
+ solar_radiation_max_d1_before: float
116
+ solar_radiation_min_d1_before: float
117
+ ET_avg_d1_before: float
118
+ ET_max_d1_before: float
119
+ ET_min_d1_before: float
120
+ heating_degree_days_avg_d1_before: float
121
+ heating_degree_days_max_d1_before: float
122
+ heating_degree_days_min_d1_before: float
123
+ cooling_degree_days_avg_d1_before: float
124
+ cooling_degree_days_max_d1_before: float
125
+ cooling_degree_days_min_d1_before: float
126
+ wind_direction_mode_d1_before: str
127
+
128
+ # Day 3 before weather features
129
+ temperature_avg_d3_before: float
130
+ temperature_max_d3_before: float
131
+ temperature_min_d3_before: float
132
+ humidity_avg_d3_before: float
133
+ humidity_max_d3_before: float
134
+ humidity_min_d3_before: float
135
+ wind_speed_avg_d3_before: float
136
+ wind_speed_max_d3_before: float
137
+ wind_speed_min_d3_before: float
138
+ wind_run_avg_d3_before: float
139
+ wind_run_max_d3_before: float
140
+ wind_run_min_d3_before: float
141
+ wind_chill_avg_d3_before: float
142
+ wind_chill_max_d3_before: float
143
+ wind_chill_min_d3_before: float
144
+ heat_index_avg_d3_before: float
145
+ heat_index_max_d3_before: float
146
+ heat_index_min_d3_before: float
147
+ thw_index_avg_d3_before: float
148
+ thw_index_max_d3_before: float
149
+ thw_index_min_d3_before: float
150
+ rain_avg_d3_before: float
151
+ rain_max_d3_before: float
152
+ rain_min_d3_before: float
153
+ rain_rate_avg_d3_before: float
154
+ rain_rate_max_d3_before: float
155
+ rain_rate_min_d3_before: float
156
+ solar_radiation_max_d3_before: float
157
+ solar_radiation_min_d3_before: float
158
+ ET_avg_d3_before: float
159
+ ET_max_d3_before: float
160
+ ET_min_d3_before: float
161
+ heating_degree_days_avg_d3_before: float
162
+ heating_degree_days_max_d3_before: float
163
+ heating_degree_days_min_d3_before: float
164
+ cooling_degree_days_avg_d3_before: float
165
+ cooling_degree_days_max_d3_before: float
166
+ cooling_degree_days_min_d3_before: float
167
+ wind_direction_mode_d3_before: str
168
+
169
+ # Day 7 before weather features
170
+ temperature_avg_d7_before: float
171
+ temperature_max_d7_before: float
172
+ temperature_min_d7_before: float
173
+ humidity_avg_d7_before: float
174
+ humidity_max_d7_before: float
175
+ humidity_min_d7_before: float
176
+ wind_speed_avg_d7_before: float
177
+ wind_speed_max_d7_before: float
178
+ wind_speed_min_d7_before: float
179
+ wind_run_avg_d7_before: float
180
+ wind_run_max_d7_before: float
181
+ wind_run_min_d7_before: float
182
+ wind_chill_avg_d7_before: float
183
+ wind_chill_max_d7_before: float
184
+ wind_chill_min_d7_before: float
185
+ heat_index_avg_d7_before: float
186
+ heat_index_max_d7_before: float
187
+ heat_index_min_d7_before: float
188
+ thw_index_avg_d7_before: float
189
+ thw_index_max_d7_before: float
190
+ thw_index_min_d7_before: float
191
+ rain_avg_d7_before: float
192
+ rain_max_d7_before: float
193
+ rain_min_d7_before: float
194
+ rain_rate_avg_d7_before: float
195
+ rain_rate_max_d7_before: float
196
+ rain_rate_min_d7_before: float
197
+ solar_radiation_max_d7_before: float
198
+ solar_radiation_min_d7_before: float
199
+ ET_avg_d7_before: float
200
+ ET_max_d7_before: float
201
+ ET_min_d7_before: float
202
+ heating_degree_days_avg_d7_before: float
203
+ heating_degree_days_max_d7_before: float
204
+ heating_degree_days_min_d7_before: float
205
+ cooling_degree_days_avg_d7_before: float
206
+ cooling_degree_days_max_d7_before: float
207
+ cooling_degree_days_min_d7_before: float
208
+ wind_direction_mode_d7_before: str
209
+
210
+ class Config:
211
+ schema_extra = {
212
+ "example": {
213
+ "org_conv_kiptraq": "Conventional",
214
+ "acres_kiptraq": 10.0,
215
+ "lat": 36.5,
216
+ "lon": -121.5,
217
+ "season": "Fall",
218
+ "temperature_avg_d0": 70.0,
219
+ "temperature_max_d0": 85.0,
220
+ "temperature_min_d0": 55.0,
221
+ "humidity_avg_d0": 65.0,
222
+ "humidity_max_d0": 85.0,
223
+ "humidity_min_d0": 45.0,
224
+ "wind_speed_avg_d0": 5.0,
225
+ "wind_speed_max_d0": 12.0,
226
+ "wind_speed_min_d0": 0.0,
227
+ "wind_run_avg_d0": 1.2,
228
+ "wind_run_max_d0": 3.0,
229
+ "wind_run_min_d0": 0.0,
230
+ "wind_chill_avg_d0": 68.0,
231
+ "wind_chill_max_d0": 85.0,
232
+ "wind_chill_min_d0": 55.0,
233
+ "heat_index_avg_d0": 70.0,
234
+ "heat_index_max_d0": 85.0,
235
+ "heat_index_min_d0": 55.0,
236
+ "thw_index_avg_d0": 68.0,
237
+ "thw_index_max_d0": 85.0,
238
+ "thw_index_min_d0": 55.0,
239
+ "rain_avg_d0": 0.0,
240
+ "rain_max_d0": 0.0,
241
+ "rain_min_d0": 0.0,
242
+ "rain_rate_avg_d0": 0.0,
243
+ "rain_rate_max_d0": 0.0,
244
+ "rain_rate_min_d0": 0.0,
245
+ "solar_radiation_max_d0": 850.0,
246
+ "solar_radiation_min_d0": 0.0,
247
+ "ET_avg_d0": 0.15,
248
+ "ET_max_d0": 0.25,
249
+ "ET_min_d0": 0.0,
250
+ "heating_degree_days_avg_d0": 0.0,
251
+ "heating_degree_days_max_d0": 0.0,
252
+ "heating_degree_days_min_d0": 0.0,
253
+ "cooling_degree_days_avg_d0": 5.0,
254
+ "cooling_degree_days_max_d0": 20.0,
255
+ "cooling_degree_days_min_d0": 0.0,
256
+ "wind_direction_mode_d0": "W",
257
+ # Similar pattern for d1, d3, d7
258
+ # (abbreviated for brevity)
259
+ }
260
+ }
261
+
262
+
263
+ class PredictionOutput(BaseModel):
264
+ """Output schema for prediction responses."""
265
+ prediction: str = Field(..., description="Predicted class: 'Positive' or 'Negative'")
266
+ probability_positive: float = Field(..., description="Probability of E.coli positive")
267
+ probability_negative: float = Field(..., description="Probability of E.coli negative")
268
+ risk_level: str = Field(..., description="Risk level: 'Low', 'Medium', or 'High'")
269
+
270
+
271
+ class ModelInfo(BaseModel):
272
+ """Model information schema."""
273
+ algorithm: str
274
+ training_date: str
275
+ metrics: Dict
276
+ top_features: Dict
277
+
278
+
279
+ @app.on_event("startup")
280
+ async def load_model_artifacts():
281
+ """Load model artifacts on application startup."""
282
+ global MODEL, PREPROCESSOR, FEATURE_NAMES, MODEL_METRICS, MODEL_COMPARISON
283
+
284
+ model_dir = Path("model")
285
+
286
+ try:
287
+ # Load model
288
+ model_path = model_dir / "best_model.joblib"
289
+ MODEL = joblib.load(model_path)
290
+ logger.info(f"Loaded model from {model_path}")
291
+
292
+ # Load preprocessor
293
+ preprocessor_path = model_dir / "preprocessor.joblib"
294
+ PREPROCESSOR = joblib.load(preprocessor_path)
295
+ logger.info(f"Loaded preprocessor from {preprocessor_path}")
296
+
297
+ # Load feature names
298
+ feature_names_path = model_dir / "feature_names.json"
299
+ with open(feature_names_path, 'r') as f:
300
+ FEATURE_NAMES = json.load(f)
301
+ logger.info(f"Loaded {len(FEATURE_NAMES)} feature names")
302
+
303
+ # Load model metrics
304
+ metrics_path = model_dir / "model_metrics.json"
305
+ with open(metrics_path, 'r') as f:
306
+ MODEL_METRICS = json.load(f)
307
+ logger.info(f"Loaded model metrics for {MODEL_METRICS.get('winning_algorithm', 'unknown')}")
308
+
309
+ # Load model comparison
310
+ comparison_path = model_dir / "model_comparison.json"
311
+ with open(comparison_path, 'r') as f:
312
+ MODEL_COMPARISON = json.load(f)
313
+ logger.info(f"Loaded comparison for {len(MODEL_COMPARISON)} models")
314
+
315
+ logger.info("All model artifacts loaded successfully!")
316
+
317
+ except Exception as e:
318
+ logger.error(f"Error loading model artifacts: {e}")
319
+ raise
320
+
321
+
322
+ def preprocess_input(input_data: PredictionInput) -> pd.DataFrame:
323
+ """
324
+ Convert input data to DataFrame with proper format for prediction.
325
+
326
+ Args:
327
+ input_data: Pydantic model with input features
328
+
329
+ Returns:
330
+ pd.DataFrame: Preprocessed features
331
+ """
332
+ # Convert to dictionary
333
+ data_dict = input_data.dict()
334
+
335
+ # Create DataFrame
336
+ df = pd.DataFrame([data_dict])
337
+
338
+ # Apply one-hot encoding for categorical variables (same as training)
339
+ categorical_cols = ['org_conv_kiptraq', 'season',
340
+ 'wind_direction_mode_d0', 'wind_direction_mode_d1_before',
341
+ 'wind_direction_mode_d3_before', 'wind_direction_mode_d7_before']
342
+
343
+ df = pd.get_dummies(df, columns=categorical_cols, drop_first=False)
344
+
345
+ # Align columns with training data
346
+ # Add missing columns with 0 values
347
+ for col in FEATURE_NAMES:
348
+ if col not in df.columns:
349
+ df[col] = 0
350
+
351
+ # Keep only columns that were in training
352
+ df = df[FEATURE_NAMES]
353
+
354
+ return df
355
+
356
+
357
+ @app.get("/")
358
+ async def root():
359
+ """Root endpoint with API information."""
360
+ return {
361
+ "message": "E.coli Preharvest Risk Prediction API",
362
+ "version": "1.0.0",
363
+ "endpoints": {
364
+ "/predict": "POST - Single prediction",
365
+ "/predict_batch": "POST - Batch predictions",
366
+ "/model_info": "GET - Model information and metrics",
367
+ "/health": "GET - Health check"
368
+ }
369
+ }
370
+
371
+
372
+ @app.get("/health")
373
+ async def health_check():
374
+ """Health check endpoint."""
375
+ if MODEL is None:
376
+ raise HTTPException(status_code=503, detail="Model not loaded")
377
+
378
+ return {
379
+ "status": "healthy",
380
+ "model_loaded": MODEL is not None,
381
+ "algorithm": MODEL_METRICS.get('winning_algorithm', 'unknown') if MODEL_METRICS else 'unknown'
382
+ }
383
+
384
+
385
+ @app.get("/model_info", response_model=ModelInfo)
386
+ async def get_model_info():
387
+ """Get model information and performance metrics."""
388
+ if MODEL_METRICS is None:
389
+ raise HTTPException(status_code=503, detail="Model metrics not loaded")
390
+
391
+ return ModelInfo(
392
+ algorithm=MODEL_METRICS.get('winning_algorithm', 'unknown'),
393
+ training_date=MODEL_METRICS.get('training_date', 'unknown'),
394
+ metrics=MODEL_METRICS.get('metrics', {}),
395
+ top_features=MODEL_METRICS.get('top_features', {})
396
+ )
397
+
398
+
399
+ @app.get("/model_comparison")
400
+ async def get_model_comparison():
401
+ """Get comparison results for all trained models."""
402
+ if MODEL_COMPARISON is None:
403
+ raise HTTPException(status_code=503, detail="Model comparison not loaded")
404
+
405
+ return {
406
+ "comparison": MODEL_COMPARISON,
407
+ "winner": MODEL_METRICS.get('winning_algorithm', 'unknown') if MODEL_METRICS else 'unknown'
408
+ }
409
+
410
+
411
+ @app.post("/predict", response_model=PredictionOutput)
412
+ async def predict(input_data: PredictionInput):
413
+ """
414
+ Make a single prediction.
415
+
416
+ Args:
417
+ input_data: Input features for prediction
418
+
419
+ Returns:
420
+ PredictionOutput: Prediction result with probabilities
421
+ """
422
+ if MODEL is None:
423
+ raise HTTPException(status_code=503, detail="Model not loaded")
424
+
425
+ try:
426
+ # Preprocess input
427
+ df = preprocess_input(input_data)
428
+
429
+ # Make prediction
430
+ prediction = MODEL.predict(df)[0]
431
+ probabilities = MODEL.predict_proba(df)[0]
432
+
433
+ # Get probability for each class
434
+ # Classes are in order: ['Negative', 'Positive']
435
+ prob_negative = float(probabilities[0])
436
+ prob_positive = float(probabilities[1])
437
+
438
+ # Determine risk level
439
+ if prob_positive < 0.3:
440
+ risk_level = "Low"
441
+ elif prob_positive < 0.7:
442
+ risk_level = "Medium"
443
+ else:
444
+ risk_level = "High"
445
+
446
+ return PredictionOutput(
447
+ prediction=prediction,
448
+ probability_positive=prob_positive,
449
+ probability_negative=prob_negative,
450
+ risk_level=risk_level
451
+ )
452
+
453
+ except Exception as e:
454
+ logger.error(f"Prediction error: {e}")
455
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
456
+
457
+
458
+ @app.post("/predict_batch")
459
+ async def predict_batch(input_data: List[PredictionInput]):
460
+ """
461
+ Make batch predictions.
462
+
463
+ Args:
464
+ input_data: List of input features for prediction
465
+
466
+ Returns:
467
+ List of prediction results
468
+ """
469
+ if MODEL is None:
470
+ raise HTTPException(status_code=503, detail="Model not loaded")
471
+
472
+ try:
473
+ results = []
474
+ for data in input_data:
475
+ result = await predict(data)
476
+ results.append(result.dict())
477
+
478
+ return {"predictions": results, "count": len(results)}
479
+
480
+ except Exception as e:
481
+ logger.error(f"Batch prediction error: {e}")
482
+ raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")
483
+
484
+
485
+ if __name__ == "__main__":
486
+ import uvicorn
487
+ uvicorn.run(app, host="0.0.0.0", port=8000)