File size: 4,188 Bytes
a7c5735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5c2352
a7c5735
 
 
09fb524
a7c5735
 
09fb524
 
 
a7c5735
09fb524
a7c5735
 
 
 
354fad4
 
 
 
 
 
a7c5735
 
 
 
354fad4
 
 
 
 
 
a7c5735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354fad4
a7c5735
 
354fad4
 
 
 
 
 
 
 
 
a7c5735
354fad4
 
a7c5735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""

Simple FastAPI REST API for Milk Spoilage Classification



This provides a clean REST endpoint for Custom GPT and other integrations.

"""

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import joblib
import numpy as np
from typing import Dict

# Load model
model = joblib.load("model.joblib")

# Create FastAPI app
app = FastAPI(
    title="Milk Spoilage Classification API",
    description="AI-powered spoilage classification system for predicting milk spoilage type based on bacterial contamination and microbial count analysis",
    version="1.0.0"
)

# Add CORS middleware - Completely open for Custom GPT and all origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"],
    max_age=3600,
)

# Request/Response models
class PredictionInput(BaseModel):
    spc_d7: float = Field(..., description="Standard Plate Count at Day 7 (log CFU/mL, base 10)", ge=0.0, le=10.0)
    spc_d14: float = Field(..., description="Standard Plate Count at Day 14 (log CFU/mL, base 10)", ge=0.0, le=10.0)
    spc_d21: float = Field(..., description="Standard Plate Count at Day 21 (log CFU/mL, base 10)", ge=0.0, le=10.0)
    tgn_d7: float = Field(..., description="Total Gram-Negative at Day 7 (log CFU/mL, base 10)", ge=0.0, le=10.0)
    tgn_d14: float = Field(..., description="Total Gram-Negative at Day 14 (log CFU/mL, base 10)", ge=0.0, le=10.0)
    tgn_d21: float = Field(..., description="Total Gram-Negative at Day 21 (log CFU/mL, base 10)", ge=0.0, le=10.0)
    
    class Config:
        json_schema_extra = {
            "example": {
                "spc_d7": 2.1,
                "spc_d14": 4.7,
                "spc_d21": 6.4,
                "tgn_d7": 1.0,
                "tgn_d14": 3.7,
                "tgn_d21": 5.3
            }
        }

class PredictionOutput(BaseModel):
    prediction: str = Field(..., description="Predicted spoilage class")
    probabilities: Dict[str, float] = Field(..., description="Probability for each class")
    confidence: float = Field(..., description="Confidence score (max probability)")


@app.get("/")
async def root():
    """Root endpoint with API information."""
    return {
        "message": "Milk Spoilage Classification API",
        "endpoints": {
            "predict": "/predict",
            "health": "/health",
            "docs": "/docs"
        }
    }


@app.post("/predict", response_model=PredictionOutput, tags=["Prediction"])
async def predict(input_data: PredictionInput):
    """

    Predict milk spoilage type based on microbial counts.

    

    Accepts log CFU/mL values (base 10) and converts to raw CFU/mL for the model.

    Returns the predicted class, probabilities for all classes, and confidence score.

    """
    # Convert log values to raw CFU/mL
    raw_spc_d7 = 10 ** input_data.spc_d7
    raw_spc_d14 = 10 ** input_data.spc_d14
    raw_spc_d21 = 10 ** input_data.spc_d21
    raw_tgn_d7 = 10 ** input_data.tgn_d7
    raw_tgn_d14 = 10 ** input_data.tgn_d14
    raw_tgn_d21 = 10 ** input_data.tgn_d21
    
    # Prepare features with raw values
    features = np.array([[
        raw_spc_d7, raw_spc_d14, raw_spc_d21,
        raw_tgn_d7, raw_tgn_d14, raw_tgn_d21
    ]])
    
    # Make prediction
    prediction = model.predict(features)[0]
    probabilities = model.predict_proba(features)[0]
    
    # Format response
    prob_dict = {
        str(cls): float(prob) 
        for cls, prob in zip(model.classes_, probabilities)
    }
    
    return PredictionOutput(
        prediction=str(prediction),
        probabilities=prob_dict,
        confidence=float(max(probabilities))
    )


@app.get("/health", tags=["Health"])
async def health_check():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "model_loaded": model is not None,
        "classes": model.classes_.tolist()
    }


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)