File size: 995 Bytes
8f0e25c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import torch
from PIL import Image
from torchvision import transforms
from models.cnn_model import CatBreedCNN


classes = ['Bengal', 'Domestic_Shorthair', 'Maine_Coon','Ragdoll','Siamese',]  # Update as needed

def predict(image_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = CatBreedCNN(num_classes=len(classes))
    model.load_state_dict(torch.load("models/cat_cnn.pth", map_location=device))
    model.eval()
    
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(image)
        predicted_index = output.argmax(dim=1).item()
    
    return classes[predicted_index]