File size: 6,523 Bytes
47f2d5a
 
 
 
23c824a
 
47f2d5a
23c824a
 
 
 
47f2d5a
 
 
 
 
23c824a
47f2d5a
23c824a
47f2d5a
 
 
 
 
23c824a
47f2d5a
23c824a
47f2d5a
 
 
23c824a
 
 
 
 
 
47f2d5a
 
 
 
 
 
 
 
 
23c824a
 
 
47f2d5a
 
 
 
06de0a9
 
 
 
47f2d5a
23c824a
 
 
 
 
47f2d5a
23c824a
 
 
 
 
 
 
 
 
 
 
 
 
47f2d5a
23c824a
 
47f2d5a
 
 
23c824a
47f2d5a
 
 
23c824a
 
 
 
 
 
 
 
47f2d5a
 
 
23c824a
47f2d5a
 
 
23c824a
 
47f2d5a
23c824a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06de0a9
23c824a
47f2d5a
 
23c824a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47f2d5a
 
 
23c824a
47f2d5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""chest2err — self-contained loader.

Usage:
    from chest2err import chest2err_score, chest2err_detail
    score  = chest2err_score(ref, cand)         # float in (0, 1]
    detail = chest2err_detail(ref, cand)        # full breakdown

The bundle ships the merged backbone weights, the decoder weights, the
tokenizer, and the concept vocabulary. No additional downloads occur at
inference; the Qwen3-architecture backbone class is taken from the
`transformers` package and instantiated from the bundled `config.json`.
"""
from __future__ import annotations

import json
import math
import os
from pathlib import Path
from typing import Any, Dict, List, Optional

import torch
from transformers import AutoModel, AutoTokenizer
from safetensors.torch import load_file

# Sibling files in this package
from chest2err_modeling import CADAD
from chest2err_collate import encode_pair_for_decoder, collate_decoder_batch

PACKAGE_DIR = Path(__file__).resolve().parent

CAT_NAMES = {0: "EOS", 1: "false_prediction", 2: "omission", 3: "location",
             4: "severity", 5: "comparison"}
ANAT_NAMES = {0: "Lung & Airways", 1: "Cardiovascular", 2: "Mediastinum & Hila",
              3: "Upper Abdomen", 4: "Pleura", 5: "Bones / Spine", 6: "Chest Wall",
              7: "Lower Neck", 8: "Others"}


def _load_config() -> Dict[str, Any]:
    with open(PACKAGE_DIR / "chest2err_config.json") as f:
        return json.load(f)


class Chest2Err:
    """Loads the merged backbone + decoder once, then scores pairs."""

    def __init__(self,
                 device: str = "cuda" if torch.cuda.is_available() else "cpu",
                 attn_implementation: Optional[str] = None):
        cfg = _load_config()
        self.cfg = cfg
        self.device = device
        self.max_length = cfg["max_length"]
        # Display temperature τ for the score exp(-K_total/τ). τ=3.0 is the
        # default gentle setting (one error → 0.72); τ=1.0 reproduces the
        # original exp(-K_total). Rank-equivalent, so τ never affects τ_b.
        self.score_temperature = float(cfg.get("score_temperature", 3.0))

        # Concept vocab (size determines decoder output head dim)
        with open(PACKAGE_DIR / "concept2id.json") as f:
            self.concept2id: Dict[str, int] = json.load(f)
        self.n_concept = len(self.concept2id)
        self.id2concept = {v: k for k, v in self.concept2id.items()}

        # Tokenizer + backbone load from bundled files only.
        self.tokenizer = AutoTokenizer.from_pretrained(str(PACKAGE_DIR))
        kw = {"torch_dtype": torch.bfloat16}
        if attn_implementation:
            kw["attn_implementation"] = attn_implementation
        backbone = AutoModel.from_pretrained(str(PACKAGE_DIR), **kw)

        # CADAD wraps the backbone + decoder. Construct, then load merged backbone
        # weights + decoder weights.
        self.model = CADAD(
            backbone=backbone,
            hidden_size=cfg["hidden_size"],
            n_cat=cfg["n_cat"],
            n_anat=cfg["n_anat"],
            n_concept=self.n_concept,
            n_severity=2,
            decoder_layers=cfg["decoder_layers"],
            decoder_heads=cfg["decoder_heads"],
            decoder_ff=cfg["decoder_ff"],
            dropout=cfg["decoder_dropout"],
            max_decode_steps=cfg["max_decode_steps"],
        )

        # The backbone weights were already loaded by AutoModel.from_pretrained.
        # Now layer the decoder weights on top.
        decoder_state = load_file(str(PACKAGE_DIR / "decoder.safetensors"))
        missing, unexpected = self.model.load_state_dict(decoder_state, strict=False)
        # Expected: many `backbone.*` keys are "missing" from decoder_state
        # (they came from model.safetensors via from_pretrained). That's fine.

        self.model = self.model.to(device).eval()

    @torch.inference_mode()
    def score(self, ref: str, cand: str) -> float:
        return self.detail(ref, cand)["score"]

    @torch.inference_mode()
    def detail(self, ref: str, cand: str) -> Dict[str, Any]:
        item = encode_pair_for_decoder(
            self.tokenizer, ref, cand, max_length=self.max_length,
        )
        batch = collate_decoder_batch([item],
                                       pad_token_id=self.tokenizer.pad_token_id or 0)
        batch = {k: v.to(self.device) for k, v in batch.items()}

        with torch.autocast(
            device_type="cuda" if str(self.device).startswith("cuda") else "cpu",
            dtype=torch.bfloat16,
        ):
            seqs = self.model.decode_greedy(
                batch["input_ids"],
                batch["attention_mask"],
                batch["ref_seg_token_mask"],
                batch["cand_seg_token_mask"],
            )
        seq = seqs[0]
        K_total = len(seq)
        score = math.exp(-K_total / self.score_temperature)

        cat_counts = [0] * self.cfg["n_cat"]
        anat_counts = [0] * self.cfg["n_anat"]
        tuples_out: List[Dict[str, Any]] = []
        for t in seq:
            c = int(t.get("cat", 0))
            a = int(t.get("anat", 0))
            if 1 <= c <= self.cfg["n_cat"]:
                cat_counts[c - 1] += 1
            if 0 <= a < self.cfg["n_anat"]:
                anat_counts[a] += 1
            tuples_out.append({
                "cat": c, "cat_name": CAT_NAMES.get(c, str(c)),
                "anat": a, "anat_name": ANAT_NAMES.get(a, str(a)),
                "concept_id": int(t.get("concept_id", 0)),
                "concept": self.id2concept.get(int(t.get("concept_id", 0)), "<UNK>"),
                "ref_seg_idx": int(t.get("ref_seg_idx", -1)),
                "cand_seg_idx": int(t.get("cand_seg_idx", -1)),
            })
        return {
            "score": score,
            "K_total": K_total,
            "tuples": tuples_out,
            "category_counts": cat_counts,
            "anatomy_counts": anat_counts,
        }


_INSTANCE: Optional[Chest2Err] = None


def _get() -> Chest2Err:
    global _INSTANCE
    if _INSTANCE is None:
        _INSTANCE = Chest2Err()
    return _INSTANCE


def chest2err_score(ref: str, cand: str) -> float:
    """chest2err-score ∈ (0, 1] for one (reference, candidate) report pair."""
    return _get().score(ref, cand)


def chest2err_detail(ref: str, cand: str) -> Dict[str, Any]:
    """Full breakdown: score, K_total, per-error tuples, per-category and per-anatomy counts."""
    return _get().detail(ref, cand)


__all__ = ["Chest2Err", "chest2err_score", "chest2err_detail"]