dmusingu commited on
Commit
3be30cf
·
verified ·
1 Parent(s): 9309a90

Update README with model loading code

Browse files
Files changed (1) hide show
  1. README.md +13 -6
README.md CHANGED
@@ -14,9 +14,9 @@ Part of the [LAPVQA collection](https://huggingface.co/collections/dmusingu/lapv
14
 
15
  ## Description
16
 
17
- DiffVQA task head trained on top of the **LAPVQA captioning-pretrained encoder**
18
  ([`lapvqa-pretrain-captioning`](https://huggingface.co/dmusingu/lapvqa-pretrain-captioning)).
19
- The encoder is kept frozen; this file contains the task head only.
20
 
21
  ## Results (test set)
22
 
@@ -24,8 +24,15 @@ The encoder is kept frozen; this file contains the task head only.
24
  |---|---|---|---|
25
  | 0.468 | 0.562 | 0.303 | 0.938 |
26
 
27
- ## Files
28
 
29
- | File | Description |
30
- |---|---|
31
- | `pretrain-captioning_best.pt` | DiffVQA head (encoder not included) |
 
 
 
 
 
 
 
 
14
 
15
  ## Description
16
 
17
+ DiffVQA head trained on the frozen **LAPVQA captioning-pretrained encoder**
18
  ([`lapvqa-pretrain-captioning`](https://huggingface.co/dmusingu/lapvqa-pretrain-captioning)).
19
+ Checkpoint is a plain `DiffVQAHead` state dict (vis_dim=1024).
20
 
21
  ## Results (test set)
22
 
 
24
  |---|---|---|---|
25
  | 0.468 | 0.562 | 0.303 | 0.938 |
26
 
27
+ ## Loading
28
 
29
+ ```python
30
+ import torch
31
+ from lapvqa.diffvqa.model import DiffVQAHead
32
+
33
+ ckpt = torch.load("pretrain-captioning_best.pt", map_location="cpu")
34
+ head = DiffVQAHead(vis_dim=1024)
35
+ head.load_state_dict(ckpt)
36
+ head.eval()
37
+ # pair with encoder_final.pt from lapvqa-pretrain-captioning
38
+ ```