chenhaoq87 commited on
Commit
49a17ee
·
verified ·
1 Parent(s): fff3c77

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +253 -0
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Variant FastAPI REST API for Milk Spoilage Classification
3
+
4
+ This API supports multiple model variants with different feature subsets.
5
+ Perfect for Custom GPT integration - allows selecting the optimal model
6
+ based on available data and prediction needs.
7
+ """
8
+
9
+ from fastapi import FastAPI, HTTPException
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel, Field
12
+ import joblib
13
+ import numpy as np
14
+ from typing import Dict, Optional, List
15
+ import os
16
+ import json
17
+ from pathlib import Path
18
+
19
+ # Load all model variants
20
+ VARIANTS_DIR = Path("model/variants")
21
+ if not VARIANTS_DIR.exists():
22
+ # Try alternate path for local development
23
+ VARIANTS_DIR = Path(__file__).parent.parent.parent / "model" / "variants"
24
+
25
+ # Load variants config
26
+ config_path = VARIANTS_DIR / "variants_config.json"
27
+ if not config_path.exists():
28
+ raise FileNotFoundError(f"variants_config.json not found at {config_path}")
29
+
30
+ with open(config_path) as f:
31
+ VARIANTS_CONFIG = json.load(f)
32
+
33
+ # Load all model files
34
+ MODELS = {}
35
+ for variant_id in VARIANTS_CONFIG['variants'].keys():
36
+ model_path = VARIANTS_DIR / f"{variant_id}.joblib"
37
+ if model_path.exists():
38
+ MODELS[variant_id] = joblib.load(model_path)
39
+ else:
40
+ print(f"Warning: Model file not found for variant {variant_id}")
41
+
42
+ print(f"✓ Loaded {len(MODELS)} model variants: {list(MODELS.keys())}")
43
+
44
+ # Create FastAPI app
45
+ app = FastAPI(
46
+ title="Milk Spoilage Classification API (Multi-Variant)",
47
+ description="""
48
+ AI-powered milk spoilage classification with multiple model variants.
49
+
50
+ **10 Model Variants Available:**
51
+ - **baseline**: All features (best accuracy: 95.8%)
52
+ - **scenario_1_days14_21**: Days 14 & 21 only (94.2%)
53
+ - **scenario_3_day21**: Day 21 only (93.7%)
54
+ - **scenario_4_day14**: Day 14 only (87.4%)
55
+ - **scenario_2_days7_14**: Days 7 & 14 (87.3%)
56
+ - **scenario_6_spc_all**: SPC only - all days (78.3%)
57
+ - **scenario_8_spc_7_14**: SPC days 7 & 14 (73.3%)
58
+ - **scenario_9_tgn_7_14**: TGN days 7 & 14 (73.1%)
59
+ - **scenario_7_tgn_all**: TGN only - all days (69.9%)
60
+ - **scenario_5_day7**: Day 7 only (62.8%)
61
+
62
+ Select the variant based on your available data. If you have all measurements,
63
+ use 'baseline' for best accuracy. If you only have partial data, choose the
64
+ appropriate scenario variant.
65
+ """,
66
+ version="2.0.0"
67
+ )
68
+
69
+ # Add CORS middleware
70
+ app.add_middleware(
71
+ CORSMiddleware,
72
+ allow_origins=["*"],
73
+ allow_credentials=False,
74
+ allow_methods=["*"],
75
+ allow_headers=["*"],
76
+ max_age=3600,
77
+ )
78
+
79
+ # Request/Response models
80
+ class PredictionInput(BaseModel):
81
+ spc_d7: Optional[float] = Field(None, description="Standard Plate Count at Day 7 (log CFU/mL)", ge=0.0, le=10.0)
82
+ spc_d14: Optional[float] = Field(None, description="Standard Plate Count at Day 14 (log CFU/mL)", ge=0.0, le=10.0)
83
+ spc_d21: Optional[float] = Field(None, description="Standard Plate Count at Day 21 (log CFU/mL)", ge=0.0, le=10.0)
84
+ tgn_d7: Optional[float] = Field(None, description="Total Gram-Negative at Day 7 (log CFU/mL)", ge=0.0, le=10.0)
85
+ tgn_d14: Optional[float] = Field(None, description="Total Gram-Negative at Day 14 (log CFU/mL)", ge=0.0, le=10.0)
86
+ tgn_d21: Optional[float] = Field(None, description="Total Gram-Negative at Day 21 (log CFU/mL)", ge=0.0, le=10.0)
87
+ model_variant: str = Field(
88
+ "baseline",
89
+ description="Model variant to use (baseline, scenario_1_days14_21, scenario_3_day21, etc.)"
90
+ )
91
+
92
+ class Config:
93
+ json_schema_extra = {
94
+ "example": {
95
+ "spc_d7": 2.1,
96
+ "spc_d14": 4.7,
97
+ "spc_d21": 6.4,
98
+ "tgn_d7": 1.0,
99
+ "tgn_d14": 3.7,
100
+ "tgn_d21": 5.3,
101
+ "model_variant": "baseline"
102
+ }
103
+ }
104
+
105
+ class VariantInfo(BaseModel):
106
+ variant_id: str
107
+ name: str
108
+ description: str
109
+ features: List[str]
110
+ test_accuracy: float
111
+
112
+ class PredictionOutput(BaseModel):
113
+ prediction: str = Field(..., description="Predicted spoilage class")
114
+ probabilities: Dict[str, float] = Field(..., description="Probability for each class")
115
+ confidence: float = Field(..., description="Confidence score (max probability)")
116
+ variant_used: VariantInfo = Field(..., description="Information about the model variant used")
117
+
118
+
119
+ def extract_features(input_data: PredictionInput, required_features: List[str]) -> np.ndarray:
120
+ """Extract required features from input data."""
121
+ feature_map = {
122
+ 'SPC_D7': input_data.spc_d7,
123
+ 'SPC_D14': input_data.spc_d14,
124
+ 'SPC_D21': input_data.spc_d21,
125
+ 'TGN_D7': input_data.tgn_d7,
126
+ 'TGN_D14': input_data.tgn_d14,
127
+ 'TGN_D21': input_data.tgn_d21,
128
+ }
129
+
130
+ # Check for missing required features
131
+ missing = [f for f in required_features if feature_map[f] is None]
132
+ if missing:
133
+ raise HTTPException(
134
+ status_code=400,
135
+ detail=f"Missing required features for variant: {', '.join(missing)}"
136
+ )
137
+
138
+ # Extract and convert from log to raw CFU/mL
139
+ features = [10 ** feature_map[f] for f in required_features]
140
+ return np.array([features])
141
+
142
+
143
+ @app.get("/")
144
+ async def root():
145
+ """Root endpoint with API information."""
146
+ return {
147
+ "message": "Milk Spoilage Classification API - Multi-Variant",
148
+ "version": "2.0.0",
149
+ "variants_available": len(MODELS),
150
+ "endpoints": {
151
+ "predict": "/predict",
152
+ "variants": "/variants",
153
+ "health": "/health",
154
+ "docs": "/docs"
155
+ }
156
+ }
157
+
158
+
159
+ @app.get("/variants", tags=["Variants"])
160
+ async def list_variants():
161
+ """List all available model variants with their metadata."""
162
+ variants_list = []
163
+ for variant_id, metadata in VARIANTS_CONFIG['variants'].items():
164
+ variants_list.append({
165
+ "variant_id": variant_id,
166
+ "name": metadata['name'],
167
+ "description": metadata['description'],
168
+ "features": metadata['features'],
169
+ "test_accuracy": metadata['test_accuracy'],
170
+ "n_features": len(metadata['features'])
171
+ })
172
+
173
+ # Sort by test accuracy descending
174
+ variants_list.sort(key=lambda x: x['test_accuracy'], reverse=True)
175
+
176
+ return {
177
+ "total_variants": len(variants_list),
178
+ "variants": variants_list
179
+ }
180
+
181
+
182
+ @app.post("/predict", response_model=PredictionOutput, tags=["Prediction"])
183
+ async def predict(input_data: PredictionInput):
184
+ """
185
+ Predict milk spoilage type using the specified model variant.
186
+
187
+ **How to choose a variant:**
188
+ - If you have all 6 measurements → use 'baseline' (best accuracy)
189
+ - If you only have Day 21 data → use 'scenario_3_day21'
190
+ - If you only have Day 14 data → use 'scenario_4_day14'
191
+ - If you only have SPC measurements → use 'scenario_6_spc_all'
192
+ - etc.
193
+
194
+ The API will validate that you've provided all required features for the selected variant.
195
+ """
196
+ # Validate variant exists
197
+ if input_data.model_variant not in MODELS:
198
+ raise HTTPException(
199
+ status_code=400,
200
+ detail=f"Unknown variant '{input_data.model_variant}'. Use /variants to see available options."
201
+ )
202
+
203
+ # Get model and metadata
204
+ model = MODELS[input_data.model_variant]
205
+ variant_meta = VARIANTS_CONFIG['variants'][input_data.model_variant]
206
+ required_features = variant_meta['features']
207
+
208
+ # Extract features
209
+ try:
210
+ features = extract_features(input_data, required_features)
211
+ except HTTPException as e:
212
+ raise e
213
+
214
+ # Make prediction
215
+ prediction = model.predict(features)[0]
216
+ probabilities = model.predict_proba(features)[0]
217
+
218
+ # Format response
219
+ prob_dict = {
220
+ str(cls): float(prob)
221
+ for cls, prob in zip(model.classes_, probabilities)
222
+ }
223
+
224
+ variant_info = VariantInfo(
225
+ variant_id=input_data.model_variant,
226
+ name=variant_meta['name'],
227
+ description=variant_meta['description'],
228
+ features=required_features,
229
+ test_accuracy=variant_meta['test_accuracy']
230
+ )
231
+
232
+ return PredictionOutput(
233
+ prediction=str(prediction),
234
+ probabilities=prob_dict,
235
+ confidence=float(max(probabilities)),
236
+ variant_used=variant_info
237
+ )
238
+
239
+
240
+ @app.get("/health", tags=["Health"])
241
+ async def health_check():
242
+ """Health check endpoint."""
243
+ return {
244
+ "status": "healthy",
245
+ "models_loaded": len(MODELS),
246
+ "variants": list(MODELS.keys()),
247
+ "classes": MODELS['baseline'].classes_.tolist() if 'baseline' in MODELS else []
248
+ }
249
+
250
+
251
+ if __name__ == "__main__":
252
+ import uvicorn
253
+ uvicorn.run(app, host="0.0.0.0", port=7860)