| --- |
| tags: |
| - chest-xray |
| - radiology |
| - report-generation |
| - mimic-cxr |
| license: apache-2.0 |
| --- |
| |
| # LAPVQA β Radiology Report Generation (Frozen Off-the-shelf Encoders) |
|
|
| Part of the [LAPVQA collection](https://huggingface.co/collections/dmusingu/lapvqa). |
|
|
| ## Description |
|
|
| Autoregressive decoder heads for **Radiology Report Generation** on MIMIC-CXR, |
| trained on top of five **frozen** off-the-shelf encoders. |
| Each checkpoint is a dict: `{state_dict, vis_dim, d_model, num_layers, nhead, encoder, epoch, val_bleu4}`. |
|
|
| ## Architecture β `ReportGenerationHead` |
|
|
| ``` |
| vis_proj : Linear(vis_dim β 512) |
| tok_emb : Embedding(50257, 512) # GPT-2 vocab, weight-tied with lm_head |
| pos_emb : Embedding(150, 512) |
| decoder : 6 Γ TransformerDecoderLayer (pre-norm) |
| lm_head : Linear(512 β 50257, bias=False) |
| ``` |
|
|
| ## Results (MIMIC-CXR test set) |
|
|
| | Encoder | BLEU-4 | ROUGE-L | RadGraph-s | |
| |---|---|---|---| |
| | SigLIP | 0.036 | 0.168 | 0.211 | |
| | Florence-2 | 0.035 | 0.169 | 0.205 | |
| | CLIP ViT-L/14 | 0.034 | 0.168 | 0.197 | |
| | OWLv2 | 0.034 | 0.169 | 0.197 | |
| | CoCa | 0.030 | 0.160 | 0.193 | |
|
|
| | File | Encoder | vis_dim | |
| |---|---|---| |
| | `siglip.pt` | SigLIP | 1152 | |
| | `florence2.pt` | Florence-2 | 1024 | |
| | `clip-vit-l14.pt` | CLIP ViT-L/14 | 1024 | |
| | `owlv2.pt` | OWLv2 | 1024 | |
| | `coca.pt` | CoCa | 768 | |
| |
| ## Loading |
| |
| ```python |
| import torch |
| import tiktoken |
| from lapvqa.rrg.heads import ReportGenerationHead |
| |
| ckpt = torch.load("siglip.pt", map_location="cpu") |
| head = ReportGenerationHead( |
| vis_dim = ckpt["vis_dim"], |
| d_model = ckpt["d_model"], |
| num_layers = ckpt["num_layers"], |
| nhead = ckpt["nhead"], |
| ) |
| head.load_state_dict(ckpt["state_dict"]) |
| head.eval() |
| |
| enc = tiktoken.get_encoding("gpt2") |
| bos_id = eos_id = enc.eot_token |
|
|
| # vis_tokens: [B, N, vis_dim] β patch tokens from the frozen encoder |
| token_ids = head.generate(vis_tokens, bos_id=bos_id, eos_id=eos_id, max_len=150) |
| reports = [enc.decode(ids) for ids in token_ids] |
| ``` |
| |