Update README with model loading code
Browse files
README.md
CHANGED
|
@@ -14,9 +14,9 @@ Part of the [LAPVQA collection](https://huggingface.co/collections/dmusingu/lapv
|
|
| 14 |
|
| 15 |
## Description
|
| 16 |
|
| 17 |
-
DiffVQA
|
| 18 |
([`lapvqa-pretrain-captioning`](https://huggingface.co/dmusingu/lapvqa-pretrain-captioning)).
|
| 19 |
-
|
| 20 |
|
| 21 |
## Results (test set)
|
| 22 |
|
|
@@ -24,8 +24,15 @@ The encoder is kept frozen; this file contains the task head only.
|
|
| 24 |
|---|---|---|---|
|
| 25 |
| 0.468 | 0.562 | 0.303 | 0.938 |
|
| 26 |
|
| 27 |
-
##
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
## Description
|
| 16 |
|
| 17 |
+
DiffVQA head trained on the frozen **LAPVQA captioning-pretrained encoder**
|
| 18 |
([`lapvqa-pretrain-captioning`](https://huggingface.co/dmusingu/lapvqa-pretrain-captioning)).
|
| 19 |
+
Checkpoint is a plain `DiffVQAHead` state dict (vis_dim=1024).
|
| 20 |
|
| 21 |
## Results (test set)
|
| 22 |
|
|
|
|
| 24 |
|---|---|---|---|
|
| 25 |
| 0.468 | 0.562 | 0.303 | 0.938 |
|
| 26 |
|
| 27 |
+
## Loading
|
| 28 |
|
| 29 |
+
```python
|
| 30 |
+
import torch
|
| 31 |
+
from lapvqa.diffvqa.model import DiffVQAHead
|
| 32 |
+
|
| 33 |
+
ckpt = torch.load("pretrain-captioning_best.pt", map_location="cpu")
|
| 34 |
+
head = DiffVQAHead(vis_dim=1024)
|
| 35 |
+
head.load_state_dict(ckpt)
|
| 36 |
+
head.eval()
|
| 37 |
+
# pair with encoder_final.pt from lapvqa-pretrain-captioning
|
| 38 |
+
```
|