File size: 9,974 Bytes
03cf9b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78247e7
03cf9b3
 
 
 
78247e7
 
03cf9b3
 
 
 
78247e7
 
03cf9b3
 
 
 
 
 
 
78247e7
03cf9b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78247e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03cf9b3
 
 
 
 
78247e7
 
 
 
 
 
 
03cf9b3
 
78247e7
03cf9b3
 
78247e7
 
b8e8104
78247e7
 
03cf9b3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import re, json
import torch
import gradio as gr
from transformers import BertTokenizerFast, BertForTokenClassification

# === ตั้งค่าโมเดลจาก Hub ===
MODEL_ID = "Donlagon007/htn-ner-v1"

# โหลดโมเดล/โทเคนไนเซอร์ (CPU เป็นค่าเริ่มต้นใน Spaces)
tokenizer = BertTokenizerFast.from_pretrained(MODEL_ID)
model = BertForTokenClassification.from_pretrained(MODEL_ID)
model.eval()
id2label = model.config.id2label

# ---------- Utils ----------
def decode_bio_to_spans(labels):
    spans, cur_type, s = [], None, None
    for i, lab in enumerate(labels):
        if lab == "O" or lab is None:
            if cur_type is not None:
                spans.append((cur_type, s, i-1))
                cur_type, s = None, None
            continue
        tag, typ = lab.split("-", 1)
        if tag == "B":
            if cur_type is not None:
                spans.append((cur_type, s, i-1))
            cur_type, s = typ, i
        elif tag == "I":
            if cur_type != typ:
                if cur_type is not None:
                    spans.append((cur_type, s, i-1))
                cur_type, s = typ, i
    if cur_type is not None:
        spans.append((cur_type, s, len(labels)-1))
    return spans

def ner_predict_with_tokens(text, max_length=256):
    enc = tokenizer(
        text,
        return_offsets_mapping=True,
        return_tensors="pt",
        truncation=True, max_length=max_length
    )
    with torch.no_grad():
        out = model(
            input_ids=enc["input_ids"],
            attention_mask=enc["attention_mask"]
        )
        pred_ids = out.logits.argmax(-1).squeeze(0).tolist()

    offsets = enc["offset_mapping"].squeeze(0).tolist()
    input_ids = enc["input_ids"].squeeze(0).tolist()

    tokens_info, kept_labels, kept_offsets = [], [], []
    # ตัด [CLS]/[SEP]/padding โดยเช็ค offset (0,0)
    for lid, (st, ed), tid in zip(pred_ids, offsets, input_ids):
        if st == ed == 0:
            continue
        tok = tokenizer.convert_ids_to_tokens([tid])[0]
        lab = id2label[lid]
        tokens_info.append({"token": tok, "label": lab, "start": st, "end": ed})
        kept_labels.append(lab)
        kept_offsets.append((st, ed))

    spans_tok = decode_bio_to_spans(kept_labels)
    entities = []
    for typ, s_tok, e_tok in spans_tok:
        cs, ce = kept_offsets[s_tok][0], kept_offsets[e_tok][1]
        entities.append({"type": typ, "text": text[cs:ce], "start": cs, "end": ce})
    return tokens_info, entities

def cn_num(s):
    m = re.search(r"(\d+(?:\.\d+)?)", s.replace(" ", ""))
    return float(m.group(1)) if m else None

def parse_bp(value_text):
    m = re.search(r"(\d{2,3})\s*/\s*(\d{2,3})", value_text.replace(" ", ""))
    if m:
        return int(m.group(1)), int(m.group(2))
    return None, None

THRESHOLDS = {
    "空腹血糖": {"unit":"mg/dL", "abnormal": lambda v: v is not None and v >= 126},
    "HbA1c":   {"unit":"%",     "abnormal": lambda v: v is not None and v >= 6.5},
    "LDL":     {"unit":"mg/dL", "abnormal": lambda v: v is not None and v >= 160,
                "borderline": lambda v: v is not None and 130 <= v < 160},
}

def status_for(test, val):
    if test in THRESHOLDS:
        th = THRESHOLDS[test]
        v = cn_num(val)
        if "borderline" in th and th["borderline"](v):
            return "偏高"
        return "異常" if th["abnormal"](v) else "正常"
    return None

def pair_tests_values(entities):
    ents = sorted(entities, key=lambda x: x["start"])
    pairs, lone = [], []
    last_test = None
    for e in ents:
        if e["type"] == "TEST":
            if last_test: lone.append(last_test)
            last_test = {"test": e["text"], "start": e["start"], "value": None}
        elif e["type"] == "VALUE" and last_test and (e["start"] - last_test["start"]) < 40:
            last_test["value"] = e["text"]
            pairs.append({"test": last_test["test"], "value": e["text"]})
            last_test = None
    if last_test: lone.append(last_test)
    return pairs, lone

