CatHann commited on
Commit
9ec30dc
·
verified ·
1 Parent(s): 118e318

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +52 -3
README.md CHANGED
@@ -1,3 +1,52 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - MichaelMM2000/animals10
5
+ ---
6
+
7
+ # AnimalNet18
8
+
9
+ **AnimalNet18** is an animal image classification model trained on the [Animals-10](https://huggingface.co/datasets/MichaelMM2000/animals10) dataset.
10
+ The goal of the model is to classify images into common animal categories in the dataset.
11
+
12
+ ---
13
+
14
+ ## Dataset
15
+ - **Source**: [MichaelMM2000/animals10](https://huggingface.co/datasets/MichaelMM2000/animals10)
16
+ - **Number of classes**: 10 (e.g., dog, cat, horse, elephant, butterfly, …)
17
+
18
+ ---
19
+
20
+ ## Architecture
21
+ - Backbone: **ResNet-18** (PyTorch)
22
+ - Input size: `224x224`
23
+ - Optimizer: Adam
24
+ - Loss: CrossEntropy
25
+
26
+ ---
27
+
28
+ ## Usage
29
+
30
+ ### 1. Load the model from Hugging Face
31
+ ```python
32
+ import torch, torch.nn as nn
33
+ from torchvision import models, transforms
34
+ from PIL import Image
35
+ from huggingface_hub import hf_hub_download
36
+
37
+ # Load model
38
+ path = hf_hub_download("CatHann/AnimalNet18", "AnimalNet18.pth")
39
+ model = models.resnet18(pretrained=False)
40
+ model.fc = nn.Linear(model.fc.in_features, 10)
41
+ model.load_state_dict(torch.load(path, map_location="cpu"))
42
+ model.eval()
43
+
44
+ # Transform & predict
45
+ tfm = transforms.Compose([
46
+ transforms.Resize((224,224)),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
49
+ ])
50
+ img = tfm(Image.open("test.jpg")).unsqueeze(0)
51
+ pred = model(img).argmax(1).item()
52
+ print("Predicted class:", pred)