import re, json import torch import gradio as gr from transformers import BertTokenizerFast, BertForTokenClassification # === ตั้งค่าโมเดลจาก Hub === MODEL_ID = "Donlagon007/htn-ner-v1" # โหลดโมเดล/โทเคนไนเซอร์ (CPU เป็นค่าเริ่มต้นใน Spaces) tokenizer = BertTokenizerFast.from_pretrained(MODEL_ID) model = BertForTokenClassification.from_pretrained(MODEL_ID) model.eval() id2label = model.config.id2label # ---------- Utils ---------- def decode_bio_to_spans(labels): spans, cur_type, s = [], None, None for i, lab in enumerate(labels): if lab == "O" or lab is None: if cur_type is not None: spans.append((cur_type, s, i-1)) cur_type, s = None, None continue tag, typ = lab.split("-", 1) if tag == "B": if cur_type is not None: spans.append((cur_type, s, i-1)) cur_type, s = typ, i elif tag == "I": if cur_type != typ: if cur_type is not None: spans.append((cur_type, s, i-1)) cur_type, s = typ, i if cur_type is not None: spans.append((cur_type, s, len(labels)-1)) return spans def ner_predict_with_tokens(text, max_length=256): enc = tokenizer( text, return_offsets_mapping=True, return_tensors="pt", truncation=True, max_length=max_length ) with torch.no_grad(): out = model( input_ids=enc["input_ids"], attention_mask=enc["attention_mask"] ) pred_ids = out.logits.argmax(-1).squeeze(0).tolist() offsets = enc["offset_mapping"].squeeze(0).tolist() input_ids = enc["input_ids"].squeeze(0).tolist() tokens_info, kept_labels, kept_offsets = [], [], [] # ตัด [CLS]/[SEP]/padding โดยเช็ค offset (0,0) for lid, (st, ed), tid in zip(pred_ids, offsets, input_ids): if st == ed == 0: continue tok = tokenizer.convert_ids_to_tokens([tid])[0] lab = id2label[lid] tokens_info.append({"token": tok, "label": lab, "start": st, "end": ed}) kept_labels.append(lab) kept_offsets.append((st, ed)) spans_tok = decode_bio_to_spans(kept_labels) entities = [] for typ, s_tok, e_tok in spans_tok: cs, ce = kept_offsets[s_tok][0], kept_offsets[e_tok][1] entities.append({"type": typ, "text": text[cs:ce], "start": cs, "end": ce}) return tokens_info, entities def cn_num(s): m = re.search(r"(\d+(?:\.\d+)?)", s.replace(" ", "")) return float(m.group(1)) if m else None def parse_bp(value_text): m = re.search(r"(\d{2,3})\s*/\s*(\d{2,3})", value_text.replace(" ", "")) if m: return int(m.group(1)), int(m.group(2)) return None, None THRESHOLDS = { "空腹血糖": {"unit":"mg/dL", "abnormal": lambda v: v is not None and v >= 126}, "HbA1c": {"unit":"%", "abnormal": lambda v: v is not None and v >= 6.5}, "LDL": {"unit":"mg/dL", "abnormal": lambda v: v is not None and v >= 160, "borderline": lambda v: v is not None and 130 <= v < 160}, } def status_for(test, val): if test in THRESHOLDS: th = THRESHOLDS[test] v = cn_num(val) if "borderline" in th and th["borderline"](v): return "偏高" return "異常" if th["abnormal"](v) else "正常" return None def pair_tests_values(entities): ents = sorted(entities, key=lambda x: x["start"]) pairs, lone = [], [] last_test = None for e in ents: if e["type"] == "TEST": if last_test: lone.append(last_test) last_test = {"test": e["text"], "start": e["start"], "value": None} elif e["type"] == "VALUE" and last_test and (e["start"] - last_test["start"]) < 40: last_test["value"] = e["text"] pairs.append({"test": last_test["test"], "value": e["text"]}) last_test = None if last_test: lone.append(last_test) return pairs, lone def extract_structured(text): tokens, entities = ner_predict_with_tokens(text) # basic fields name, ages, sex = None, [], None for e in entities: if e["type"] == "PER" and (name is None or len(e["text"]) > len(name)): name = e["text"] elif e["type"] == "AGE": v = cn_num(e["text"]) if v is not None: ages.append(int(v)) elif e["type"] == "SEX": sex = e["text"] # fallback sex detect if not sex: m_sex = re.search(r"(男|女)", text) if m_sex: sex = m_sex.group(1) pairs, _ = pair_tests_values(entities) key_findings = [] for p in pairs: st = status_for(p["test"], p["value"]) row = {"test": p["test"], "value": p["value"]} if st: row["status"] = st key_findings.append(row) risks = set() # 4.1 Diabetes fpg = next((cn_num(p["value"]) for p in pairs if p["test"] == "空腹血糖"), None) a1c = next((cn_num(p["value"]) for p in pairs if p["test"] == "HbA1c"), None) if (fpg is not None and fpg >= 126) or (a1c is not None and a1c >= 6.5): risks.add("糖尿病") # 4.2 Hyperlipidemia via LDL ldl = next((cn_num(p["value"]) for p in pairs if p["test"] == "LDL"), None) if ldl is not None and ldl >= 160: risks.add("高血脂") elif ldl is not None and ldl >= 130: risks.add("高血脂(輕度)") # 4.3 Hypertension via BP or DISEASE mention bp_val = next((p["value"] for p in pairs if p["test"] in ["診間血壓","家庭血壓","24小時動態血壓"]), None) if bp_val: sys, dia = parse_bp(bp_val) if sys and dia and (sys >= 140 or dia >= 90): risks.add("高血壓") if any(e["type"] == "DISEASE" and "高血壓" in e["text"] for e in entities): risks.add("高血壓") # recommendations recs = [] for e in entities: if e["type"] == "DRUG": recs.append(f"開始服用 {e['text']}") elif e["type"] == "DRUG_CLASS": recs.append(f"考慮 {e['text']} 類藥物") elif e["type"] == "TREATMENT": t = e["text"] if "飲食" in t and "低鹽" not in t: t = "控制飲食" recs.append(f"建議{t}") age = max(ages) if ages else None name_disp = name if name else "病人" age_disp = f"{age}歲" if age is not None else "" abns = [f"{k['test']} {k['value']}" for k in key_findings if k.get("status") in ("異常","偏高")] parts = [f"{name_disp}({age_disp})"] if age_disp else [name_disp] if abns: parts.append(f"檢查顯示 " + "、".join(abns[:3])) if "糖尿病" in risks: parts.append("符合糖尿病診斷") if "高血脂" in risks or "高血脂(輕度)" in risks: parts.append("另見 LDL 偏高") if recs: parts.append("建議:" + "、".join(recs[:3])) summary = ",".join(parts) + "。" structured = { "name": name or None, "age": age if age is not None else None, "sex": sex, "key_findings": key_findings, "disease_risk": sorted(list(risks)), "recommendations": recs, "summary": summary } return tokens, entities, structured # ---------- Human-readable report ---------- def make_readable_report(structured: dict) -> str: name = structured.get("name") or "病人" age = structured.get("age") sex = structured.get("sex") head = f"【健檢摘要】{name}" if age is not None: head += f"({age}歲" else: head += "(年齡不詳" if sex: head += f",{sex})" else: head += ")" # key findings kfs = structured.get("key_findings", []) abn_lines, nor_lines = [], [] for k in kfs: t, v = k.get("test"), k.get("value") st = k.get("status") if st in ("異常","偏高"): abn_lines.append(f".{t}: {v}({st})") elif st == "正常": nor_lines.append(f".{t}: {v}(正常)") else: nor_lines.append(f".{t}: {v}") risks = structured.get("disease_risk", []) recs = structured.get("recommendations", []) summary = structured.get("summary") or "" sections = [head, ""] if abn_lines: sections += ["【異常/偏高項目】"] + abn_lines + [""] if nor_lines: sections += ["【其他檢測】"] + nor_lines + [""] if risks: sections += ["【疾病風險/診斷】", "." + "、".join(risks), ""] if recs: sections += ["【建議】", "." + "、".join(recs), ""] if summary: sections += ["【摘要敘述】", summary, ""] return "\n".join(sections).strip() # ---------- Gradio UI ---------- EXAMPLE = "李偉(65歲,男),有高血壓與糖尿病。\n診間血壓152/94mmHg,空腹血糖138mg/dL,HbA1c 7.1%。\n建議使用ARB類藥物並低鹽飲食。" def run(text): tokens, entities, structured = extract_structured(text) human_report = make_readable_report(structured) return ( human_report, json.dumps(structured, ensure_ascii=False, indent=2), json.dumps(entities, ensure_ascii=False, indent=2), json.dumps(tokens, ensure_ascii=False, indent=2), ) with gr.Blocks(title="HTN NER (Chinese)") as demo: gr.Markdown("## Hypertension NER → Human Report / JSON / Entities / Tokens") inp = gr.Textbox(label="輸入文字 (中文)", lines=6, value=EXAMPLE) btn = gr.Button("Analyze") out_report = gr.Textbox(label="Doctor Report", lines=12) out_struct = gr.Code(label="Structured JSON") out_entities = gr.Code(label="Entities (spans)") out_tokens = gr.Code(label="Token-level (B/I/O)") btn.click(run, inputs=inp, outputs=[out_report, out_struct, out_entities, out_tokens]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0")