|
|
--- |
|
|
license: mit |
|
|
datasets: |
|
|
- MichaelMM2000/animals10 |
|
|
--- |
|
|
|
|
|
# AnimalNet18 |
|
|
|
|
|
**AnimalNet18** is an animal image classification model trained on the [Animals-10](https://huggingface.co/datasets/MichaelMM2000/animals10) dataset. |
|
|
The goal of the model is to classify images into common animal categories in the dataset. |
|
|
|
|
|
--- |
|
|
|
|
|
## Dataset |
|
|
- **Source**: [MichaelMM2000/animals10](https://huggingface.co/datasets/MichaelMM2000/animals10) |
|
|
- **Number of classes**: 10 (e.g., dog, cat, horse, elephant, butterfly, …) |
|
|
|
|
|
--- |
|
|
|
|
|
## Architecture |
|
|
- Backbone: **ResNet-18** (PyTorch) |
|
|
- Input size: `224x224` |
|
|
- Optimizer: Adam |
|
|
- Loss: CrossEntropy |
|
|
|
|
|
--- |
|
|
|
|
|
## Usage |
|
|
|
|
|
### 1. Load the model from Hugging Face |
|
|
```python |
|
|
import torch, torch.nn as nn |
|
|
from torchvision import models, transforms |
|
|
from PIL import Image |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# Load model |
|
|
path = hf_hub_download("CatHann/AnimalNet18", "AnimalNet18.pth") |
|
|
model = models.resnet18(pretrained=False) |
|
|
model.fc = nn.Linear(model.fc.in_features, 10) |
|
|
model.load_state_dict(torch.load(path, map_location="cpu")) |
|
|
model.eval() |
|
|
|
|
|
# Transform & predict |
|
|
tfm = transforms.Compose([ |
|
|
transforms.Resize((224,224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) |
|
|
]) |
|
|
img = tfm(Image.open("test.jpg")).unsqueeze(0) |
|
|
pred = model(img).argmax(1).item() |
|
|
print("Predicted class:", pred) |
|
|
|