dmusingu commited on
Commit
fbbce50
·
verified ·
1 Parent(s): e6e5c1c

Update README with model loading code

Browse files
Files changed (1) hide show
  1. README.md +46 -16
README.md CHANGED
@@ -13,31 +13,61 @@ Part of the [LAPVQA collection](https://huggingface.co/collections/dmusingu/lapv
13
 
14
  ## Description
15
 
16
- Lightweight task heads for **closed-domain Visual Question Answering** on MIMIC-Diff-VQA,
17
  trained on top of five **frozen** off-the-shelf vision encoders.
18
- Each `.pt` file contains the task head weights for one encoder variant;
19
- the underlying encoder weights are not included and must be loaded separately.
20
 
21
- ## Setup
22
 
23
- The head takes the frozen encoder's patch tokens as input and is trained with cross-entropy
24
- over answer vocabulary. The encoder is kept frozen throughout training.
 
 
 
 
 
25
 
26
- ## Results (test set, overall BLEU-4)
 
 
 
 
 
 
27
 
28
- | Encoder (frozen) | BLEU-1 | BLEU-4 | ROUGE-L | RadGraph-s |
 
 
29
  |---|---|---|---|---|
30
  | CLIP ViT-L/14 | 0.602 | 0.243 | 0.725 | 0.222 |
31
  | SigLIP | 0.586 | 0.253 | 0.717 | 0.214 |
32
  | Florence-2 | 0.575 | 0.207 | 0.700 | 0.217 |
33
  | CoCa | 0.532 | 0.173 | 0.642 | 0.170 |
34
 
35
- ## Files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- | File | Encoder backbone |
38
- |---|---|
39
- | `clip-vit-l14_best.pt` | CLIP ViT-L/14 |
40
- | `siglip_best.pt` | SigLIP (ViT-SO400M-14-384) |
41
- | `florence2_best.pt` | Florence-2 |
42
- | `coca_best.pt` | CoCa |
43
- | `owlv2_best.pt` | OWLv2 |
 
 
 
13
 
14
  ## Description
15
 
16
+ Lightweight task heads for **Visual Question Answering** on MIMIC-Diff-VQA,
17
  trained on top of five **frozen** off-the-shelf vision encoders.
18
+ Each `.pt` file contains only the task head weights; load the encoder separately.
 
19
 
20
+ ## Architecture — `VQAHead`
21
 
22
+ ```
23
+ vis_proj : Linear(vis_dim 512)
24
+ tok_emb : Embedding(50257, 512) # GPT-2 vocab, weight-tied with lm_head
25
+ pos_emb : Embedding(150, 512)
26
+ decoder : 6 × TransformerDecoderLayer (pre-norm, cross-attn to visual tokens)
27
+ lm_head : Linear(512 → 50257, bias=False)
28
+ ```
29
 
30
+ | File | Encoder | vis_dim |
31
+ |---|---|---|
32
+ | `clip-vit-l14_best.pt` | CLIP ViT-L/14 | 1024 |
33
+ | `siglip_best.pt` | SigLIP ViT-SO400M-14-384 | 1152 |
34
+ | `florence2_best.pt` | Florence-2 | 1024 |
35
+ | `coca_best.pt` | CoCa | 768 |
36
+ | `owlv2_best.pt` | OWLv2 | 1024 |
37
 
38
+ ## Results (test set, overall)
39
+
40
+ | Encoder | BLEU-1 | BLEU-4 | ROUGE-L | RadGraph-s |
41
  |---|---|---|---|---|
42
  | CLIP ViT-L/14 | 0.602 | 0.243 | 0.725 | 0.222 |
43
  | SigLIP | 0.586 | 0.253 | 0.717 | 0.214 |
44
  | Florence-2 | 0.575 | 0.207 | 0.700 | 0.217 |
45
  | CoCa | 0.532 | 0.173 | 0.642 | 0.170 |
46
 
47
+ ## Loading
48
+
49
+ ```python
50
+ import torch
51
+ import tiktoken
52
+ from lapvqa.vqa.model import VQAHead
53
+
54
+ # checkpoint is a plain state dict
55
+ ckpt = torch.load("clip-vit-l14_best.pt", map_location="cpu")
56
+ head = VQAHead(vis_dim=1024)
57
+ head.load_state_dict(ckpt)
58
+ head.eval()
59
+
60
+ # vis_tokens: [B, N, vis_dim] — patch tokens from the frozen encoder
61
+ # prompt_ids: [B, Q] — tokenised question (GPT-2 tokeniser)
62
+ enc = tiktoken.get_encoding("gpt2")
63
+ bos_id, eos_id = enc.eot_token, enc.eot_token
64
 
65
+ answers = head.generate(
66
+ vis_tokens = vis_tokens,
67
+ prompt_ids = prompt_ids,
68
+ bos_id = bos_id,
69
+ eos_id = eos_id,
70
+ max_new_tokens = 64,
71
+ )
72
+ decoded = [enc.decode(ids) for ids in answers]
73
+ ```