chenhaoq87's picture
Upload app.py with huggingface_hub
183b3c1 verified
"""
Simple FastAPI REST API for Milk Spoilage Classification
This provides a clean REST endpoint for Custom GPT and other integrations.
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import joblib
import numpy as np
from typing import Dict
# Load model
model = joblib.load("model.joblib")
# Create FastAPI app
app = FastAPI(
title="Milk Spoilage Classification API",
description="Predict milk spoilage type based on microbial count data",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request/Response models
class PredictionInput(BaseModel):
spc_d7: float = Field(..., description="Standard Plate Count at Day 7 (log CFU/mL)", ge=0.0, le=10.0)
spc_d14: float = Field(..., description="Standard Plate Count at Day 14 (log CFU/mL)", ge=0.0, le=10.0)
spc_d21: float = Field(..., description="Standard Plate Count at Day 21 (log CFU/mL)", ge=0.0, le=10.0)
tgn_d7: float = Field(..., description="Total Gram-Negative at Day 7 (log CFU/mL)", ge=0.0, le=10.0)
tgn_d14: float = Field(..., description="Total Gram-Negative at Day 14 (log CFU/mL)", ge=0.0, le=10.0)
tgn_d21: float = Field(..., description="Total Gram-Negative at Day 21 (log CFU/mL)", ge=0.0, le=10.0)
class Config:
json_schema_extra = {
"example": {
"spc_d7": 4.0,
"spc_d14": 5.0,
"spc_d21": 6.0,
"tgn_d7": 3.0,
"tgn_d14": 4.0,
"tgn_d21": 5.0
}
}
class PredictionOutput(BaseModel):
prediction: str = Field(..., description="Predicted spoilage class")
probabilities: Dict[str, float] = Field(..., description="Probability for each class")
confidence: float = Field(..., description="Confidence score (max probability)")
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"message": "Milk Spoilage Classification API",
"endpoints": {
"predict": "/predict",
"health": "/health",
"docs": "/docs"
}
}
@app.post("/predict", response_model=PredictionOutput, tags=["Prediction"])
async def predict(input_data: PredictionInput):
"""
Predict milk spoilage type based on microbial counts.
Returns the predicted class, probabilities for all classes, and confidence score.
"""
# Prepare features
features = np.array([[
input_data.spc_d7, input_data.spc_d14, input_data.spc_d21,
input_data.tgn_d7, input_data.tgn_d14, input_data.tgn_d21
]])
# Make prediction
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
# Format response
prob_dict = {
str(cls): float(prob)
for cls, prob in zip(model.classes_, probabilities)
}
return PredictionOutput(
prediction=str(prediction),
probabilities=prob_dict,
confidence=float(max(probabilities))
)
@app.get("/health", tags=["Health"])
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"model_loaded": model is not None,
"classes": model.classes_.tolist()
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)