| |
| DEMO_PREDICT_BODY = { |
| "sepal_length": 5.1, |
| "sepal_width": 3.5, |
| "petal_length": 1.4, |
| "petal_width": 0.2 |
| } |
|
|
| |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, Field |
| from typing import List, Dict |
| import os |
|
|
| import numpy as np |
| import joblib |
|
|
| from sklearn.datasets import load_iris |
| from sklearn.ensemble import RandomForestClassifier |
| from sklearn.model_selection import train_test_split |
|
|
| APP_VERSION = "1.0.0" |
| MODEL_DIR = "/tmp/models" |
| MODEL_PATH = os.path.join(MODEL_DIR, "iris_rf.joblib") |
|
|
| app = FastAPI( |
| title="Class 8 - ML Model Serving (Iris)", |
| version=APP_VERSION, |
| description="Serve a scikit-learn model via FastAPI with input validation." |
| ) |
|
|
| |
| class IrisFeatures(BaseModel): |
| sepal_length: float = Field(..., ge=0.0, le=10.0) |
| sepal_width: float = Field(..., ge=0.0, le=10.0) |
| petal_length: float = Field(..., ge=0.0, le=10.0) |
| petal_width: float = Field(..., ge=0.0, le=10.0) |
|
|
| class PredictResponse(BaseModel): |
| ok: bool |
| model_version: str |
| predicted_label: str |
| predicted_class_index: int |
| probabilities: Dict[str, float] |
|
|
| |
| def train_and_save_model(path: str): |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
|
|
| iris = load_iris() |
| X = iris.data |
| y = iris.target |
| class_names = iris.target_names |
|
|
| X_train, X_test, y_train, y_test = train_test_split( |
| X, y, test_size=0.2, random_state=42, stratify=y |
| ) |
|
|
| model = RandomForestClassifier( |
| n_estimators=200, |
| random_state=42 |
| ) |
| model.fit(X_train, y_train) |
|
|
| payload = { |
| "model": model, |
| "class_names": class_names.tolist(), |
| "feature_names": iris.feature_names, |
| "version": APP_VERSION |
| } |
| joblib.dump(payload, path) |
|
|
| def load_model(path: str): |
| if not os.path.exists(path): |
| train_and_save_model(path) |
| return joblib.load(path) |
|
|
| MODEL_BUNDLE = load_model(MODEL_PATH) |
| MODEL = MODEL_BUNDLE["model"] |
| CLASS_NAMES = MODEL_BUNDLE["class_names"] |
| MODEL_VERSION = MODEL_BUNDLE.get("version", "unknown") |
|
|
| |
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "model_loaded": True, "model_version": MODEL_VERSION} |
|
|
| @app.post("/v1/predict", response_model=PredictResponse) |
| def predict(features: IrisFeatures): |
| try: |
| x = np.array([[ |
| features.sepal_length, |
| features.sepal_width, |
| features.petal_length, |
| features.petal_width |
| ]], dtype=float) |
|
|
| proba = MODEL.predict_proba(x)[0] |
| idx = int(np.argmax(proba)) |
| label = CLASS_NAMES[idx] |
|
|
| prob_map = {CLASS_NAMES[i]: float(proba[i]) for i in range(len(CLASS_NAMES))} |
|
|
| return PredictResponse( |
| ok=True, |
| model_version=MODEL_VERSION, |
| predicted_label=label, |
| predicted_class_index=idx, |
| probabilities=prob_map |
| ) |
| except Exception: |
| raise HTTPException(status_code=500, detail="Prediction failed") |