File size: 3,220 Bytes
83851f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# ---------- Demo Data Example ----------
DEMO_PREDICT_BODY = {
    "sepal_length": 5.1,
    "sepal_width": 3.5,
    "petal_length": 1.4,
    "petal_width": 0.5
}

# app_ml.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List, Dict
import os

import numpy as np
import joblib

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

APP_VERSION = "1.0.0"
MODEL_DIR = "/tmp/models"
MODEL_PATH = os.path.join(MODEL_DIR, "iris_rf.joblib")

app = FastAPI(
    title="Class 8 - ML Model Serving (Iris)",
    version=APP_VERSION,
    description="Serve a scikit-learn model via FastAPI with input validation."
)

# ---------- Schemas ----------
class IrisFeatures(BaseModel):
    sepal_length: float = Field(..., ge=0.0, le=10.0)
    sepal_width: float = Field(..., ge=0.0, le=10.0)
    petal_length: float = Field(..., ge=0.0, le=10.0)
    petal_width: float = Field(..., ge=0.0, le=10.0)

class PredictResponse(BaseModel):
    ok: bool
    model_version: str
    predicted_label: str
    predicted_class_index: int
    probabilities: Dict[str, float]

# ---------- Model utilities ----------
def train_and_save_model(path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)

    iris = load_iris()
    X = iris.data
    y = iris.target
    class_names = iris.target_names

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )

    model = RandomForestClassifier(
        n_estimators=200,
        random_state=42
    )
    model.fit(X_train, y_train)

    payload = {
        "model": model,
        "class_names": class_names.tolist(),
        "feature_names": iris.feature_names,
        "version": APP_VERSION
    }
    joblib.dump(payload, path)

def load_model(path: str):
    if not os.path.exists(path):
        train_and_save_model(path)
    return joblib.load(path)

MODEL_BUNDLE = load_model(MODEL_PATH)
MODEL = MODEL_BUNDLE["model"]
CLASS_NAMES = MODEL_BUNDLE["class_names"]
MODEL_VERSION = MODEL_BUNDLE.get("version", "unknown")

# ---------- Endpoints ----------
@app.get("/health")
def health():
    return {"status": "ok", "model_loaded": True, "model_version": MODEL_VERSION}

@app.post("/v1/predict", response_model=PredictResponse)
def predict(features: IrisFeatures):
    try:
        x = np.array([[
            features.sepal_length,
            features.sepal_width,
            features.petal_length,
            features.petal_width
        ]], dtype=float)

        proba = MODEL.predict_proba(x)[0]
        idx = int(np.argmax(proba))
        label = CLASS_NAMES[idx]

        prob_map = {CLASS_NAMES[i]: float(proba[i]) for i in range(len(CLASS_NAMES))}

        return PredictResponse(
            ok=True,
            model_version=MODEL_VERSION,
            predicted_label=label,
            predicted_class_index=idx,
            probabilities=prob_map
        )
    except Exception:
        raise HTTPException(status_code=500, detail="Prediction failed")