xray-classification / backend /model_loader.py
Flamekizer11's picture
Upload 27 files
64d0ccc verified
import torch
import pandas as pd
from training.model import build_model
from training.utils import get_device
CHECKPOINT_PATH = "checkpoints/best_model.pth"
LABEL_MAP_PATH = "data_processed/label_map.csv"
# Wrapper class for loading the model and making predictions on new data instances
class ModelWrapper:
def __init__(self):
self.device = get_device()
label_df = pd.read_csv(LABEL_MAP_PATH)
self.id_to_label = dict(
zip(label_df["label_id"], label_df["label"])
)
num_classes = len(self.id_to_label)
self.model = build_model(num_classes, self.device)
self.model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=self.device)
)
self.model.eval()
def predict(self, image_tensor):
with torch.no_grad():
image_tensor = image_tensor.to(self.device)
outputs = self.model(image_tensor)
probs = torch.softmax(outputs, dim=1)
confidence, pred_id = torch.max(probs, dim=1)
return {
"label_id": int(pred_id.item()),
"label_name": self.id_to_label[int(pred_id.item())],
"confidence": float(confidence.item())
}
model_wrapper = ModelWrapper()