import torch import torchvision.transforms as transforms from torchvision import models import torch.nn as nn from PIL import Image import json # Load class names with open('class_names.json', 'r') as f: class_names = json.load(f) # Define model def load_model(): model = models.resnet50(pretrained=False) model.fc = nn.Linear(model.fc.in_features, len(class_names)) checkpoint = torch.load('reptile_classifier.pth', map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model model = load_model() # Image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Prediction function def predict(image: Image.Image): image = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(image) probabilities = torch.nn.functional.softmax(outputs[0], dim=0) top3_prob, top3_indices = torch.topk(probabilities, 3) return {class_names[idx]: float(prob) for idx, prob in zip(top3_indices, top3_prob)}