davanstrien HF Staff commited on
Commit
59f7af6
·
verified ·
1 Parent(s): c71f17a

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +10 -3
README.md CHANGED
@@ -74,13 +74,18 @@ weighted avg 0.969 0.967 0.967 30
74
  ```python
75
  import timm
76
  import torch
 
77
  from PIL import Image
78
  from safetensors.torch import load_file
79
  from torchvision import transforms
80
 
81
- # Load model
 
 
 
 
82
  model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=2)
83
- model.load_state_dict(load_file('classifier.safetensors'))
84
  model.eval()
85
 
86
  # Preprocess
@@ -96,10 +101,12 @@ input_tensor = transform(image).unsqueeze(0)
96
 
97
  with torch.no_grad():
98
  output = model(input_tensor)
 
99
  pred = output.argmax(1).item()
 
100
 
101
  classes = ['other', 'index_card']
102
- print(f"Prediction: {classes[pred]}")
103
  ```
104
 
105
  ## Training
 
74
  ```python
75
  import timm
76
  import torch
77
+ from huggingface_hub import hf_hub_download
78
  from PIL import Image
79
  from safetensors.torch import load_file
80
  from torchvision import transforms
81
 
82
+ # Download and load model from Hub
83
+ weights_path = hf_hub_download(
84
+ repo_id="davanstrien/nls-index-card-classifier",
85
+ filename="classifier.safetensors"
86
+ )
87
  model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=2)
88
+ model.load_state_dict(load_file(weights_path))
89
  model.eval()
90
 
91
  # Preprocess
 
101
 
102
  with torch.no_grad():
103
  output = model(input_tensor)
104
+ probs = torch.softmax(output, dim=1)
105
  pred = output.argmax(1).item()
106
+ confidence = probs[0, pred].item()
107
 
108
  classes = ['other', 'index_card']
109
+ print(f"Prediction: {classes[pred]} ({confidence:.1%})")
110
  ```
111
 
112
  ## Training