Spaces:
Sleeping
Sleeping
File size: 1,224 Bytes
e3e1326 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import joblib
import os
from pydantic import BaseModel
from typing import List
class IrisInput(BaseModel):
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float
class IrisPrediction(BaseModel):
class_name: str
class_id: int
class IrisModel:
def __init__(self):
self.model = None
self.class_names = ["setosa", "versicolor", "virginica"]
self.load_model()
def load_model(self):
model_path = os.path.join("models", "iris_model.joblib")
if os.path.exists(model_path):
self.model = joblib.load(model_path)
else:
raise FileNotFoundError(f"Model not found at {model_path}. Please train the model first.")
def predict(self, input_data: IrisInput) -> IrisPrediction:
if not self.model:
self.load_model()
data = [[
input_data.sepal_length,
input_data.sepal_width,
input_data.petal_length,
input_data.petal_width
]]
prediction = self.model.predict(data)[0]
return IrisPrediction(
class_name=self.class_names[prediction],
class_id=int(prediction)
)
|