import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import io # 1. Define the model architecture EXACTLY as in your training script def create_model(num_classes: int): model = models.resnet18(weights=None) # Using weights=None since we are loading our own model.fc = nn.Linear(model.fc.in_features, num_classes) return model # 2. Define the EXACT same evaluation transforms from your training script transform_eval = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 3. Create a function to make a prediction def predict(model: nn.Module, image_bytes: bytes, class_names: list): """ Takes a model, image bytes, and class names, returns the prediction and confidence. """ # Load image from bytes image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Preprocess the image input_tensor = transform_eval(image).unsqueeze(0) # Add batch dimension # Make prediction model.eval() with torch.no_grad(): output = model(input_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) confidence, predicted_idx = torch.max(probabilities, 0) predicted_class = class_names[predicted_idx.item()] return { "predicted_id": predicted_class, "confidence": confidence.item() }