File size: 1,224 Bytes
e3e1326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import joblib
import os
from pydantic import BaseModel
from typing import List

class IrisInput(BaseModel):
    sepal_length: float
    sepal_width: float
    petal_length: float
    petal_width: float

class IrisPrediction(BaseModel):
    class_name: str
    class_id: int

class IrisModel:
    def __init__(self):
        self.model = None
        self.class_names = ["setosa", "versicolor", "virginica"]
        self.load_model()

    def load_model(self):
        model_path = os.path.join("models", "iris_model.joblib")
        if os.path.exists(model_path):
            self.model = joblib.load(model_path)
        else:
            raise FileNotFoundError(f"Model not found at {model_path}. Please train the model first.")

    def predict(self, input_data: IrisInput) -> IrisPrediction:
        if not self.model:
            self.load_model()
        
        data = [[
            input_data.sepal_length,
            input_data.sepal_width,
            input_data.petal_length,
            input_data.petal_width
        ]]
        
        prediction = self.model.predict(data)[0]
        return IrisPrediction(
            class_name=self.class_names[prediction],
            class_id=int(prediction)
        )