"""Shared MedLayEval model definition. A LoRA-adapted Qwen2.5-VL backbone with a 5-attribute regression head. Used by: - train.py (distillation training) - validate.py (held-out correlation report) - inference.py (scoring benchmark VLM outputs) The five attributes, in fixed order, are: modality, anatomy, finding, factual, readability. """ import torch import torch.nn as nn ATTRS = ["modality", "anatomy", "finding", "factual", "readability"] class VLMRegressor(nn.Module): """LoRA-adapted VLM with an attention-mask-pooled regression head. Forward returns shape (B, 5) sigmoid scores in [0, 1]. """ def __init__(self, vlm, hidden_size: int): super().__init__() self.vlm = vlm self.head = nn.Sequential( nn.Linear(hidden_size, 256), nn.GELU(), nn.Dropout(0.1), nn.Linear(256, 5), nn.Sigmoid(), ) def forward(self, target_scores=None, **inputs): # `target_scores` is ignored at inference; accepting it lets the # train loop pass labels through the same forward signature. out = self.vlm(**inputs, output_hidden_states=True, return_dict=True) h = out.hidden_states[-1] # (B, T, H) mask = inputs["attention_mask"].unsqueeze(-1).to(h.dtype) pooled = (h * mask).sum(1) / mask.sum(1).clamp(min=1) return self.head(pooled.float())