stefanosgikas commited on
Commit
018a811
Β·
verified Β·
1 Parent(s): 7d80d4e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +45 -31
README.md CHANGED
@@ -70,43 +70,69 @@ A Vision Foundation Model for Affective Computing
70
 
71
  ---
72
 
73
- ## Pre-trained Weights
74
 
75
- Get the weights from the **[GitHub Releases](https://github.com/GkikasStefanos/PainFormer/releases)**.
76
 
77
- | File | Size |
78
- | ---------------- | ------- |
79
- | `painformer.pth` | **TBA** |
 
 
80
 
81
  ```bash
82
- # download the latest checkpoint
83
- auto=https://github.com/GkikasStefanos/PainFormer/releases/latest/download/painformer.pth
84
- curl -L -o painformer.pth "$auto"
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # optional: verify
87
- sha256sum painformer.pth
88
  ```
89
 
90
- The checkpoint contains **one key**:
91
 
92
- ```text
93
  model_state_dict # PainFormer backbone weights
94
  ```
 
95
  ---
96
 
97
  ## Quick start
98
 
99
- > Assumes **PyTorch β‰₯ 2.0** and **timm β‰₯ 0.9** are installed.
 
 
 
 
 
 
 
 
 
100
 
101
  ### Extract embeddings
102
 
103
  ```python
104
  import torch
105
  from timm.models import create_model
106
- from architecture import painformer
107
  from PIL import Image
108
  from torchvision import transforms
109
 
 
 
 
110
  # ---------------------------------------------------------------
111
  # Setup ---------------------------------------------------------
112
  # ---------------------------------------------------------------
@@ -126,8 +152,8 @@ to_tensor = transforms.Compose([
126
  # ---------------------------------------------------------------
127
  # Load PainFormer -----------------------------------------------
128
  # ---------------------------------------------------------------
129
- model = create_model('painformer').to(device)
130
- state = torch.load('./checkpoints/painformer.pth', map_location=device)
131
  model.load_state_dict(state['model_state_dict'], strict=False)
132
 
133
  # expose embeddings (remove classification head)
@@ -158,15 +184,12 @@ import torch, torch.nn as nn
158
  from timm.models import create_model
159
  from architecture import painformer
160
 
161
- # ---------------------------------------------------------------
162
- # Setup ----------------------------------------------------------
163
- # ---------------------------------------------------------------
164
  device = "cuda" if torch.cuda.is_available() else "cpu"
165
  num_classes = 3 # set to your task
166
 
167
  # Backbone β†’ 160-D embeddings
168
  model = create_model('painformer').to(device)
169
- state = torch.load('painformer.pth', map_location=device)
170
  model.load_state_dict(state['model_state_dict'], strict=False)
171
 
172
  # freeze if you only need fixed embeddings
@@ -182,16 +205,7 @@ head = nn.Sequential(
182
  optimizer = torch.optim.Adam(head.parameters(), lr=1e-3)
183
  criterion = nn.CrossEntropyLoss()
184
 
185
- # one step (sketch)
186
- def step(x, y):
187
- model.eval()
188
- with torch.no_grad():
189
- z = model(x) # [B, 160]
190
- logits = head(z) # [B, C]
191
- loss = criterion(logits, y)
192
- return loss, logits
193
-
194
- # --- optional: end-to-end fine-tune ---
195
  for p in model.parameters():
196
  p.requires_grad = True
197
  optimizer = torch.optim.AdamW(
@@ -221,7 +235,7 @@ optimizer = torch.optim.AdamW(
221
 
222
  ## Licence & acknowledgements
223
 
224
- * Code & weights: **MIT Licence** – see [`LICENSE`](./LICENSE)
225
 
226
  ---
227
 
 
70
 
71
  ---
72
 
73
+ ## Pre-trained Checkpoint
74
 
75
+ The checkpoint is stored under `checkpoint/` in this repository.
76
 
77
+ | File | Size |
78
+ | --------------------------- | ------- |
79
+ | `checkpoint/painformer.pth` | **75 MB** |
80
+
81
+ ### Download options
82
 
83
  ```bash
84
+ # direct file download (PainFormer)
85
+ mkdir -p checkpoint
86
+ wget https://huggingface.co/stefanosgikas/PainFormer/resolve/main/checkpoint/painformer.pth
87
+ ```
88
+
89
+ ```python
90
+ from huggingface_hub import hf_hub_download
91
+ ckpt_path = hf_hub_download(
92
+ repo_id="stefanosgikas/PainFormer",
93
+ filename="checkpoint/painformer.pth"
94
+ )
95
+ print(ckpt_path)
96
+ ```
97
+
98
+ Optional integrity check:
99
 
100
+ ```bash
101
+ sha256sum checkpoint/painformer.pth
102
  ```
103
 
104
+ The checkpoint contains:
105
 
106
+ ```
107
  model_state_dict # PainFormer backbone weights
108
  ```
109
+
110
  ---
111
 
112
  ## Quick start
113
 
114
+ Assumes **PyTorch β‰₯ 2.0** and **timm β‰₯ 0.9** are installed.
115
+
116
+ Repository layout expected:
117
+
118
+ ```
119
+ .
120
+ β”œβ”€β”€ docs/ # images for the model card
121
+ β”œβ”€β”€ architecture/ # Python modules (e.g., painformer.py)
122
+ └── checkpoint/ # painformer.pth
123
+ ```
124
 
125
  ### Extract embeddings
126
 
127
  ```python
128
  import torch
129
  from timm.models import create_model
 
130
  from PIL import Image
131
  from torchvision import transforms
132
 
133
+ # model code lives in the local "architecture" folder
134
+ from architecture import painformer # ensures registry / model class is imported
135
+
136
  # ---------------------------------------------------------------
137
  # Setup ---------------------------------------------------------
138
  # ---------------------------------------------------------------
 
152
  # ---------------------------------------------------------------
153
  # Load PainFormer -----------------------------------------------
154
  # ---------------------------------------------------------------
155
+ model = create_model('painformer').to(device) # class registered by architecture/painformer.py
156
+ state = torch.load('checkpoint/painformer.pth', map_location=device)
157
  model.load_state_dict(state['model_state_dict'], strict=False)
158
 
159
  # expose embeddings (remove classification head)
 
184
  from timm.models import create_model
185
  from architecture import painformer
186
 
 
 
 
187
  device = "cuda" if torch.cuda.is_available() else "cpu"
188
  num_classes = 3 # set to your task
189
 
190
  # Backbone β†’ 160-D embeddings
191
  model = create_model('painformer').to(device)
192
+ state = torch.load('checkpoint/painformer.pth', map_location=device)
193
  model.load_state_dict(state['model_state_dict'], strict=False)
194
 
195
  # freeze if you only need fixed embeddings
 
205
  optimizer = torch.optim.Adam(head.parameters(), lr=1e-3)
206
  criterion = nn.CrossEntropyLoss()
207
 
208
+ # optional: end-to-end fine-tune
 
 
 
 
 
 
 
 
 
209
  for p in model.parameters():
210
  p.requires_grad = True
211
  optimizer = torch.optim.AdamW(
 
235
 
236
  ## Licence & acknowledgements
237
 
238
+ * Code & weights: **MIT Licence** – see `LICENSE`.
239
 
240
  ---
241