import torch import pandas as pd from training.model import build_model from training.utils import get_device CHECKPOINT_PATH = "checkpoints/best_model.pth" LABEL_MAP_PATH = "data_processed/label_map.csv" # Wrapper class for loading the model and making predictions on new data instances class ModelWrapper: def __init__(self): self.device = get_device() label_df = pd.read_csv(LABEL_MAP_PATH) self.id_to_label = dict( zip(label_df["label_id"], label_df["label"]) ) num_classes = len(self.id_to_label) self.model = build_model(num_classes, self.device) self.model.load_state_dict( torch.load(CHECKPOINT_PATH, map_location=self.device) ) self.model.eval() def predict(self, image_tensor): with torch.no_grad(): image_tensor = image_tensor.to(self.device) outputs = self.model(image_tensor) probs = torch.softmax(outputs, dim=1) confidence, pred_id = torch.max(probs, dim=1) return { "label_id": int(pred_id.item()), "label_name": self.id_to_label[int(pred_id.item())], "confidence": float(confidence.item()) } model_wrapper = ModelWrapper()