File size: 6,523 Bytes
47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 06de0a9 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 06de0a9 23c824a 47f2d5a 23c824a 47f2d5a 23c824a 47f2d5a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """chest2err — self-contained loader.
Usage:
from chest2err import chest2err_score, chest2err_detail
score = chest2err_score(ref, cand) # float in (0, 1]
detail = chest2err_detail(ref, cand) # full breakdown
The bundle ships the merged backbone weights, the decoder weights, the
tokenizer, and the concept vocabulary. No additional downloads occur at
inference; the Qwen3-architecture backbone class is taken from the
`transformers` package and instantiated from the bundled `config.json`.
"""
from __future__ import annotations
import json
import math
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from transformers import AutoModel, AutoTokenizer
from safetensors.torch import load_file
# Sibling files in this package
from chest2err_modeling import CADAD
from chest2err_collate import encode_pair_for_decoder, collate_decoder_batch
PACKAGE_DIR = Path(__file__).resolve().parent
CAT_NAMES = {0: "EOS", 1: "false_prediction", 2: "omission", 3: "location",
4: "severity", 5: "comparison"}
ANAT_NAMES = {0: "Lung & Airways", 1: "Cardiovascular", 2: "Mediastinum & Hila",
3: "Upper Abdomen", 4: "Pleura", 5: "Bones / Spine", 6: "Chest Wall",
7: "Lower Neck", 8: "Others"}
def _load_config() -> Dict[str, Any]:
with open(PACKAGE_DIR / "chest2err_config.json") as f:
return json.load(f)
class Chest2Err:
"""Loads the merged backbone + decoder once, then scores pairs."""
def __init__(self,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
attn_implementation: Optional[str] = None):
cfg = _load_config()
self.cfg = cfg
self.device = device
self.max_length = cfg["max_length"]
# Display temperature τ for the score exp(-K_total/τ). τ=3.0 is the
# default gentle setting (one error → 0.72); τ=1.0 reproduces the
# original exp(-K_total). Rank-equivalent, so τ never affects τ_b.
self.score_temperature = float(cfg.get("score_temperature", 3.0))
# Concept vocab (size determines decoder output head dim)
with open(PACKAGE_DIR / "concept2id.json") as f:
self.concept2id: Dict[str, int] = json.load(f)
self.n_concept = len(self.concept2id)
self.id2concept = {v: k for k, v in self.concept2id.items()}
# Tokenizer + backbone load from bundled files only.
self.tokenizer = AutoTokenizer.from_pretrained(str(PACKAGE_DIR))
kw = {"torch_dtype": torch.bfloat16}
if attn_implementation:
kw["attn_implementation"] = attn_implementation
backbone = AutoModel.from_pretrained(str(PACKAGE_DIR), **kw)
# CADAD wraps the backbone + decoder. Construct, then load merged backbone
# weights + decoder weights.
self.model = CADAD(
backbone=backbone,
hidden_size=cfg["hidden_size"],
n_cat=cfg["n_cat"],
n_anat=cfg["n_anat"],
n_concept=self.n_concept,
n_severity=2,
decoder_layers=cfg["decoder_layers"],
decoder_heads=cfg["decoder_heads"],
decoder_ff=cfg["decoder_ff"],
dropout=cfg["decoder_dropout"],
max_decode_steps=cfg["max_decode_steps"],
)
# The backbone weights were already loaded by AutoModel.from_pretrained.
# Now layer the decoder weights on top.
decoder_state = load_file(str(PACKAGE_DIR / "decoder.safetensors"))
missing, unexpected = self.model.load_state_dict(decoder_state, strict=False)
# Expected: many `backbone.*` keys are "missing" from decoder_state
# (they came from model.safetensors via from_pretrained). That's fine.
self.model = self.model.to(device).eval()
@torch.inference_mode()
def score(self, ref: str, cand: str) -> float:
return self.detail(ref, cand)["score"]
@torch.inference_mode()
def detail(self, ref: str, cand: str) -> Dict[str, Any]:
item = encode_pair_for_decoder(
self.tokenizer, ref, cand, max_length=self.max_length,
)
batch = collate_decoder_batch([item],
pad_token_id=self.tokenizer.pad_token_id or 0)
batch = {k: v.to(self.device) for k, v in batch.items()}
with torch.autocast(
device_type="cuda" if str(self.device).startswith("cuda") else "cpu",
dtype=torch.bfloat16,
):
seqs = self.model.decode_greedy(
batch["input_ids"],
batch["attention_mask"],
batch["ref_seg_token_mask"],
batch["cand_seg_token_mask"],
)
seq = seqs[0]
K_total = len(seq)
score = math.exp(-K_total / self.score_temperature)
cat_counts = [0] * self.cfg["n_cat"]
anat_counts = [0] * self.cfg["n_anat"]
tuples_out: List[Dict[str, Any]] = []
for t in seq:
c = int(t.get("cat", 0))
a = int(t.get("anat", 0))
if 1 <= c <= self.cfg["n_cat"]:
cat_counts[c - 1] += 1
if 0 <= a < self.cfg["n_anat"]:
anat_counts[a] += 1
tuples_out.append({
"cat": c, "cat_name": CAT_NAMES.get(c, str(c)),
"anat": a, "anat_name": ANAT_NAMES.get(a, str(a)),
"concept_id": int(t.get("concept_id", 0)),
"concept": self.id2concept.get(int(t.get("concept_id", 0)), "<UNK>"),
"ref_seg_idx": int(t.get("ref_seg_idx", -1)),
"cand_seg_idx": int(t.get("cand_seg_idx", -1)),
})
return {
"score": score,
"K_total": K_total,
"tuples": tuples_out,
"category_counts": cat_counts,
"anatomy_counts": anat_counts,
}
_INSTANCE: Optional[Chest2Err] = None
def _get() -> Chest2Err:
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = Chest2Err()
return _INSTANCE
def chest2err_score(ref: str, cand: str) -> float:
"""chest2err-score ∈ (0, 1] for one (reference, candidate) report pair."""
return _get().score(ref, cand)
def chest2err_detail(ref: str, cand: str) -> Dict[str, Any]:
"""Full breakdown: score, K_total, per-error tuples, per-category and per-anatomy counts."""
return _get().detail(ref, cand)
__all__ = ["Chest2Err", "chest2err_score", "chest2err_detail"]
|