Spaces:
Sleeping
Sleeping
| import re, json | |
| import torch | |
| import gradio as gr | |
| from transformers import BertTokenizerFast, BertForTokenClassification | |
| # === ตั้งค่าโมเดลจาก Hub === | |
| MODEL_ID = "Donlagon007/htn-ner-v1" | |
| # โหลดโมเดล/โทเคนไนเซอร์ (CPU เป็นค่าเริ่มต้นใน Spaces) | |
| tokenizer = BertTokenizerFast.from_pretrained(MODEL_ID) | |
| model = BertForTokenClassification.from_pretrained(MODEL_ID) | |
| model.eval() | |
| id2label = model.config.id2label | |
| # ---------- Utils ---------- | |
| def decode_bio_to_spans(labels): | |
| spans, cur_type, s = [], None, None | |
| for i, lab in enumerate(labels): | |
| if lab == "O" or lab is None: | |
| if cur_type is not None: | |
| spans.append((cur_type, s, i-1)) | |
| cur_type, s = None, None | |
| continue | |
| tag, typ = lab.split("-", 1) | |
| if tag == "B": | |
| if cur_type is not None: | |
| spans.append((cur_type, s, i-1)) | |
| cur_type, s = typ, i | |
| elif tag == "I": | |
| if cur_type != typ: | |
| if cur_type is not None: | |
| spans.append((cur_type, s, i-1)) | |
| cur_type, s = typ, i | |
| if cur_type is not None: | |
| spans.append((cur_type, s, len(labels)-1)) | |
| return spans | |
| def ner_predict_with_tokens(text, max_length=256): | |
| enc = tokenizer( | |
| text, | |
| return_offsets_mapping=True, | |
| return_tensors="pt", | |
| truncation=True, max_length=max_length | |
| ) | |
| with torch.no_grad(): | |
| out = model( | |
| input_ids=enc["input_ids"], | |
| attention_mask=enc["attention_mask"] | |
| ) | |
| pred_ids = out.logits.argmax(-1).squeeze(0).tolist() | |
| offsets = enc["offset_mapping"].squeeze(0).tolist() | |
| input_ids = enc["input_ids"].squeeze(0).tolist() | |
| tokens_info, kept_labels, kept_offsets = [], [], [] | |
| # ตัด [CLS]/[SEP]/padding โดยเช็ค offset (0,0) | |
| for lid, (st, ed), tid in zip(pred_ids, offsets, input_ids): | |
| if st == ed == 0: | |
| continue | |
| tok = tokenizer.convert_ids_to_tokens([tid])[0] | |
| lab = id2label[lid] | |
| tokens_info.append({"token": tok, "label": lab, "start": st, "end": ed}) | |
| kept_labels.append(lab) | |
| kept_offsets.append((st, ed)) | |
| spans_tok = decode_bio_to_spans(kept_labels) | |
| entities = [] | |
| for typ, s_tok, e_tok in spans_tok: | |
| cs, ce = kept_offsets[s_tok][0], kept_offsets[e_tok][1] | |
| entities.append({"type": typ, "text": text[cs:ce], "start": cs, "end": ce}) | |
| return tokens_info, entities | |
| def cn_num(s): | |
| m = re.search(r"(\d+(?:\.\d+)?)", s.replace(" ", "")) | |
| return float(m.group(1)) if m else None | |
| def parse_bp(value_text): | |
| m = re.search(r"(\d{2,3})\s*/\s*(\d{2,3})", value_text.replace(" ", "")) | |
| if m: | |
| return int(m.group(1)), int(m.group(2)) | |
| return None, None | |
| THRESHOLDS = { | |
| "空腹血糖": {"unit":"mg/dL", "abnormal": lambda v: v is not None and v >= 126}, | |
| "HbA1c": {"unit":"%", "abnormal": lambda v: v is not None and v >= 6.5}, | |
| "LDL": {"unit":"mg/dL", "abnormal": lambda v: v is not None and v >= 160, | |
| "borderline": lambda v: v is not None and 130 <= v < 160}, | |
| } | |
| def status_for(test, val): | |
| if test in THRESHOLDS: | |
| th = THRESHOLDS[test] | |
| v = cn_num(val) | |
| if "borderline" in th and th["borderline"](v): | |
| return "偏高" | |
| return "異常" if th["abnormal"](v) else "正常" | |
| return None | |
| def pair_tests_values(entities): | |
| ents = sorted(entities, key=lambda x: x["start"]) | |
| pairs, lone = [], [] | |
| last_test = None | |
| for e in ents: | |
| if e["type"] == "TEST": | |
| if last_test: lone.append(last_test) | |
| last_test = {"test": e["text"], "start": e["start"], "value": None} | |
| elif e["type"] == "VALUE" and last_test and (e["start"] - last_test["start"]) < 40: | |
| last_test["value"] = e["text"] | |
| pairs.append({"test": last_test["test"], "value": e["text"]}) | |
| last_test = None | |
| if last_test: lone.append(last_test) | |
| return pairs, lone | |
| def extract_structured(text): | |
| tokens, entities = ner_predict_with_tokens(text) | |
| # basic fields | |
| name, ages, sex = None, [], None | |
| for e in entities: | |
| if e["type"] == "PER" and (name is None or len(e["text"]) > len(name)): | |
| name = e["text"] | |
| elif e["type"] == "AGE": | |
| v = cn_num(e["text"]) | |
| if v is not None: ages.append(int(v)) | |
| elif e["type"] == "SEX": | |
| sex = e["text"] | |
| # fallback sex detect | |
| if not sex: | |
| m_sex = re.search(r"(男|女)", text) | |
| if m_sex: sex = m_sex.group(1) | |
| pairs, _ = pair_tests_values(entities) | |
| key_findings = [] | |
| for p in pairs: | |
| st = status_for(p["test"], p["value"]) | |
| row = {"test": p["test"], "value": p["value"]} | |
| if st: row["status"] = st | |
| key_findings.append(row) | |
| risks = set() | |
| # 4.1 Diabetes | |
| fpg = next((cn_num(p["value"]) for p in pairs if p["test"] == "空腹血糖"), None) | |
| a1c = next((cn_num(p["value"]) for p in pairs if p["test"] == "HbA1c"), None) | |
| if (fpg is not None and fpg >= 126) or (a1c is not None and a1c >= 6.5): | |
| risks.add("糖尿病") | |
| # 4.2 Hyperlipidemia via LDL | |
| ldl = next((cn_num(p["value"]) for p in pairs if p["test"] == "LDL"), None) | |
| if ldl is not None and ldl >= 160: | |
| risks.add("高血脂") | |
| elif ldl is not None and ldl >= 130: | |
| risks.add("高血脂(輕度)") | |
| # 4.3 Hypertension via BP or DISEASE mention | |
| bp_val = next((p["value"] for p in pairs if p["test"] in ["診間血壓","家庭血壓","24小時動態血壓"]), None) | |
| if bp_val: | |
| sys, dia = parse_bp(bp_val) | |
| if sys and dia and (sys >= 140 or dia >= 90): | |
| risks.add("高血壓") | |
| if any(e["type"] == "DISEASE" and "高血壓" in e["text"] for e in entities): | |
| risks.add("高血壓") | |
| # recommendations | |
| recs = [] | |
| for e in entities: | |
| if e["type"] == "DRUG": | |
| recs.append(f"開始服用 {e['text']}") | |
| elif e["type"] == "DRUG_CLASS": | |
| recs.append(f"考慮 {e['text']} 類藥物") | |
| elif e["type"] == "TREATMENT": | |
| t = e["text"] | |
| if "飲食" in t and "低鹽" not in t: | |
| t = "控制飲食" | |
| recs.append(f"建議{t}") | |
| age = max(ages) if ages else None | |
| name_disp = name if name else "病人" | |
| age_disp = f"{age}歲" if age is not None else "" | |
| abns = [f"{k['test']} {k['value']}" for k in key_findings if k.get("status") in ("異常","偏高")] | |
| parts = [f"{name_disp}({age_disp})"] if age_disp else [name_disp] | |
| if abns: parts.append(f"檢查顯示 " + "、".join(abns[:3])) | |
| if "糖尿病" in risks: parts.append("符合糖尿病診斷") | |
| if "高血脂" in risks or "高血脂(輕度)" in risks: parts.append("另見 LDL 偏高") | |
| if recs: parts.append("建議:" + "、".join(recs[:3])) | |
| summary = ",".join(parts) + "。" | |
| structured = { | |
| "name": name or None, | |
| "age": age if age is not None else None, | |
| "sex": sex, | |
| "key_findings": key_findings, | |
| "disease_risk": sorted(list(risks)), | |
| "recommendations": recs, | |
| "summary": summary | |
| } | |
| return tokens, entities, structured | |
| # ---------- Human-readable report ---------- | |
| def make_readable_report(structured: dict) -> str: | |
| name = structured.get("name") or "病人" | |
| age = structured.get("age") | |
| sex = structured.get("sex") | |
| head = f"【健檢摘要】{name}" | |
| if age is not None: head += f"({age}歲" | |
| else: head += "(年齡不詳" | |
| if sex: head += f",{sex})" | |
| else: head += ")" | |
| # key findings | |
| kfs = structured.get("key_findings", []) | |
| abn_lines, nor_lines = [], [] | |
| for k in kfs: | |
| t, v = k.get("test"), k.get("value") | |
| st = k.get("status") | |
| if st in ("異常","偏高"): | |
| abn_lines.append(f".{t}: {v}({st})") | |
| elif st == "正常": | |
| nor_lines.append(f".{t}: {v}(正常)") | |
| else: | |
| nor_lines.append(f".{t}: {v}") | |
| risks = structured.get("disease_risk", []) | |
| recs = structured.get("recommendations", []) | |
| summary = structured.get("summary") or "" | |
| sections = [head, ""] | |
| if abn_lines: | |
| sections += ["【異常/偏高項目】"] + abn_lines + [""] | |
| if nor_lines: | |
| sections += ["【其他檢測】"] + nor_lines + [""] | |
| if risks: | |
| sections += ["【疾病風險/診斷】", "." + "、".join(risks), ""] | |
| if recs: | |
| sections += ["【建議】", "." + "、".join(recs), ""] | |
| if summary: | |
| sections += ["【摘要敘述】", summary, ""] | |
| return "\n".join(sections).strip() | |
| # ---------- Gradio UI ---------- | |
| EXAMPLE = "李偉(65歲,男),有高血壓與糖尿病。\n診間血壓152/94mmHg,空腹血糖138mg/dL,HbA1c 7.1%。\n建議使用ARB類藥物並低鹽飲食。" | |
| def run(text): | |
| tokens, entities, structured = extract_structured(text) | |
| human_report = make_readable_report(structured) | |
| return ( | |
| human_report, | |
| json.dumps(structured, ensure_ascii=False, indent=2), | |
| json.dumps(entities, ensure_ascii=False, indent=2), | |
| json.dumps(tokens, ensure_ascii=False, indent=2), | |
| ) | |
| with gr.Blocks(title="HTN NER (Chinese)") as demo: | |
| gr.Markdown("## Hypertension NER → Human Report / JSON / Entities / Tokens") | |
| inp = gr.Textbox(label="輸入文字 (中文)", lines=6, value=EXAMPLE) | |
| btn = gr.Button("Analyze") | |
| out_report = gr.Textbox(label="Doctor Report", lines=12) | |
| out_struct = gr.Code(label="Structured JSON") | |
| out_entities = gr.Code(label="Entities (spans)") | |
| out_tokens = gr.Code(label="Token-level (B/I/O)") | |
| btn.click(run, inputs=inp, outputs=[out_report, out_struct, out_entities, out_tokens]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0") | |