|
|
""" |
|
|
Multi-Variant FastAPI REST API for Milk Spoilage Classification |
|
|
|
|
|
This API supports multiple model variants with different feature subsets. |
|
|
Perfect for Custom GPT integration - allows selecting the optimal model |
|
|
based on available data and prediction needs. |
|
|
""" |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
import joblib |
|
|
import numpy as np |
|
|
from typing import Dict, Optional, List |
|
|
import os |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
VARIANTS_DIR = Path("model/variants") |
|
|
if not VARIANTS_DIR.exists(): |
|
|
|
|
|
VARIANTS_DIR = Path(__file__).parent.parent.parent / "model" / "variants" |
|
|
|
|
|
|
|
|
config_path = VARIANTS_DIR / "variants_config.json" |
|
|
if not config_path.exists(): |
|
|
raise FileNotFoundError(f"variants_config.json not found at {config_path}") |
|
|
|
|
|
with open(config_path) as f: |
|
|
VARIANTS_CONFIG = json.load(f) |
|
|
|
|
|
|
|
|
MODELS = {} |
|
|
for variant_id in VARIANTS_CONFIG['variants'].keys(): |
|
|
model_path = VARIANTS_DIR / f"{variant_id}.joblib" |
|
|
if model_path.exists(): |
|
|
MODELS[variant_id] = joblib.load(model_path) |
|
|
else: |
|
|
print(f"Warning: Model file not found for variant {variant_id}") |
|
|
|
|
|
print(f"✓ Loaded {len(MODELS)} model variants: {list(MODELS.keys())}") |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Milk Spoilage Classification API (Multi-Variant)", |
|
|
description=""" |
|
|
AI-powered milk spoilage classification with multiple model variants. |
|
|
|
|
|
**10 Model Variants Available:** |
|
|
- **baseline**: All features (best accuracy: 95.8%) |
|
|
- **scenario_1_days14_21**: Days 14 & 21 only (94.2%) |
|
|
- **scenario_3_day21**: Day 21 only (93.7%) |
|
|
- **scenario_4_day14**: Day 14 only (87.4%) |
|
|
- **scenario_2_days7_14**: Days 7 & 14 (87.3%) |
|
|
- **scenario_6_spc_all**: SPC only - all days (78.3%) |
|
|
- **scenario_8_spc_7_14**: SPC days 7 & 14 (73.3%) |
|
|
- **scenario_9_tgn_7_14**: TGN days 7 & 14 (73.1%) |
|
|
- **scenario_7_tgn_all**: TGN only - all days (69.9%) |
|
|
- **scenario_5_day7**: Day 7 only (62.8%) |
|
|
|
|
|
Select the variant based on your available data. If you have all measurements, |
|
|
use 'baseline' for best accuracy. If you only have partial data, choose the |
|
|
appropriate scenario variant. |
|
|
""", |
|
|
version="2.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=False, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
max_age=3600, |
|
|
) |
|
|
|
|
|
|
|
|
class PredictionInput(BaseModel): |
|
|
spc_d7: Optional[float] = Field(None, description="Standard Plate Count at Day 7 (log CFU/mL)", ge=0.0, le=10.0) |
|
|
spc_d14: Optional[float] = Field(None, description="Standard Plate Count at Day 14 (log CFU/mL)", ge=0.0, le=10.0) |
|
|
spc_d21: Optional[float] = Field(None, description="Standard Plate Count at Day 21 (log CFU/mL)", ge=0.0, le=10.0) |
|
|
tgn_d7: Optional[float] = Field(None, description="Total Gram-Negative at Day 7 (log CFU/mL)", ge=0.0, le=10.0) |
|
|
tgn_d14: Optional[float] = Field(None, description="Total Gram-Negative at Day 14 (log CFU/mL)", ge=0.0, le=10.0) |
|
|
tgn_d21: Optional[float] = Field(None, description="Total Gram-Negative at Day 21 (log CFU/mL)", ge=0.0, le=10.0) |
|
|
model_variant: str = Field( |
|
|
"baseline", |
|
|
description="Model variant to use (baseline, scenario_1_days14_21, scenario_3_day21, etc.)" |
|
|
) |
|
|
|
|
|
class Config: |
|
|
json_schema_extra = { |
|
|
"example": { |
|
|
"spc_d7": 2.1, |
|
|
"spc_d14": 4.7, |
|
|
"spc_d21": 6.4, |
|
|
"tgn_d7": 1.0, |
|
|
"tgn_d14": 3.7, |
|
|
"tgn_d21": 5.3, |
|
|
"model_variant": "baseline" |
|
|
} |
|
|
} |
|
|
|
|
|
class VariantInfo(BaseModel): |
|
|
variant_id: str |
|
|
name: str |
|
|
description: str |
|
|
features: List[str] |
|
|
test_accuracy: float |
|
|
|
|
|
class PredictionOutput(BaseModel): |
|
|
prediction: str = Field(..., description="Predicted spoilage class") |
|
|
probabilities: Dict[str, float] = Field(..., description="Probability for each class") |
|
|
confidence: float = Field(..., description="Confidence score (max probability)") |
|
|
variant_used: VariantInfo = Field(..., description="Information about the model variant used") |
|
|
|
|
|
|
|
|
def extract_features(input_data: PredictionInput, required_features: List[str]) -> np.ndarray: |
|
|
"""Extract required features from input data.""" |
|
|
feature_map = { |
|
|
'SPC_D7': input_data.spc_d7, |
|
|
'SPC_D14': input_data.spc_d14, |
|
|
'SPC_D21': input_data.spc_d21, |
|
|
'TGN_D7': input_data.tgn_d7, |
|
|
'TGN_D14': input_data.tgn_d14, |
|
|
'TGN_D21': input_data.tgn_d21, |
|
|
} |
|
|
|
|
|
|
|
|
missing = [f for f in required_features if feature_map[f] is None] |
|
|
if missing: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Missing required features for variant: {', '.join(missing)}" |
|
|
) |
|
|
|
|
|
|
|
|
features = [10 ** feature_map[f] for f in required_features] |
|
|
return np.array([features]) |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint with API information.""" |
|
|
return { |
|
|
"message": "Milk Spoilage Classification API - Multi-Variant", |
|
|
"version": "2.0.0", |
|
|
"variants_available": len(MODELS), |
|
|
"endpoints": { |
|
|
"predict": "/predict", |
|
|
"variants": "/variants", |
|
|
"health": "/health", |
|
|
"docs": "/docs" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/variants", tags=["Variants"]) |
|
|
async def list_variants(): |
|
|
"""List all available model variants with their metadata.""" |
|
|
variants_list = [] |
|
|
for variant_id, metadata in VARIANTS_CONFIG['variants'].items(): |
|
|
variants_list.append({ |
|
|
"variant_id": variant_id, |
|
|
"name": metadata['name'], |
|
|
"description": metadata['description'], |
|
|
"features": metadata['features'], |
|
|
"test_accuracy": metadata['test_accuracy'], |
|
|
"n_features": len(metadata['features']) |
|
|
}) |
|
|
|
|
|
|
|
|
variants_list.sort(key=lambda x: x['test_accuracy'], reverse=True) |
|
|
|
|
|
return { |
|
|
"total_variants": len(variants_list), |
|
|
"variants": variants_list |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/predict", response_model=PredictionOutput, tags=["Prediction"]) |
|
|
async def predict(input_data: PredictionInput): |
|
|
""" |
|
|
Predict milk spoilage type using the specified model variant. |
|
|
|
|
|
**How to choose a variant:** |
|
|
- If you have all 6 measurements → use 'baseline' (best accuracy) |
|
|
- If you only have Day 21 data → use 'scenario_3_day21' |
|
|
- If you only have Day 14 data → use 'scenario_4_day14' |
|
|
- If you only have SPC measurements → use 'scenario_6_spc_all' |
|
|
- etc. |
|
|
|
|
|
The API will validate that you've provided all required features for the selected variant. |
|
|
""" |
|
|
|
|
|
if input_data.model_variant not in MODELS: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Unknown variant '{input_data.model_variant}'. Use /variants to see available options." |
|
|
) |
|
|
|
|
|
|
|
|
model = MODELS[input_data.model_variant] |
|
|
variant_meta = VARIANTS_CONFIG['variants'][input_data.model_variant] |
|
|
required_features = variant_meta['features'] |
|
|
|
|
|
|
|
|
try: |
|
|
features = extract_features(input_data, required_features) |
|
|
except HTTPException as e: |
|
|
raise e |
|
|
|
|
|
|
|
|
prediction = model.predict(features)[0] |
|
|
probabilities = model.predict_proba(features)[0] |
|
|
|
|
|
|
|
|
prob_dict = { |
|
|
str(cls): float(prob) |
|
|
for cls, prob in zip(model.classes_, probabilities) |
|
|
} |
|
|
|
|
|
variant_info = VariantInfo( |
|
|
variant_id=input_data.model_variant, |
|
|
name=variant_meta['name'], |
|
|
description=variant_meta['description'], |
|
|
features=required_features, |
|
|
test_accuracy=variant_meta['test_accuracy'] |
|
|
) |
|
|
|
|
|
return PredictionOutput( |
|
|
prediction=str(prediction), |
|
|
probabilities=prob_dict, |
|
|
confidence=float(max(probabilities)), |
|
|
variant_used=variant_info |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/health", tags=["Health"]) |
|
|
async def health_check(): |
|
|
"""Health check endpoint.""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"models_loaded": len(MODELS), |
|
|
"variants": list(MODELS.keys()), |
|
|
"classes": MODELS['baseline'].classes_.tolist() if 'baseline' in MODELS else [] |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|