import torch from torchvision import models, transforms from PIL import Image # ==== Configuration ==== MODEL_PATH = 'models/cat_dog_classifier.pth' CLASS_NAMES = ['cat', 'dog'] # Make sure this order matches your training dataset # ==== Load Model ==== model = models.mobilenet_v2(pretrained=False) model.classifier[1] = torch.nn.Linear(model.last_channel, 2) model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu')) model.eval() # Set to evaluation mode # ==== Image Preprocessing ==== transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) # ==== Inference Function ==== def predict(image_path): image = Image.open(image_path).convert('RGB') input_tensor = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): outputs = model(input_tensor) predicted_class = outputs.argmax(1).item() confidence = torch.softmax(outputs, dim=1)[0][predicted_class].item() return { 'class': CLASS_NAMES[predicted_class], 'confidence': confidence } if __name__ == "__main__": print(predict('raw_data/train/dog.0.jpg'))