import joblib import os from pydantic import BaseModel from typing import List class IrisInput(BaseModel): sepal_length: float sepal_width: float petal_length: float petal_width: float class IrisPrediction(BaseModel): class_name: str class_id: int class IrisModel: def __init__(self): self.model = None self.class_names = ["setosa", "versicolor", "virginica"] self.load_model() def load_model(self): model_path = os.path.join("models", "iris_model.joblib") if os.path.exists(model_path): self.model = joblib.load(model_path) else: raise FileNotFoundError(f"Model not found at {model_path}. Please train the model first.") def predict(self, input_data: IrisInput) -> IrisPrediction: if not self.model: self.load_model() data = [[ input_data.sepal_length, input_data.sepal_width, input_data.petal_length, input_data.petal_width ]] prediction = self.model.predict(data)[0] return IrisPrediction( class_name=self.class_names[prediction], class_id=int(prediction) )