Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| from models.cnn_model import CatBreedCNN | |
| classes = ['Bengal', 'Domestic_Shorthair', 'Maine_Coon','Ragdoll','Siamese',] # Update as needed | |
| def predict(image_path): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = CatBreedCNN(num_classes=len(classes)) | |
| model.load_state_dict(torch.load("models/cat_cnn.pth", map_location=device)) | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5]*3, [0.5]*3) | |
| ]) | |
| image = Image.open(image_path).convert("RGB") | |
| image = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(image) | |
| predicted_index = output.argmax(dim=1).item() | |
| return classes[predicted_index] |