Spaces:
Sleeping
Sleeping
| import joblib | |
| import os | |
| from pydantic import BaseModel | |
| from typing import List | |
| class IrisInput(BaseModel): | |
| sepal_length: float | |
| sepal_width: float | |
| petal_length: float | |
| petal_width: float | |
| class IrisPrediction(BaseModel): | |
| class_name: str | |
| class_id: int | |
| class IrisModel: | |
| def __init__(self): | |
| self.model = None | |
| self.class_names = ["setosa", "versicolor", "virginica"] | |
| self.load_model() | |
| def load_model(self): | |
| model_path = os.path.join("models", "iris_model.joblib") | |
| if os.path.exists(model_path): | |
| self.model = joblib.load(model_path) | |
| else: | |
| raise FileNotFoundError(f"Model not found at {model_path}. Please train the model first.") | |
| def predict(self, input_data: IrisInput) -> IrisPrediction: | |
| if not self.model: | |
| self.load_model() | |
| data = [[ | |
| input_data.sepal_length, | |
| input_data.sepal_width, | |
| input_data.petal_length, | |
| input_data.petal_width | |
| ]] | |
| prediction = self.model.predict(data)[0] | |
| return IrisPrediction( | |
| class_name=self.class_names[prediction], | |
| class_id=int(prediction) | |
| ) | |