chenhaoq87 commited on
Commit
a7c5735
·
verified ·
1 Parent(s): ffcb6ab

Upload fastapi_app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. fastapi_app.py +115 -0
fastapi_app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple FastAPI REST API for Milk Spoilage Classification
3
+
4
+ This provides a clean REST endpoint for Custom GPT and other integrations.
5
+ """
6
+
7
+ from fastapi import FastAPI
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel, Field
10
+ import joblib
11
+ import numpy as np
12
+ from typing import Dict
13
+
14
+ # Load model
15
+ model = joblib.load("model.joblib")
16
+
17
+ # Create FastAPI app
18
+ app = FastAPI(
19
+ title="Milk Spoilage Classification API",
20
+ description="Predict milk spoilage type based on microbial count data",
21
+ version="1.0.0"
22
+ )
23
+
24
+ # Add CORS middleware
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ # Request/Response models
34
+ class PredictionInput(BaseModel):
35
+ spc_d7: float = Field(..., description="Standard Plate Count at Day 7 (log CFU/mL)", ge=0.0, le=10.0)
36
+ spc_d14: float = Field(..., description="Standard Plate Count at Day 14 (log CFU/mL)", ge=0.0, le=10.0)
37
+ spc_d21: float = Field(..., description="Standard Plate Count at Day 21 (log CFU/mL)", ge=0.0, le=10.0)
38
+ tgn_d7: float = Field(..., description="Total Gram-Negative at Day 7 (log CFU/mL)", ge=0.0, le=10.0)
39
+ tgn_d14: float = Field(..., description="Total Gram-Negative at Day 14 (log CFU/mL)", ge=0.0, le=10.0)
40
+ tgn_d21: float = Field(..., description="Total Gram-Negative at Day 21 (log CFU/mL)", ge=0.0, le=10.0)
41
+
42
+ class Config:
43
+ json_schema_extra = {
44
+ "example": {
45
+ "spc_d7": 4.0,
46
+ "spc_d14": 5.0,
47
+ "spc_d21": 6.0,
48
+ "tgn_d7": 3.0,
49
+ "tgn_d14": 4.0,
50
+ "tgn_d21": 5.0
51
+ }
52
+ }
53
+
54
+ class PredictionOutput(BaseModel):
55
+ prediction: str = Field(..., description="Predicted spoilage class")
56
+ probabilities: Dict[str, float] = Field(..., description="Probability for each class")
57
+ confidence: float = Field(..., description="Confidence score (max probability)")
58
+
59
+
60
+ @app.get("/")
61
+ async def root():
62
+ """Root endpoint with API information."""
63
+ return {
64
+ "message": "Milk Spoilage Classification API",
65
+ "endpoints": {
66
+ "predict": "/predict",
67
+ "health": "/health",
68
+ "docs": "/docs"
69
+ }
70
+ }
71
+
72
+
73
+ @app.post("/predict", response_model=PredictionOutput, tags=["Prediction"])
74
+ async def predict(input_data: PredictionInput):
75
+ """
76
+ Predict milk spoilage type based on microbial counts.
77
+
78
+ Returns the predicted class, probabilities for all classes, and confidence score.
79
+ """
80
+ # Prepare features
81
+ features = np.array([[
82
+ input_data.spc_d7, input_data.spc_d14, input_data.spc_d21,
83
+ input_data.tgn_d7, input_data.tgn_d14, input_data.tgn_d21
84
+ ]])
85
+
86
+ # Make prediction
87
+ prediction = model.predict(features)[0]
88
+ probabilities = model.predict_proba(features)[0]
89
+
90
+ # Format response
91
+ prob_dict = {
92
+ str(cls): float(prob)
93
+ for cls, prob in zip(model.classes_, probabilities)
94
+ }
95
+
96
+ return PredictionOutput(
97
+ prediction=str(prediction),
98
+ probabilities=prob_dict,
99
+ confidence=float(max(probabilities))
100
+ )
101
+
102
+
103
+ @app.get("/health", tags=["Health"])
104
+ async def health_check():
105
+ """Health check endpoint."""
106
+ return {
107
+ "status": "healthy",
108
+ "model_loaded": model is not None,
109
+ "classes": model.classes_.tolist()
110
+ }
111
+
112
+
113
+ if __name__ == "__main__":
114
+ import uvicorn
115
+ uvicorn.run(app, host="0.0.0.0", port=7860)