import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as transforms from PIL import Image # ---------------------------- # Labels (all breeds) # ---------------------------- breeds = [ "Alambadi", "Amritmahal", "Ayrshire", "Banni", "Bargur", "Bhadawari", "Brown_Swiss", "Dangi", "Deoni", "Gir", "Guernsey", "Hallikar", "Hariana", "Holstein_Friesian", "Jaffrabadi", "Jersey", "Kangayam", "Kankrej", "Kasargod", "Kenkatha", "Kherigarh", "Khillari", "Krishna_Valley", "Malnad_gidda", "Mehsana", "Murrah", "Nagori", "Nagpuri", "Nili_Ravi", "Nimari", "Ongole", "Pulikulam", "Rathi", "Red_Dane", "Red_Sindhi", "Sahiwal", "Surti", "Tharparkar", "Toda", "Umblachery", "Vechur" ] # ---------------------------- # Load Model # ---------------------------- def load_model(): model = models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features, len(breeds)) model.load_state_dict(torch.load("bovine_model.pth", map_location="cpu")) model.eval() return model model = load_model() # ---------------------------- # Image Preprocessing # ---------------------------- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # ---------------------------- # Prediction Function # ---------------------------- def predict(image: Image.Image): img = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(img) _, predicted = torch.max(outputs, 1) return {"breed": breeds[predicted.item()]}