Spaces:
Sleeping
Sleeping
| import torch | |
| import pandas as pd | |
| from training.model import build_model | |
| from training.utils import get_device | |
| CHECKPOINT_PATH = "checkpoints/best_model.pth" | |
| LABEL_MAP_PATH = "data_processed/label_map.csv" | |
| # Wrapper class for loading the model and making predictions on new data instances | |
| class ModelWrapper: | |
| def __init__(self): | |
| self.device = get_device() | |
| label_df = pd.read_csv(LABEL_MAP_PATH) | |
| self.id_to_label = dict( | |
| zip(label_df["label_id"], label_df["label"]) | |
| ) | |
| num_classes = len(self.id_to_label) | |
| self.model = build_model(num_classes, self.device) | |
| self.model.load_state_dict( | |
| torch.load(CHECKPOINT_PATH, map_location=self.device) | |
| ) | |
| self.model.eval() | |
| def predict(self, image_tensor): | |
| with torch.no_grad(): | |
| image_tensor = image_tensor.to(self.device) | |
| outputs = self.model(image_tensor) | |
| probs = torch.softmax(outputs, dim=1) | |
| confidence, pred_id = torch.max(probs, dim=1) | |
| return { | |
| "label_id": int(pred_id.item()), | |
| "label_name": self.id_to_label[int(pred_id.item())], | |
| "confidence": float(confidence.item()) | |
| } | |
| model_wrapper = ModelWrapper() | |