Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -74,13 +74,18 @@ weighted avg 0.969 0.967 0.967 30
|
|
| 74 |
```python
|
| 75 |
import timm
|
| 76 |
import torch
|
|
|
|
| 77 |
from PIL import Image
|
| 78 |
from safetensors.torch import load_file
|
| 79 |
from torchvision import transforms
|
| 80 |
|
| 81 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=2)
|
| 83 |
-
model.load_state_dict(load_file(
|
| 84 |
model.eval()
|
| 85 |
|
| 86 |
# Preprocess
|
|
@@ -96,10 +101,12 @@ input_tensor = transform(image).unsqueeze(0)
|
|
| 96 |
|
| 97 |
with torch.no_grad():
|
| 98 |
output = model(input_tensor)
|
|
|
|
| 99 |
pred = output.argmax(1).item()
|
|
|
|
| 100 |
|
| 101 |
classes = ['other', 'index_card']
|
| 102 |
-
print(f"Prediction: {classes[pred]}")
|
| 103 |
```
|
| 104 |
|
| 105 |
## Training
|
|
|
|
| 74 |
```python
|
| 75 |
import timm
|
| 76 |
import torch
|
| 77 |
+
from huggingface_hub import hf_hub_download
|
| 78 |
from PIL import Image
|
| 79 |
from safetensors.torch import load_file
|
| 80 |
from torchvision import transforms
|
| 81 |
|
| 82 |
+
# Download and load model from Hub
|
| 83 |
+
weights_path = hf_hub_download(
|
| 84 |
+
repo_id="davanstrien/nls-index-card-classifier",
|
| 85 |
+
filename="classifier.safetensors"
|
| 86 |
+
)
|
| 87 |
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=2)
|
| 88 |
+
model.load_state_dict(load_file(weights_path))
|
| 89 |
model.eval()
|
| 90 |
|
| 91 |
# Preprocess
|
|
|
|
| 101 |
|
| 102 |
with torch.no_grad():
|
| 103 |
output = model(input_tensor)
|
| 104 |
+
probs = torch.softmax(output, dim=1)
|
| 105 |
pred = output.argmax(1).item()
|
| 106 |
+
confidence = probs[0, pred].item()
|
| 107 |
|
| 108 |
classes = ['other', 'index_card']
|
| 109 |
+
print(f"Prediction: {classes[pred]} ({confidence:.1%})")
|
| 110 |
```
|
| 111 |
|
| 112 |
## Training
|