| | """ |
| | Multi-Variant FastAPI REST API for Milk Spoilage Classification |
| | |
| | This API supports multiple model variants with different feature subsets. |
| | Perfect for Custom GPT integration - allows selecting the optimal model |
| | based on available data and prediction needs. |
| | """ |
| |
|
| | from fastapi import FastAPI, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel, Field |
| | import joblib |
| | import numpy as np |
| | from typing import Dict, Optional, List |
| | import os |
| | import json |
| | from pathlib import Path |
| |
|
| | |
| | VARIANTS_DIR = Path("model/variants") |
| | if not VARIANTS_DIR.exists(): |
| | |
| | VARIANTS_DIR = Path(__file__).parent.parent.parent / "model" / "variants" |
| |
|
| | |
| | config_path = VARIANTS_DIR / "variants_config.json" |
| | if not config_path.exists(): |
| | raise FileNotFoundError(f"variants_config.json not found at {config_path}") |
| |
|
| | with open(config_path) as f: |
| | VARIANTS_CONFIG = json.load(f) |
| |
|
| | |
| | MODELS = {} |
| | for variant_id in VARIANTS_CONFIG['variants'].keys(): |
| | model_path = VARIANTS_DIR / f"{variant_id}.joblib" |
| | if model_path.exists(): |
| | MODELS[variant_id] = joblib.load(model_path) |
| | else: |
| | print(f"Warning: Model file not found for variant {variant_id}") |
| |
|
| | print(f"✓ Loaded {len(MODELS)} model variants: {list(MODELS.keys())}") |
| |
|
| | |
| | app = FastAPI( |
| | title="Milk Spoilage Classification API (Multi-Variant)", |
| | description=""" |
| | AI-powered milk spoilage classification with multiple model variants. |
| | |
| | **10 Model Variants Available:** |
| | - **baseline**: All features (best accuracy: 95.8%) |
| | - **scenario_1_days14_21**: Days 14 & 21 only (94.2%) |
| | - **scenario_3_day21**: Day 21 only (93.7%) |
| | - **scenario_4_day14**: Day 14 only (87.4%) |
| | - **scenario_2_days7_14**: Days 7 & 14 (87.3%) |
| | - **scenario_6_spc_all**: SPC only - all days (78.3%) |
| | - **scenario_8_spc_7_14**: SPC days 7 & 14 (73.3%) |
| | - **scenario_9_tgn_7_14**: TGN days 7 & 14 (73.1%) |
| | - **scenario_7_tgn_all**: TGN only - all days (69.9%) |
| | - **scenario_5_day7**: Day 7 only (62.8%) |
| | |
| | Select the variant based on your available data. If you have all measurements, |
| | use 'baseline' for best accuracy. If you only have partial data, choose the |
| | appropriate scenario variant. |
| | """, |
| | version="2.0.0" |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=False, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | max_age=3600, |
| | ) |
| |
|
| | |
| | class PredictionInput(BaseModel): |
| | spc_d7: Optional[float] = Field(None, description="Standard Plate Count at Day 7 (log CFU/mL)", ge=0.0, le=10.0) |
| | spc_d14: Optional[float] = Field(None, description="Standard Plate Count at Day 14 (log CFU/mL)", ge=0.0, le=10.0) |
| | spc_d21: Optional[float] = Field(None, description="Standard Plate Count at Day 21 (log CFU/mL)", ge=0.0, le=10.0) |
| | tgn_d7: Optional[float] = Field(None, description="Total Gram-Negative at Day 7 (log CFU/mL)", ge=0.0, le=10.0) |
| | tgn_d14: Optional[float] = Field(None, description="Total Gram-Negative at Day 14 (log CFU/mL)", ge=0.0, le=10.0) |
| | tgn_d21: Optional[float] = Field(None, description="Total Gram-Negative at Day 21 (log CFU/mL)", ge=0.0, le=10.0) |
| | model_variant: str = Field( |
| | "baseline", |
| | description="Model variant to use (baseline, scenario_1_days14_21, scenario_3_day21, etc.)" |
| | ) |
| | |
| | class Config: |
| | json_schema_extra = { |
| | "example": { |
| | "spc_d7": 2.1, |
| | "spc_d14": 4.7, |
| | "spc_d21": 6.4, |
| | "tgn_d7": 1.0, |
| | "tgn_d14": 3.7, |
| | "tgn_d21": 5.3, |
| | "model_variant": "baseline" |
| | } |
| | } |
| |
|
| | class VariantInfo(BaseModel): |
| | variant_id: str |
| | name: str |
| | description: str |
| | features: List[str] |
| | test_accuracy: float |
| |
|
| | class PredictionOutput(BaseModel): |
| | prediction: str = Field(..., description="Predicted spoilage class") |
| | probabilities: Dict[str, float] = Field(..., description="Probability for each class") |
| | confidence: float = Field(..., description="Confidence score (max probability)") |
| | variant_used: VariantInfo = Field(..., description="Information about the model variant used") |
| |
|
| |
|
| | def extract_features(input_data: PredictionInput, required_features: List[str]) -> np.ndarray: |
| | """Extract required features from input data.""" |
| | feature_map = { |
| | 'SPC_D7': input_data.spc_d7, |
| | 'SPC_D14': input_data.spc_d14, |
| | 'SPC_D21': input_data.spc_d21, |
| | 'TGN_D7': input_data.tgn_d7, |
| | 'TGN_D14': input_data.tgn_d14, |
| | 'TGN_D21': input_data.tgn_d21, |
| | } |
| | |
| | |
| | missing = [f for f in required_features if feature_map[f] is None] |
| | if missing: |
| | raise HTTPException( |
| | status_code=400, |
| | detail=f"Missing required features for variant: {', '.join(missing)}" |
| | ) |
| | |
| | |
| | features = [10 ** feature_map[f] for f in required_features] |
| | return np.array([features]) |
| |
|
| |
|
| | @app.get("/") |
| | async def root(): |
| | """Root endpoint with API information.""" |
| | return { |
| | "message": "Milk Spoilage Classification API - Multi-Variant", |
| | "version": "2.0.0", |
| | "variants_available": len(MODELS), |
| | "endpoints": { |
| | "predict": "/predict", |
| | "variants": "/variants", |
| | "health": "/health", |
| | "docs": "/docs" |
| | } |
| | } |
| |
|
| |
|
| | @app.get("/variants", tags=["Variants"]) |
| | async def list_variants(): |
| | """List all available model variants with their metadata.""" |
| | variants_list = [] |
| | for variant_id, metadata in VARIANTS_CONFIG['variants'].items(): |
| | variants_list.append({ |
| | "variant_id": variant_id, |
| | "name": metadata['name'], |
| | "description": metadata['description'], |
| | "features": metadata['features'], |
| | "test_accuracy": metadata['test_accuracy'], |
| | "n_features": len(metadata['features']) |
| | }) |
| | |
| | |
| | variants_list.sort(key=lambda x: x['test_accuracy'], reverse=True) |
| | |
| | return { |
| | "total_variants": len(variants_list), |
| | "variants": variants_list |
| | } |
| |
|
| |
|
| | @app.post("/predict", response_model=PredictionOutput, tags=["Prediction"]) |
| | async def predict(input_data: PredictionInput): |
| | """ |
| | Predict milk spoilage type using the specified model variant. |
| | |
| | **How to choose a variant:** |
| | - If you have all 6 measurements → use 'baseline' (best accuracy) |
| | - If you only have Day 21 data → use 'scenario_3_day21' |
| | - If you only have Day 14 data → use 'scenario_4_day14' |
| | - If you only have SPC measurements → use 'scenario_6_spc_all' |
| | - etc. |
| | |
| | The API will validate that you've provided all required features for the selected variant. |
| | """ |
| | |
| | if input_data.model_variant not in MODELS: |
| | raise HTTPException( |
| | status_code=400, |
| | detail=f"Unknown variant '{input_data.model_variant}'. Use /variants to see available options." |
| | ) |
| | |
| | |
| | model = MODELS[input_data.model_variant] |
| | variant_meta = VARIANTS_CONFIG['variants'][input_data.model_variant] |
| | required_features = variant_meta['features'] |
| | |
| | |
| | try: |
| | features = extract_features(input_data, required_features) |
| | except HTTPException as e: |
| | raise e |
| | |
| | |
| | prediction = model.predict(features)[0] |
| | probabilities = model.predict_proba(features)[0] |
| | |
| | |
| | prob_dict = { |
| | str(cls): float(prob) |
| | for cls, prob in zip(model.classes_, probabilities) |
| | } |
| | |
| | variant_info = VariantInfo( |
| | variant_id=input_data.model_variant, |
| | name=variant_meta['name'], |
| | description=variant_meta['description'], |
| | features=required_features, |
| | test_accuracy=variant_meta['test_accuracy'] |
| | ) |
| | |
| | return PredictionOutput( |
| | prediction=str(prediction), |
| | probabilities=prob_dict, |
| | confidence=float(max(probabilities)), |
| | variant_used=variant_info |
| | ) |
| |
|
| |
|
| | @app.get("/health", tags=["Health"]) |
| | async def health_check(): |
| | """Health check endpoint.""" |
| | return { |
| | "status": "healthy", |
| | "models_loaded": len(MODELS), |
| | "variants": list(MODELS.keys()), |
| | "classes": MODELS['baseline'].classes_.tolist() if 'baseline' in MODELS else [] |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| |
|