MedLayEval / model.py
anonymous-medical's picture
Upload folder using huggingface_hub
67d9005 verified
"""Shared MedLayEval model definition.
A LoRA-adapted Qwen2.5-VL backbone with a 5-attribute regression head.
Used by:
- train.py (distillation training)
- validate.py (held-out correlation report)
- inference.py (scoring benchmark VLM outputs)
The five attributes, in fixed order, are:
modality, anatomy, finding, factual, readability.
"""
import torch
import torch.nn as nn
ATTRS = ["modality", "anatomy", "finding", "factual", "readability"]
class VLMRegressor(nn.Module):
"""LoRA-adapted VLM with an attention-mask-pooled regression head.
Forward returns shape (B, 5) sigmoid scores in [0, 1].
"""
def __init__(self, vlm, hidden_size: int):
super().__init__()
self.vlm = vlm
self.head = nn.Sequential(
nn.Linear(hidden_size, 256),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(256, 5),
nn.Sigmoid(),
)
def forward(self, target_scores=None, **inputs):
# `target_scores` is ignored at inference; accepting it lets the
# train loop pass labels through the same forward signature.
out = self.vlm(**inputs, output_hidden_states=True, return_dict=True)
h = out.hidden_states[-1] # (B, T, H)
mask = inputs["attention_mask"].unsqueeze(-1).to(h.dtype)
pooled = (h * mask).sum(1) / mask.sum(1).clamp(min=1)
return self.head(pooled.float())