chest2vec_labeler / modeling_chest2vec_labeler.py
lukeingawesome's picture
Upload modeling_chest2vec_labeler.py with huggingface_hub
6410e2a verified
Raw
History Blame Contribute Delete
18.2 kB
"""
Chest2Vec CT Report Labeler — HuggingFace `AutoModel` wrapper.
A weakly-supervised multi-label classifier that maps a free-text chest-CT report to a
137-leaf chest-imaging taxonomy with a ternary status per label
(negative / uncertain / positive).
Architecture: `Qwen/Qwen3-Embedding-0.6B` encoder (LoRA merged in) → left-padding-aware
last-token (EOS) pooling → L2-normalization → a single linear ternary head
(`hidden=1024 → 137 × 3`).
Usage:
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("chest2vec/chest2vec_labeler", trust_remote_code=True).eval()
tok = AutoTokenizer.from_pretrained("chest2vec/chest2vec_labeler", trust_remote_code=True)
reports = ["Bibasilar atelectasis with small bilateral pleural effusions. Cardiomegaly."]
print(model.label_reports(reports, tokenizer=tok)) # -> [{'Pleural effusion': 'positive', ...}]
# CheXbert / SRR-BERT-style report comparison (label both, compare):
res = model.score_reports(gt_reports, pred_reports, tokenizer=tok)
print(res["micro"]["f1"], res["macro"]["f1"], res["weighted"]["f1"])
"""
from typing import Dict, List, Optional, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
# class index ordering produced by the head's softmax (axis=-1)
NEGATIVE, UNCERTAIN, POSITIVE = 0, 1, 2
_CLASS_TO_VALUE = {NEGATIVE: 0, UNCERTAIN: -1, POSITIVE: 1}
_CLASS_TO_NAME = {NEGATIVE: "negative", UNCERTAIN: "uncertain", POSITIVE: "positive"}
class Chest2VecLabelerConfig(PretrainedConfig):
model_type = "chest2vec_labeler"
def __init__(
self,
encoder_config: Optional[dict] = None,
base_model: str = "Qwen/Qwen3-Embedding-0.6B",
hidden_size: int = 1024,
n_labels: int = 137,
num_classes_per_label: int = 3,
labels: Optional[List[str]] = None,
instruction: str = "Given the following chest CT report, extract the presence/absence of entities",
max_len: int = 512,
default_threshold: float = 0.5,
label_hierarchy: Optional[dict] = None,
**kwargs,
):
super().__init__(**kwargs)
self.encoder_config = encoder_config or {}
self.base_model = base_model
self.hidden_size = hidden_size
self.n_labels = n_labels
self.num_classes_per_label = num_classes_per_label
self.labels = labels or []
self.instruction = instruction
self.max_len = max_len
self.default_threshold = default_threshold
self.label_hierarchy = label_hierarchy or {}
@dataclass
class LabelerOutput(ModelOutput):
logits: torch.FloatTensor = None # [B, num_labels, 3]
embedding: torch.FloatTensor = None # [B, hidden] L2-normalized pooled
def _build_encoder(encoder_config: dict, attn_implementation: str = "sdpa"):
ecfg = dict(encoder_config)
for k in ("architectures", "auto_map", "transformers_version", "_name_or_path", "torch_dtype"):
ecfg.pop(k, None)
model_type = ecfg.pop("model_type", "qwen3")
cfg = AutoConfig.for_model(model_type, **ecfg)
cfg.torch_dtype = "float32"
try:
cfg._attn_implementation = attn_implementation
except Exception:
pass
try:
return AutoModel.from_config(cfg, attn_implementation=attn_implementation)
except TypeError:
return AutoModel.from_config(cfg)
def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Left-padding-aware last-token (EOS) pooling — matches the training pipeline."""
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
idx = attention_mask.sum(dim=1) - 1
return last_hidden_states[torch.arange(last_hidden_states.size(0), device=last_hidden_states.device), idx]
class Chest2VecLabelerModel(PreTrainedModel):
config_class = Chest2VecLabelerConfig
base_model_prefix = "model"
def __init__(self, config: Chest2VecLabelerConfig):
super().__init__(config)
self.model = _build_encoder(config.encoder_config, getattr(config, "attn_implementation", "sdpa"))
self.head = nn.Linear(config.hidden_size, config.n_labels * config.num_classes_per_label)
self.num_labels = config.n_labels
self.num_classes_per_label = config.num_classes_per_label
self._tokenizer = None
self.post_init()
# ---- core forward (token tensors in, logits out) ----
def forward(self, input_ids=None, attention_mask=None, position_ids=None, **kwargs):
if position_ids is None and attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
out = self.model(input_ids=input_ids, attention_mask=attention_mask,
position_ids=position_ids, use_cache=False, return_dict=True)
h = out.last_hidden_state if hasattr(out, "last_hidden_state") else out.hidden_states[-1]
emb = _last_token_pool(h, attention_mask)
emb = F.normalize(emb.float(), p=2, dim=-1)
logits = self.head(emb).view(emb.size(0), self.num_labels, self.num_classes_per_label)
return LabelerOutput(logits=logits, embedding=emb)
# ---- tokenization (matches training: Instruct/Query + reserved EOS + left pad) ----
def _get_tokenizer(self, tokenizer=None):
if tokenizer is not None:
return tokenizer
if self._tokenizer is None:
from transformers import AutoTokenizer
src = self.config._name_or_path or self.config.base_model
self._tokenizer = AutoTokenizer.from_pretrained(src, padding_side="left", trust_remote_code=True)
if self._tokenizer.pad_token_id is None:
self._tokenizer.pad_token = self._tokenizer.eos_token
return self._tokenizer
def _encode(self, tok, reports: List[str], max_len: int):
instr = self.config.instruction.strip()
texts = [(f"Instruct: {instr}\nQuery: {str(r).strip()}" if instr else str(r).strip()) for r in reports]
pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id
eod_id = tok.convert_tokens_to_ids("<|endoftext|>")
if eod_id is None or eod_id < 0:
eod_id = pad_id
enc = tok(texts, add_special_tokens=False, truncation=True, max_length=max_len - 1,
padding=False, return_attention_mask=False)
ids = [x + [eod_id] for x in enc["input_ids"]]
T = max((len(x) for x in ids), default=1)
input_ids = [[pad_id] * (T - len(x)) + x for x in ids]
attn = [[0] * (T - len(x)) + [1] * len(x) for x in ids]
return (torch.tensor(input_ids, dtype=torch.long), torch.tensor(attn, dtype=torch.long))
# ---- high-level prediction API ----
@torch.no_grad()
def predict_proba(self, reports: List[str], tokenizer=None, batch_size: int = 16,
max_len: Optional[int] = None, device=None) -> torch.Tensor:
"""Return [N, num_labels] probability of the POSITIVE class for each label."""
if isinstance(reports, str):
reports = [reports]
tok = self._get_tokenizer(tokenizer)
max_len = max_len or self.config.max_len
device = device or next(self.parameters()).device
self.eval()
out = []
for i in range(0, len(reports), batch_size):
ii, am = self._encode(tok, reports[i:i + batch_size], max_len)
logits = self(input_ids=ii.to(device), attention_mask=am.to(device)).logits
out.append(torch.softmax(logits.float(), dim=-1)[:, :, POSITIVE].cpu())
return torch.cat(out, dim=0)
@torch.no_grad()
def predict(self, reports: List[str], tokenizer=None, threshold: Optional[float] = None,
batch_size: int = 16, max_len: Optional[int] = None, device=None,
return_ternary: bool = False) -> Dict[str, Any]:
"""Return {'labels': names, 'positive': [N,L] 0/1, 'proba': [N,L], ('ternary': [N,L] in {-1,0,1})}."""
if isinstance(reports, str):
reports = [reports]
thr = self.config.default_threshold if threshold is None else threshold
tok = self._get_tokenizer(tokenizer)
max_len = max_len or self.config.max_len
device = device or next(self.parameters()).device
self.eval()
proba, ternary = [], []
for i in range(0, len(reports), batch_size):
ii, am = self._encode(tok, reports[i:i + batch_size], max_len)
logits = self(input_ids=ii.to(device), attention_mask=am.to(device)).logits.float().cpu()
proba.append(torch.softmax(logits, dim=-1)[:, :, POSITIVE])
if return_ternary:
cls = logits.argmax(-1)
ternary.append(torch.tensor([[_CLASS_TO_VALUE[int(c)] for c in row] for row in cls]))
proba = torch.cat(proba, dim=0)
res = {"labels": list(self.config.labels), "proba": proba.numpy(),
"positive": (proba >= thr).int().numpy(), "threshold": thr}
if return_ternary:
res["ternary"] = torch.cat(ternary, dim=0).numpy()
return res
def label_reports(self, reports: List[str], tokenizer=None, threshold: Optional[float] = None,
**kw) -> List[Dict[str, str]]:
"""Return, per report, a dict {label_name: 'positive'} for labels above threshold."""
out = self.predict(reports, tokenizer=tokenizer, threshold=threshold, **kw)
names = out["labels"]
return [{names[j]: "positive" for j in range(len(names)) if row[j]} for row in out["positive"]]
# ---- hierarchy roll-up (leaf -> upper -> anatomy), max over children ----
def aggregate_hierarchy(self, leaf_prob):
"""Roll leaf positive-probabilities up to upper and anatomy levels (max over children).
Mirrors the training-time evaluation: each upper group's score is the max over its
child-leaf probabilities; each anatomy score is the max over its upper groups plus the
section's `*_others` leaf. Returns (upper_prob, upper_names, anatomy_prob, anatomy_names).
"""
import numpy as np
leaf_prob = np.asarray(leaf_prob, dtype=np.float32)
H = self.config.label_hierarchy or {}
idx = {n: i for i, n in enumerate(self.config.labels)}
N = leaf_prob.shape[0]
u_names, u_cols, a_names, a_cols = [], [], [], []
for anat, groups in H.items():
a_names.append(anat)
ac = np.full(N, -1.0, dtype=np.float32)
for up, leaves in groups.items():
u_names.append(up)
cols = [idx[l] for l in leaves if l in idx]
uc = leaf_prob[:, cols].max(axis=1) if cols else np.zeros(N, dtype=np.float32)
u_cols.append(uc)
ac = np.maximum(ac, uc)
okey = f"{anat}_others"
if okey in idx:
ac = np.maximum(ac, leaf_prob[:, idx[okey]])
a_cols.append(np.maximum(ac, 0.0))
import numpy as _np
up = _np.column_stack(u_cols) if u_cols else _np.zeros((N, 0), dtype=_np.float32)
an = _np.column_stack(a_cols) if a_cols else _np.zeros((N, 0), dtype=_np.float32)
return up, u_names, an, a_names
# ---- CheXbert / SRR-BERT-style report-comparison F1 (leaf / upper / anatomy) ----
@torch.no_grad()
def score_reports(self, gt_reports: List[str], pred_reports: List[str], tokenizer=None,
threshold: Optional[float] = None, batch_size: int = 16,
max_len: Optional[int] = None, device=None,
levels=("leaf", "upper", "anatomy")) -> Dict[str, Any]:
"""
Label both GT and predicted reports, then compute label-agreement F1 (CheXbert-style)
at the requested hierarchy levels.
`gt_reports` labels are treated as truth, `pred_reports` as prediction. For each level
in `levels` ("leaf" = 137 labels, "upper" = container groups, "anatomy" = sections),
returns micro / macro / weighted precision-recall-F1 plus per-label scores.
"""
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
if len(gt_reports) != len(pred_reports):
raise ValueError("gt_reports and pred_reports must have the same length")
thr = self.config.default_threshold if threshold is None else threshold
kw = dict(tokenizer=tokenizer, batch_size=batch_size, max_len=max_len, device=device)
gt_leaf = self.predict_proba(gt_reports, **kw).numpy()
pr_leaf = self.predict_proba(pred_reports, **kw).numpy()
level_inputs = {"leaf": (gt_leaf, pr_leaf, list(self.config.labels))}
if "upper" in levels or "anatomy" in levels:
gu, un, ga, an = self.aggregate_hierarchy(gt_leaf)
pu, _, pa, _ = self.aggregate_hierarchy(pr_leaf)
level_inputs["upper"] = (gu, pu, un)
level_inputs["anatomy"] = (ga, pa, an)
res: Dict[str, Any] = {"n_reports": len(gt_reports), "threshold": thr}
for lvl in levels:
gp, pp, names = level_inputs[lvl]
y_true = (gp >= thr).astype(int)
y_pred = (pp >= thr).astype(int)
block: Dict[str, Any] = {"n_labels": len(names)}
for avg in ("micro", "macro", "weighted"):
p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average=avg, zero_division=0)
block[avg] = {"precision": float(p), "recall": float(r), "f1": float(f)}
p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None,
labels=list(range(len(names))), zero_division=0)
block["per_label"] = {names[j]: {"precision": float(p[j]), "recall": float(r[j]),
"f1": float(f[j]), "support_gt": int(s[j])} for j in range(len(names))}
res[lvl] = block
return res
# ---- per-label best F1 (threshold swept to maximize F1) vs ground-truth labels ----
def _to_positive_matrix(self, gt, names):
"""Coerce ground-truth labels to a [N, len(names)] binary positive matrix.
Accepts a pandas DataFrame with the label columns (ternary 1/0/-1/NaN; positive == 1),
or a numpy/torch array (ternary -> ==1, or already-binary 0/1)."""
import numpy as np
try:
import pandas as pd
if isinstance(gt, pd.DataFrame):
out = np.zeros((len(gt), len(names)), dtype=int)
for j, c in enumerate(names):
if c in gt.columns:
out[:, j] = (pd.to_numeric(gt[c], errors="coerce").fillna(0).values == 1).astype(int)
return out
except ImportError:
pass
arr = gt.detach().cpu().numpy() if hasattr(gt, "detach") else np.asarray(gt)
return (arr == 1).astype(int)
@torch.no_grad()
def per_label_best_f1(self, reports: List[str], gt, tokenizer=None, level: str = "leaf",
min_pos: int = 30, batch_size: int = 16, max_len: Optional[int] = None,
device=None) -> Dict[str, Any]:
"""
For each label, sweep the decision threshold and report the **F1-maximizing** operating
point (best F1 + the threshold that achieves it), evaluated against ground-truth labels.
`gt` is a ground-truth label matrix for `reports` (DataFrame with the 137 label columns,
or array). `level` is "leaf" / "upper" / "anatomy". Returns per-label best F1 / threshold /
n_pos, plus macro best-F1 over all labels and over labels with >= `min_pos` positives.
"""
import numpy as np
from sklearn.metrics import precision_recall_curve
leaf_names = list(self.config.labels)
gt_leaf = self._to_positive_matrix(gt, leaf_names)
pr_leaf = self.predict_proba(reports, tokenizer=tokenizer, batch_size=batch_size,
max_len=max_len, device=device).numpy()
if level == "leaf":
prob, names, gtb = pr_leaf, leaf_names, gt_leaf
else:
pu, un, pa, an = self.aggregate_hierarchy(pr_leaf)
gu, _, ga, _ = self.aggregate_hierarchy(gt_leaf.astype(np.float32))
prob, names, gtb = (pu, un, (gu >= 0.5).astype(int)) if level == "upper" else (pa, an, (ga >= 0.5).astype(int))
per: Dict[str, Any] = {}
all_best, ge_best = [], []
for j, lab in enumerate(names):
t = gtb[:, j].astype(int); s = prob[:, j].astype(float); npos = int(t.sum())
if npos == 0 or len(np.unique(t)) < 2:
bf, bt = 0.0, None
else:
p, r, thr = precision_recall_curve(t, s)
f1 = (2 * p * r / (p + r + 1e-12))[:-1]
bi = int(np.nanargmax(f1)); bf = float(f1[bi]); bt = float(thr[bi])
per[lab] = {"best_f1": bf, "best_threshold": bt, "n_pos": npos}
all_best.append(bf)
if npos >= min_pos:
ge_best.append(bf)
return {"level": level, "min_pos": min_pos,
"macro_best_f1": float(np.mean(all_best)) if all_best else 0.0,
"macro_best_f1_min_pos": float(np.mean(ge_best)) if ge_best else 0.0,
"n_labels_min_pos": len(ge_best), "per_label": per}
def report_f1(gt_reports: List[str], pred_reports: List[str], model=None, tokenizer=None,
model_id: str = "chest2vec/chest2vec_labeler", **kw) -> Dict[str, Any]:
"""Convenience wrapper: load the labeler (if not supplied) and score GT vs predicted reports."""
if model is None:
model = Chest2VecLabelerModel.from_pretrained(model_id).eval()
return model.score_reports(gt_reports, pred_reports, tokenizer=tokenizer, **kw)