george114 commited on
Commit
a40ccfd
·
verified ·
1 Parent(s): 6c86165

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------- Demo Data Example ----------
2
+ DEMO_PREDICT_BODY = {
3
+ "sepal_length": 5.1,
4
+ "sepal_width": 3.5,
5
+ "petal_length": 1.4,
6
+ "petal_width": 0.2
7
+ }
8
+
9
+ # app_ml.py
10
+ from fastapi import FastAPI, HTTPException
11
+ from pydantic import BaseModel, Field
12
+ from typing import List, Dict
13
+ import os
14
+
15
+ import numpy as np
16
+ import joblib
17
+
18
+ from sklearn.datasets import load_iris
19
+ from sklearn.ensemble import RandomForestClassifier
20
+ from sklearn.model_selection import train_test_split
21
+
22
+ APP_VERSION = "1.0.0"
23
+ MODEL_DIR = "/tmp/models"
24
+ MODEL_PATH = os.path.join(MODEL_DIR, "iris_rf.joblib")
25
+
26
+ app = FastAPI(
27
+ title="Class 8 - ML Model Serving (Iris)",
28
+ version=APP_VERSION,
29
+ description="Serve a scikit-learn model via FastAPI with input validation."
30
+ )
31
+
32
+ # ---------- Schemas ----------
33
+ class IrisFeatures(BaseModel):
34
+ sepal_length: float = Field(..., ge=0.0, le=10.0)
35
+ sepal_width: float = Field(..., ge=0.0, le=10.0)
36
+ petal_length: float = Field(..., ge=0.0, le=10.0)
37
+ petal_width: float = Field(..., ge=0.0, le=10.0)
38
+
39
+ class PredictResponse(BaseModel):
40
+ ok: bool
41
+ model_version: str
42
+ predicted_label: str
43
+ predicted_class_index: int
44
+ probabilities: Dict[str, float]
45
+
46
+ # ---------- Model utilities ----------
47
+ def train_and_save_model(path: str):
48
+ os.makedirs(os.path.dirname(path), exist_ok=True)
49
+
50
+ iris = load_iris()
51
+ X = iris.data
52
+ y = iris.target
53
+ class_names = iris.target_names
54
+
55
+ X_train, X_test, y_train, y_test = train_test_split(
56
+ X, y, test_size=0.2, random_state=42, stratify=y
57
+ )
58
+
59
+ model = RandomForestClassifier(
60
+ n_estimators=200,
61
+ random_state=42
62
+ )
63
+ model.fit(X_train, y_train)
64
+
65
+ payload = {
66
+ "model": model,
67
+ "class_names": class_names.tolist(),
68
+ "feature_names": iris.feature_names,
69
+ "version": APP_VERSION
70
+ }
71
+ joblib.dump(payload, path)
72
+
73
+ def load_model(path: str):
74
+ if not os.path.exists(path):
75
+ train_and_save_model(path)
76
+ return joblib.load(path)
77
+
78
+ MODEL_BUNDLE = load_model(MODEL_PATH)
79
+ MODEL = MODEL_BUNDLE["model"]
80
+ CLASS_NAMES = MODEL_BUNDLE["class_names"]
81
+ MODEL_VERSION = MODEL_BUNDLE.get("version", "unknown")
82
+
83
+ # ---------- Endpoints ----------
84
+ @app.get("/health")
85
+ def health():
86
+ return {"status": "ok", "model_loaded": True, "model_version": MODEL_VERSION}
87
+
88
+ @app.post("/v1/predict", response_model=PredictResponse)
89
+ def predict(features: IrisFeatures):
90
+ try:
91
+ x = np.array([[
92
+ features.sepal_length,
93
+ features.sepal_width,
94
+ features.petal_length,
95
+ features.petal_width
96
+ ]], dtype=float)
97
+
98
+ proba = MODEL.predict_proba(x)[0]
99
+ idx = int(np.argmax(proba))
100
+ label = CLASS_NAMES[idx]
101
+
102
+ prob_map = {CLASS_NAMES[i]: float(proba[i]) for i in range(len(CLASS_NAMES))}
103
+
104
+ return PredictResponse(
105
+ ok=True,
106
+ model_version=MODEL_VERSION,
107
+ predicted_label=label,
108
+ predicted_class_index=idx,
109
+ probabilities=prob_map
110
+ )
111
+ except Exception:
112
+ raise HTTPException(status_code=500, detail="Prediction failed")