File size: 1,304 Bytes
64d0ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import pandas as pd
from training.model import build_model
from training.utils import get_device

CHECKPOINT_PATH = "checkpoints/best_model.pth"
LABEL_MAP_PATH = "data_processed/label_map.csv"

# Wrapper class for loading the model and making predictions on new data instances
class ModelWrapper:
    def __init__(self):
        self.device = get_device()

        label_df = pd.read_csv(LABEL_MAP_PATH)
        self.id_to_label = dict(
            zip(label_df["label_id"], label_df["label"])
        )

        num_classes = len(self.id_to_label)

        self.model = build_model(num_classes, self.device)
        self.model.load_state_dict(
            torch.load(CHECKPOINT_PATH, map_location=self.device)
        )
        self.model.eval()

    def predict(self, image_tensor):
        with torch.no_grad():
            image_tensor = image_tensor.to(self.device)
            outputs = self.model(image_tensor)
            probs = torch.softmax(outputs, dim=1)

            confidence, pred_id = torch.max(probs, dim=1)

        return {
            "label_id": int(pred_id.item()),
            "label_name": self.id_to_label[int(pred_id.item())],
            "confidence": float(confidence.item())
        }


model_wrapper = ModelWrapper()