import sys import os sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import torch from PIL import Image from torchvision import transforms from models.cnn_model import CatBreedCNN classes = ['Bengal', 'Domestic_Shorthair', 'Maine_Coon','Ragdoll','Siamese',] # Update as needed def predict(image_path): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = CatBreedCNN(num_classes=len(classes)) model.load_state_dict(torch.load("models/cat_cnn.pth", map_location=device)) model.eval() transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3) ]) image = Image.open(image_path).convert("RGB") image = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(image) predicted_index = output.argmax(dim=1).item() return classes[predicted_index]