chest2err / chest2err.py
lukeingawesome's picture
Soften chest2err-score: add display temperature tau=3.0 (default)
06de0a9 verified
Raw
History Blame Contribute Delete
6.52 kB
"""chest2err — self-contained loader.
Usage:
from chest2err import chest2err_score, chest2err_detail
score = chest2err_score(ref, cand) # float in (0, 1]
detail = chest2err_detail(ref, cand) # full breakdown
The bundle ships the merged backbone weights, the decoder weights, the
tokenizer, and the concept vocabulary. No additional downloads occur at
inference; the Qwen3-architecture backbone class is taken from the
`transformers` package and instantiated from the bundled `config.json`.
"""
from __future__ import annotations
import json
import math
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from transformers import AutoModel, AutoTokenizer
from safetensors.torch import load_file
# Sibling files in this package
from chest2err_modeling import CADAD
from chest2err_collate import encode_pair_for_decoder, collate_decoder_batch
PACKAGE_DIR = Path(__file__).resolve().parent
CAT_NAMES = {0: "EOS", 1: "false_prediction", 2: "omission", 3: "location",
4: "severity", 5: "comparison"}
ANAT_NAMES = {0: "Lung & Airways", 1: "Cardiovascular", 2: "Mediastinum & Hila",
3: "Upper Abdomen", 4: "Pleura", 5: "Bones / Spine", 6: "Chest Wall",
7: "Lower Neck", 8: "Others"}
def _load_config() -> Dict[str, Any]:
with open(PACKAGE_DIR / "chest2err_config.json") as f:
return json.load(f)
class Chest2Err:
"""Loads the merged backbone + decoder once, then scores pairs."""
def __init__(self,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
attn_implementation: Optional[str] = None):
cfg = _load_config()
self.cfg = cfg
self.device = device
self.max_length = cfg["max_length"]
# Display temperature τ for the score exp(-K_total/τ). τ=3.0 is the
# default gentle setting (one error → 0.72); τ=1.0 reproduces the
# original exp(-K_total). Rank-equivalent, so τ never affects τ_b.
self.score_temperature = float(cfg.get("score_temperature", 3.0))
# Concept vocab (size determines decoder output head dim)
with open(PACKAGE_DIR / "concept2id.json") as f:
self.concept2id: Dict[str, int] = json.load(f)
self.n_concept = len(self.concept2id)
self.id2concept = {v: k for k, v in self.concept2id.items()}
# Tokenizer + backbone load from bundled files only.
self.tokenizer = AutoTokenizer.from_pretrained(str(PACKAGE_DIR))
kw = {"torch_dtype": torch.bfloat16}
if attn_implementation:
kw["attn_implementation"] = attn_implementation
backbone = AutoModel.from_pretrained(str(PACKAGE_DIR), **kw)
# CADAD wraps the backbone + decoder. Construct, then load merged backbone
# weights + decoder weights.
self.model = CADAD(
backbone=backbone,
hidden_size=cfg["hidden_size"],
n_cat=cfg["n_cat"],
n_anat=cfg["n_anat"],
n_concept=self.n_concept,
n_severity=2,
decoder_layers=cfg["decoder_layers"],
decoder_heads=cfg["decoder_heads"],
decoder_ff=cfg["decoder_ff"],
dropout=cfg["decoder_dropout"],
max_decode_steps=cfg["max_decode_steps"],
)
# The backbone weights were already loaded by AutoModel.from_pretrained.
# Now layer the decoder weights on top.
decoder_state = load_file(str(PACKAGE_DIR / "decoder.safetensors"))
missing, unexpected = self.model.load_state_dict(decoder_state, strict=False)
# Expected: many `backbone.*` keys are "missing" from decoder_state
# (they came from model.safetensors via from_pretrained). That's fine.
self.model = self.model.to(device).eval()
@torch.inference_mode()
def score(self, ref: str, cand: str) -> float:
return self.detail(ref, cand)["score"]
@torch.inference_mode()
def detail(self, ref: str, cand: str) -> Dict[str, Any]:
item = encode_pair_for_decoder(
self.tokenizer, ref, cand, max_length=self.max_length,
)
batch = collate_decoder_batch([item],
pad_token_id=self.tokenizer.pad_token_id or 0)
batch = {k: v.to(self.device) for k, v in batch.items()}
with torch.autocast(
device_type="cuda" if str(self.device).startswith("cuda") else "cpu",
dtype=torch.bfloat16,
):
seqs = self.model.decode_greedy(
batch["input_ids"],
batch["attention_mask"],
batch["ref_seg_token_mask"],
batch["cand_seg_token_mask"],
)
seq = seqs[0]
K_total = len(seq)
score = math.exp(-K_total / self.score_temperature)
cat_counts = [0] * self.cfg["n_cat"]
anat_counts = [0] * self.cfg["n_anat"]
tuples_out: List[Dict[str, Any]] = []
for t in seq:
c = int(t.get("cat", 0))
a = int(t.get("anat", 0))
if 1 <= c <= self.cfg["n_cat"]:
cat_counts[c - 1] += 1
if 0 <= a < self.cfg["n_anat"]:
anat_counts[a] += 1
tuples_out.append({
"cat": c, "cat_name": CAT_NAMES.get(c, str(c)),
"anat": a, "anat_name": ANAT_NAMES.get(a, str(a)),
"concept_id": int(t.get("concept_id", 0)),
"concept": self.id2concept.get(int(t.get("concept_id", 0)), "<UNK>"),
"ref_seg_idx": int(t.get("ref_seg_idx", -1)),
"cand_seg_idx": int(t.get("cand_seg_idx", -1)),
})
return {
"score": score,
"K_total": K_total,
"tuples": tuples_out,
"category_counts": cat_counts,
"anatomy_counts": anat_counts,
}
_INSTANCE: Optional[Chest2Err] = None
def _get() -> Chest2Err:
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = Chest2Err()
return _INSTANCE
def chest2err_score(ref: str, cand: str) -> float:
"""chest2err-score ∈ (0, 1] for one (reference, candidate) report pair."""
return _get().score(ref, cand)
def chest2err_detail(ref: str, cand: str) -> Dict[str, Any]:
"""Full breakdown: score, K_total, per-error tuples, per-category and per-anatomy counts."""
return _get().detail(ref, cand)
__all__ = ["Chest2Err", "chest2err_score", "chest2err_detail"]