MichaelMM2000/animals10
Updated • 2
AnimalNet18 is an animal image classification model trained on the Animals-10 dataset.
The goal of the model is to classify images into common animal categories in the dataset.
224x224 import torch, torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from huggingface_hub import hf_hub_download
# Load model
path = hf_hub_download("CatHann/AnimalNet18", "AnimalNet18.pth")
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10)
model.load_state_dict(torch.load(path, map_location="cpu"))
model.eval()
# Transform & predict
tfm = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
img = tfm(Image.open("test.jpg")).unsqueeze(0)
pred = model(img).argmax(1).item()
print("Predicted class:", pred)