"""chest2err — self-contained loader. Usage: from chest2err import chest2err_score, chest2err_detail score = chest2err_score(ref, cand) # float in (0, 1] detail = chest2err_detail(ref, cand) # full breakdown The bundle ships the merged backbone weights, the decoder weights, the tokenizer, and the concept vocabulary. No additional downloads occur at inference; the Qwen3-architecture backbone class is taken from the `transformers` package and instantiated from the bundled `config.json`. """ from __future__ import annotations import json import math import os from pathlib import Path from typing import Any, Dict, List, Optional import torch from transformers import AutoModel, AutoTokenizer from safetensors.torch import load_file # Sibling files in this package from chest2err_modeling import CADAD from chest2err_collate import encode_pair_for_decoder, collate_decoder_batch PACKAGE_DIR = Path(__file__).resolve().parent CAT_NAMES = {0: "EOS", 1: "false_prediction", 2: "omission", 3: "location", 4: "severity", 5: "comparison"} ANAT_NAMES = {0: "Lung & Airways", 1: "Cardiovascular", 2: "Mediastinum & Hila", 3: "Upper Abdomen", 4: "Pleura", 5: "Bones / Spine", 6: "Chest Wall", 7: "Lower Neck", 8: "Others"} def _load_config() -> Dict[str, Any]: with open(PACKAGE_DIR / "chest2err_config.json") as f: return json.load(f) class Chest2Err: """Loads the merged backbone + decoder once, then scores pairs.""" def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu", attn_implementation: Optional[str] = None): cfg = _load_config() self.cfg = cfg self.device = device self.max_length = cfg["max_length"] # Display temperature τ for the score exp(-K_total/τ). τ=3.0 is the # default gentle setting (one error → 0.72); τ=1.0 reproduces the # original exp(-K_total). Rank-equivalent, so τ never affects τ_b. self.score_temperature = float(cfg.get("score_temperature", 3.0)) # Concept vocab (size determines decoder output head dim) with open(PACKAGE_DIR / "concept2id.json") as f: self.concept2id: Dict[str, int] = json.load(f) self.n_concept = len(self.concept2id) self.id2concept = {v: k for k, v in self.concept2id.items()} # Tokenizer + backbone load from bundled files only. self.tokenizer = AutoTokenizer.from_pretrained(str(PACKAGE_DIR)) kw = {"torch_dtype": torch.bfloat16} if attn_implementation: kw["attn_implementation"] = attn_implementation backbone = AutoModel.from_pretrained(str(PACKAGE_DIR), **kw) # CADAD wraps the backbone + decoder. Construct, then load merged backbone # weights + decoder weights. self.model = CADAD( backbone=backbone, hidden_size=cfg["hidden_size"], n_cat=cfg["n_cat"], n_anat=cfg["n_anat"], n_concept=self.n_concept, n_severity=2, decoder_layers=cfg["decoder_layers"], decoder_heads=cfg["decoder_heads"], decoder_ff=cfg["decoder_ff"], dropout=cfg["decoder_dropout"], max_decode_steps=cfg["max_decode_steps"], ) # The backbone weights were already loaded by AutoModel.from_pretrained. # Now layer the decoder weights on top. decoder_state = load_file(str(PACKAGE_DIR / "decoder.safetensors")) missing, unexpected = self.model.load_state_dict(decoder_state, strict=False) # Expected: many `backbone.*` keys are "missing" from decoder_state # (they came from model.safetensors via from_pretrained). That's fine. self.model = self.model.to(device).eval() @torch.inference_mode() def score(self, ref: str, cand: str) -> float: return self.detail(ref, cand)["score"] @torch.inference_mode() def detail(self, ref: str, cand: str) -> Dict[str, Any]: item = encode_pair_for_decoder( self.tokenizer, ref, cand, max_length=self.max_length, ) batch = collate_decoder_batch([item], pad_token_id=self.tokenizer.pad_token_id or 0) batch = {k: v.to(self.device) for k, v in batch.items()} with torch.autocast( device_type="cuda" if str(self.device).startswith("cuda") else "cpu", dtype=torch.bfloat16, ): seqs = self.model.decode_greedy( batch["input_ids"], batch["attention_mask"], batch["ref_seg_token_mask"], batch["cand_seg_token_mask"], ) seq = seqs[0] K_total = len(seq) score = math.exp(-K_total / self.score_temperature) cat_counts = [0] * self.cfg["n_cat"] anat_counts = [0] * self.cfg["n_anat"] tuples_out: List[Dict[str, Any]] = [] for t in seq: c = int(t.get("cat", 0)) a = int(t.get("anat", 0)) if 1 <= c <= self.cfg["n_cat"]: cat_counts[c - 1] += 1 if 0 <= a < self.cfg["n_anat"]: anat_counts[a] += 1 tuples_out.append({ "cat": c, "cat_name": CAT_NAMES.get(c, str(c)), "anat": a, "anat_name": ANAT_NAMES.get(a, str(a)), "concept_id": int(t.get("concept_id", 0)), "concept": self.id2concept.get(int(t.get("concept_id", 0)), ""), "ref_seg_idx": int(t.get("ref_seg_idx", -1)), "cand_seg_idx": int(t.get("cand_seg_idx", -1)), }) return { "score": score, "K_total": K_total, "tuples": tuples_out, "category_counts": cat_counts, "anatomy_counts": anat_counts, } _INSTANCE: Optional[Chest2Err] = None def _get() -> Chest2Err: global _INSTANCE if _INSTANCE is None: _INSTANCE = Chest2Err() return _INSTANCE def chest2err_score(ref: str, cand: str) -> float: """chest2err-score ∈ (0, 1] for one (reference, candidate) report pair.""" return _get().score(ref, cand) def chest2err_detail(ref: str, cand: str) -> Dict[str, Any]: """Full breakdown: score, K_total, per-error tuples, per-category and per-anatomy counts.""" return _get().detail(ref, cand) __all__ = ["Chest2Err", "chest2err_score", "chest2err_detail"]