Spaces:
Sleeping
Sleeping
| # ---------- Demo Data Example ---------- | |
| DEMO_PREDICT_BODY = { | |
| "sepal_length": 5.1, | |
| "sepal_width": 3.5, | |
| "petal_length": 1.4, | |
| "petal_width": 0.5 | |
| } | |
| # app_ml.py | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from typing import List, Dict | |
| import os | |
| import numpy as np | |
| import joblib | |
| from sklearn.datasets import load_iris | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.model_selection import train_test_split | |
| APP_VERSION = "1.0.0" | |
| MODEL_DIR = "/tmp/models" | |
| MODEL_PATH = os.path.join(MODEL_DIR, "iris_rf.joblib") | |
| app = FastAPI( | |
| title="Class 8 - ML Model Serving (Iris)", | |
| version=APP_VERSION, | |
| description="Serve a scikit-learn model via FastAPI with input validation." | |
| ) | |
| # ---------- Schemas ---------- | |
| class IrisFeatures(BaseModel): | |
| sepal_length: float = Field(..., ge=0.0, le=10.0) | |
| sepal_width: float = Field(..., ge=0.0, le=10.0) | |
| petal_length: float = Field(..., ge=0.0, le=10.0) | |
| petal_width: float = Field(..., ge=0.0, le=10.0) | |
| class PredictResponse(BaseModel): | |
| ok: bool | |
| model_version: str | |
| predicted_label: str | |
| predicted_class_index: int | |
| probabilities: Dict[str, float] | |
| # ---------- Model utilities ---------- | |
| def train_and_save_model(path: str): | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| iris = load_iris() | |
| X = iris.data | |
| y = iris.target | |
| class_names = iris.target_names | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.2, random_state=42, stratify=y | |
| ) | |
| model = RandomForestClassifier( | |
| n_estimators=200, | |
| random_state=42 | |
| ) | |
| model.fit(X_train, y_train) | |
| payload = { | |
| "model": model, | |
| "class_names": class_names.tolist(), | |
| "feature_names": iris.feature_names, | |
| "version": APP_VERSION | |
| } | |
| joblib.dump(payload, path) | |
| def load_model(path: str): | |
| if not os.path.exists(path): | |
| train_and_save_model(path) | |
| return joblib.load(path) | |
| MODEL_BUNDLE = load_model(MODEL_PATH) | |
| MODEL = MODEL_BUNDLE["model"] | |
| CLASS_NAMES = MODEL_BUNDLE["class_names"] | |
| MODEL_VERSION = MODEL_BUNDLE.get("version", "unknown") | |
| # ---------- Endpoints ---------- | |
| def health(): | |
| return {"status": "ok", "model_loaded": True, "model_version": MODEL_VERSION} | |
| def predict(features: IrisFeatures): | |
| try: | |
| x = np.array([[ | |
| features.sepal_length, | |
| features.sepal_width, | |
| features.petal_length, | |
| features.petal_width | |
| ]], dtype=float) | |
| proba = MODEL.predict_proba(x)[0] | |
| idx = int(np.argmax(proba)) | |
| label = CLASS_NAMES[idx] | |
| prob_map = {CLASS_NAMES[i]: float(proba[i]) for i in range(len(CLASS_NAMES))} | |
| return PredictResponse( | |
| ok=True, | |
| model_version=MODEL_VERSION, | |
| predicted_label=label, | |
| predicted_class_index=idx, | |
| probabilities=prob_map | |
| ) | |
| except Exception: | |
| raise HTTPException(status_code=500, detail="Prediction failed") | |