lapvqa-ad / README.md
dmusingu's picture
Update README with model loading code
fdbe183 verified
|
Raw
History Blame Contribute Delete
2.04 kB
metadata
tags:
  - chest-xray
  - radiology
  - object-detection
  - abnormality-detection
  - vindr-cxr
license: apache-2.0

LAPVQA — Abnormality Detection

Part of the LAPVQA collection.

Description

DETR-style detection heads for 14-class chest abnormality detection on VinDr-CXR, trained on top of six frozen vision encoders. Each checkpoint is a dict: {state_dict, vis_dim, d_model, num_queries, num_enc, num_dec, encoder, epoch, val_map40, val_map50}.

Architecture — DetectionHead

vis_proj        : Linear(vis_dim → 256)
encoder         : 2 × TransformerEncoderLayer (self-attn, pre-norm)
object_queries  : Parameter [1, 20, 256]
decoder         : 3 × TransformerDecoderLayer (cross-attn to encoder output)
class_head      : Linear(256 → 15)   # 14 classes + background
box_head        : MLP(256 → 256 → 4) # (cx,cy,w,h) ∈ [0,1]

Results (VinDr-CXR test, mAP@IoU=0.4)

Encoder mAP@0.4 (test)
OWLv2 0.048
SigLIP ~0.045
CLIP ViT-L/14 ~0.040
File Encoder vis_dim
clip-vit-l14.pt CLIP ViT-L/14 1024
siglip.pt SigLIP 1152
florence2.pt Florence-2 1024
coca.pt CoCa 768
owlv2.pt OWLv2 1024
mae-vit-l16.pt MAE ViT-L/16 1024

Loading

import torch
from lapvqa.ad.heads import DetectionHead
from lapvqa.ad.heads import predict

ckpt = torch.load("owlv2.pt", map_location="cpu")
head = DetectionHead(
    vis_dim     = ckpt["vis_dim"],
    d_model     = ckpt["d_model"],
    num_queries = ckpt["num_queries"],
    num_enc_layers = ckpt["num_enc"],
    num_dec_layers = ckpt["num_dec"],
)
head.load_state_dict(ckpt["state_dict"])
head.eval()

with torch.no_grad():
    # vis_tokens: [B, HW, vis_dim] — spatial patch tokens from the frozen encoder
    outputs    = head(vis_tokens)
    detections = predict(outputs, score_threshold=0.1, nms_iou=0.5)
    # detections[i]: {'boxes': [K,4] xyxy, 'labels': [K], 'scores': [K]}