| | import torch |
| | from torchvision import models, transforms |
| | from PIL import Image |
| |
|
| | |
| | MODEL_PATH = 'models/cat_dog_classifier.pth' |
| | CLASS_NAMES = ['cat', 'dog'] |
| |
|
| | |
| | model = models.mobilenet_v2(pretrained=False) |
| | model.classifier[1] = torch.nn.Linear(model.last_channel, 2) |
| | model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu')) |
| | model.eval() |
| |
|
| | |
| | transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], |
| | [0.229, 0.224, 0.225]), |
| | ]) |
| |
|
| | |
| | def predict(image_path): |
| | image = Image.open(image_path).convert('RGB') |
| | input_tensor = transform(image).unsqueeze(0) |
| |
|
| | with torch.no_grad(): |
| | outputs = model(input_tensor) |
| | predicted_class = outputs.argmax(1).item() |
| | confidence = torch.softmax(outputs, dim=1)[0][predicted_class].item() |
| |
|
| | return { |
| | 'class': CLASS_NAMES[predicted_class], |
| | 'confidence': confidence |
| | } |
| |
|
| | if __name__ == "__main__": |
| | print(predict('raw_data/train/dog.0.jpg')) |
| |
|