CCS229_ALA / src /predict.py
Gillie2004's picture
Upload 4 files
a82dfe3 verified
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import torch
from PIL import Image
from torchvision import transforms
from models.cnn_model import CatBreedCNN
classes = ['Bengal', 'Domestic_Shorthair', 'Maine_Coon','Ragdoll','Siamese',] # Update as needed
def predict(image_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CatBreedCNN(num_classes=len(classes))
model.load_state_dict(torch.load("models/cat_cnn.pth", map_location=device))
model.eval()
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
image = Image.open(image_path).convert("RGB")
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image)
predicted_index = output.argmax(dim=1).item()
return classes[predicted_index]