| """chest2err — self-contained loader. |
| |
| Usage: |
| from chest2err import chest2err_score, chest2err_detail |
| score = chest2err_score(ref, cand) # float in (0, 1] |
| detail = chest2err_detail(ref, cand) # full breakdown |
| |
| The bundle ships the merged backbone weights, the decoder weights, the |
| tokenizer, and the concept vocabulary. No additional downloads occur at |
| inference; the Qwen3-architecture backbone class is taken from the |
| `transformers` package and instantiated from the bundled `config.json`. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import math |
| import os |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional |
|
|
| import torch |
| from transformers import AutoModel, AutoTokenizer |
| from safetensors.torch import load_file |
|
|
| |
| from chest2err_modeling import CADAD |
| from chest2err_collate import encode_pair_for_decoder, collate_decoder_batch |
|
|
| PACKAGE_DIR = Path(__file__).resolve().parent |
|
|
| CAT_NAMES = {0: "EOS", 1: "false_prediction", 2: "omission", 3: "location", |
| 4: "severity", 5: "comparison"} |
| ANAT_NAMES = {0: "Lung & Airways", 1: "Cardiovascular", 2: "Mediastinum & Hila", |
| 3: "Upper Abdomen", 4: "Pleura", 5: "Bones / Spine", 6: "Chest Wall", |
| 7: "Lower Neck", 8: "Others"} |
|
|
|
|
| def _load_config() -> Dict[str, Any]: |
| with open(PACKAGE_DIR / "chest2err_config.json") as f: |
| return json.load(f) |
|
|
|
|
| class Chest2Err: |
| """Loads the merged backbone + decoder once, then scores pairs.""" |
|
|
| def __init__(self, |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", |
| attn_implementation: Optional[str] = None): |
| cfg = _load_config() |
| self.cfg = cfg |
| self.device = device |
| self.max_length = cfg["max_length"] |
| |
| |
| |
| self.score_temperature = float(cfg.get("score_temperature", 3.0)) |
|
|
| |
| with open(PACKAGE_DIR / "concept2id.json") as f: |
| self.concept2id: Dict[str, int] = json.load(f) |
| self.n_concept = len(self.concept2id) |
| self.id2concept = {v: k for k, v in self.concept2id.items()} |
|
|
| |
| self.tokenizer = AutoTokenizer.from_pretrained(str(PACKAGE_DIR)) |
| kw = {"torch_dtype": torch.bfloat16} |
| if attn_implementation: |
| kw["attn_implementation"] = attn_implementation |
| backbone = AutoModel.from_pretrained(str(PACKAGE_DIR), **kw) |
|
|
| |
| |
| self.model = CADAD( |
| backbone=backbone, |
| hidden_size=cfg["hidden_size"], |
| n_cat=cfg["n_cat"], |
| n_anat=cfg["n_anat"], |
| n_concept=self.n_concept, |
| n_severity=2, |
| decoder_layers=cfg["decoder_layers"], |
| decoder_heads=cfg["decoder_heads"], |
| decoder_ff=cfg["decoder_ff"], |
| dropout=cfg["decoder_dropout"], |
| max_decode_steps=cfg["max_decode_steps"], |
| ) |
|
|
| |
| |
| decoder_state = load_file(str(PACKAGE_DIR / "decoder.safetensors")) |
| missing, unexpected = self.model.load_state_dict(decoder_state, strict=False) |
| |
| |
|
|
| self.model = self.model.to(device).eval() |
|
|
| @torch.inference_mode() |
| def score(self, ref: str, cand: str) -> float: |
| return self.detail(ref, cand)["score"] |
|
|
| @torch.inference_mode() |
| def detail(self, ref: str, cand: str) -> Dict[str, Any]: |
| item = encode_pair_for_decoder( |
| self.tokenizer, ref, cand, max_length=self.max_length, |
| ) |
| batch = collate_decoder_batch([item], |
| pad_token_id=self.tokenizer.pad_token_id or 0) |
| batch = {k: v.to(self.device) for k, v in batch.items()} |
|
|
| with torch.autocast( |
| device_type="cuda" if str(self.device).startswith("cuda") else "cpu", |
| dtype=torch.bfloat16, |
| ): |
| seqs = self.model.decode_greedy( |
| batch["input_ids"], |
| batch["attention_mask"], |
| batch["ref_seg_token_mask"], |
| batch["cand_seg_token_mask"], |
| ) |
| seq = seqs[0] |
| K_total = len(seq) |
| score = math.exp(-K_total / self.score_temperature) |
|
|
| cat_counts = [0] * self.cfg["n_cat"] |
| anat_counts = [0] * self.cfg["n_anat"] |
| tuples_out: List[Dict[str, Any]] = [] |
| for t in seq: |
| c = int(t.get("cat", 0)) |
| a = int(t.get("anat", 0)) |
| if 1 <= c <= self.cfg["n_cat"]: |
| cat_counts[c - 1] += 1 |
| if 0 <= a < self.cfg["n_anat"]: |
| anat_counts[a] += 1 |
| tuples_out.append({ |
| "cat": c, "cat_name": CAT_NAMES.get(c, str(c)), |
| "anat": a, "anat_name": ANAT_NAMES.get(a, str(a)), |
| "concept_id": int(t.get("concept_id", 0)), |
| "concept": self.id2concept.get(int(t.get("concept_id", 0)), "<UNK>"), |
| "ref_seg_idx": int(t.get("ref_seg_idx", -1)), |
| "cand_seg_idx": int(t.get("cand_seg_idx", -1)), |
| }) |
| return { |
| "score": score, |
| "K_total": K_total, |
| "tuples": tuples_out, |
| "category_counts": cat_counts, |
| "anatomy_counts": anat_counts, |
| } |
|
|
|
|
| _INSTANCE: Optional[Chest2Err] = None |
|
|
|
|
| def _get() -> Chest2Err: |
| global _INSTANCE |
| if _INSTANCE is None: |
| _INSTANCE = Chest2Err() |
| return _INSTANCE |
|
|
|
|
| def chest2err_score(ref: str, cand: str) -> float: |
| """chest2err-score ∈ (0, 1] for one (reference, candidate) report pair.""" |
| return _get().score(ref, cand) |
|
|
|
|
| def chest2err_detail(ref: str, cand: str) -> Dict[str, Any]: |
| """Full breakdown: score, K_total, per-error tuples, per-category and per-anatomy counts.""" |
| return _get().detail(ref, cand) |
|
|
|
|
| __all__ = ["Chest2Err", "chest2err_score", "chest2err_detail"] |
|
|