def extract_structured(text):
    tokens, entities = ner_predict_with_tokens(text)

    # basic fields
    name, ages, sex = None, [], None
    for e in entities:
        if e["type"] == "PER" and (name is None or len(e["text"]) > len(name)):
            name = e["text"]
        elif e["type"] == "AGE":
            v = cn_num(e["text"])
            if v is not None: ages.append(int(v))
        elif e["type"] == "SEX":
            sex = e["text"]

    # fallback sex detect
    if not sex:
        m_sex = re.search(r"(男|女)", text)
        if m_sex: sex = m_sex.group(1)

    pairs, _ = pair_tests_values(entities)
    key_findings = []
    for p in pairs:
        st = status_for(p["test"], p["value"])
        row = {"test": p["test"], "value": p["value"]}
        if st: row["status"] = st
        key_findings.append(row)

    risks = set()
    # 4.1 Diabetes
    fpg = next((cn_num(p["value"]) for p in pairs if p["test"] == "空腹血糖"), None)
    a1c = next((cn_num(p["value"]) for p in pairs if p["test"] == "HbA1c"), None)
    if (fpg is not None and fpg >= 126) or (a1c is not None and a1c >= 6.5):
        risks.add("糖尿病")
    # 4.2 Hyperlipidemia via LDL
    ldl = next((cn_num(p["value"]) for p in pairs if p["test"] == "LDL"), None)
    if ldl is not None and ldl >= 160:
        risks.add("高血脂")
    elif ldl is not None and ldl >= 130:
        risks.add("高血脂(輕度)")
    # 4.3 Hypertension via BP or DISEASE mention
    bp_val = next((p["value"] for p in pairs if p["test"] in ["診間血壓","家庭血壓","24小時動態血壓"]), None)
    if bp_val:
        sys, dia = parse_bp(bp_val)
        if sys and dia and (sys >= 140 or dia >= 90):
            risks.add("高血壓")
    if any(e["type"] == "DISEASE" and "高血壓" in e["text"] for e in entities):
        risks.add("高血壓")

    # recommendations
    recs = []
    for e in entities:
        if e["type"] == "DRUG":
            recs.append(f"開始服用 {e['text']}")
        elif e["type"] == "DRUG_CLASS":
            recs.append(f"考慮 {e['text']} 類藥物")
        elif e["type"] == "TREATMENT":
            t = e["text"]
            if "飲食" in t and "低鹽" not in t:
                t = "控制飲食"
            recs.append(f"建議{t}")

    age = max(ages) if ages else None
    name_disp = name if name else "病人"
    age_disp = f"{age}歲" if age is not None else ""
    abns = [f"{k['test']} {k['value']}" for k in key_findings if k.get("status") in ("異常","偏高")]
    parts = [f"{name_disp}{age_disp})"] if age_disp else [name_disp]
    if abns: parts.append(f"檢查顯示 " + "、".join(abns[:3]))
    if "糖尿病" in risks: parts.append("符合糖尿病診斷")
    if "高血脂" in risks or "高血脂(輕度)" in risks: parts.append("另見 LDL 偏高")
    if recs: parts.append("建議:" + "、".join(recs[:3]))
    summary = ",".join(parts) + "。"

    structured = {
        "name": name or None,
        "age": age if age is not None else None,
        "sex": sex,
        "key_findings": key_findings,
        "disease_risk": sorted(list(risks)),
        "recommendations": recs,
        "summary": summary
    }
    return tokens, entities, structured

# ---------- Human-readable report ----------
def make_readable_report(structured: dict) -> str:
    name = structured.get("name") or "病人"
    age  = structured.get("age")
    sex  = structured.get("sex")
    head = f"【健檢摘要】{name}"
    if age is not None: head += f"({age}歲"
    else: head += "(年齡不詳"
    if sex: head += f",{sex})"
    else: head += ")"

    # key findings
    kfs = structured.get("key_findings", [])
    abn_lines, nor_lines = [], []
    for k in kfs:
        t, v = k.get("test"), k.get("value")
        st = k.get("status")
        if st in ("異常","偏高"):
            abn_lines.append(f".{t}: {v}{st})")
        elif st == "正常":
            nor_lines.append(f".{t}: {v}(正常)")
        else:
            nor_lines.append(f".{t}: {v}")

    risks = structured.get("disease_risk", [])
    recs  = structured.get("recommendations", [])
    summary = structured.get("summary") or ""

    sections = [head, ""]
    if abn_lines:
        sections += ["【異常/偏高項目】"] + abn_lines + [""]
    if nor_lines:
        sections += ["【其他檢測】"] + nor_lines + [""]
    if risks:
        sections += ["【疾病風險/診斷】", "." + "、".join(risks), ""]
    if recs:
        sections += ["【建議】", "." + "、".join(recs), ""]
    if summary:
        sections += ["【摘要敘述】", summary, ""]
    return "\n".join(sections).strip()

# ---------- Gradio UI ----------
EXAMPLE = "李偉(65歲,男),有高血壓與糖尿病。\n診間血壓152/94mmHg,空腹血糖138mg/dL,HbA1c 7.1%。\n建議使用ARB類藥物並低鹽飲食。"

def run(text):
    tokens, entities, structured = extract_structured(text)
    human_report = make_readable_report(structured)
    return (
        human_report,
        json.dumps(structured, ensure_ascii=False, indent=2),
        json.dumps(entities, ensure_ascii=False, indent=2),
        json.dumps(tokens, ensure_ascii=False, indent=2),
    )

with gr.Blocks(title="HTN NER (Chinese)") as demo:
    gr.Markdown("## Hypertension NER → Human Report / JSON / Entities / Tokens")
    inp = gr.Textbox(label="輸入文字 (中文)", lines=6, value=EXAMPLE)
    btn = gr.Button("Analyze")
    out_report   = gr.Textbox(label="Doctor Report", lines=12)
    out_struct   = gr.Code(label="Structured JSON")
    out_entities = gr.Code(label="Entities (spans)")
    out_tokens   = gr.Code(label="Token-level (B/I/O)")
    btn.click(run, inputs=inp, outputs=[out_report, out_struct, out_entities, out_tokens])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")