NER_hyper / app.py
Donlagon007's picture
Update app.py
78247e7 verified
import re, json
import torch
import gradio as gr
from transformers import BertTokenizerFast, BertForTokenClassification
# === ตั้งค่าโมเดลจาก Hub ===
MODEL_ID = "Donlagon007/htn-ner-v1"
# โหลดโมเดล/โทเคนไนเซอร์ (CPU เป็นค่าเริ่มต้นใน Spaces)
tokenizer = BertTokenizerFast.from_pretrained(MODEL_ID)
model = BertForTokenClassification.from_pretrained(MODEL_ID)
model.eval()
id2label = model.config.id2label
# ---------- Utils ----------
def decode_bio_to_spans(labels):
spans, cur_type, s = [], None, None
for i, lab in enumerate(labels):
if lab == "O" or lab is None:
if cur_type is not None:
spans.append((cur_type, s, i-1))
cur_type, s = None, None
continue
tag, typ = lab.split("-", 1)
if tag == "B":
if cur_type is not None:
spans.append((cur_type, s, i-1))
cur_type, s = typ, i
elif tag == "I":
if cur_type != typ:
if cur_type is not None:
spans.append((cur_type, s, i-1))
cur_type, s = typ, i
if cur_type is not None:
spans.append((cur_type, s, len(labels)-1))
return spans
def ner_predict_with_tokens(text, max_length=256):
enc = tokenizer(
text,
return_offsets_mapping=True,
return_tensors="pt",
truncation=True, max_length=max_length
)
with torch.no_grad():
out = model(
input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"]
)
pred_ids = out.logits.argmax(-1).squeeze(0).tolist()
offsets = enc["offset_mapping"].squeeze(0).tolist()
input_ids = enc["input_ids"].squeeze(0).tolist()
tokens_info, kept_labels, kept_offsets = [], [], []
# ตัด [CLS]/[SEP]/padding โดยเช็ค offset (0,0)
for lid, (st, ed), tid in zip(pred_ids, offsets, input_ids):
if st == ed == 0:
continue
tok = tokenizer.convert_ids_to_tokens([tid])[0]
lab = id2label[lid]
tokens_info.append({"token": tok, "label": lab, "start": st, "end": ed})
kept_labels.append(lab)
kept_offsets.append((st, ed))
spans_tok = decode_bio_to_spans(kept_labels)
entities = []
for typ, s_tok, e_tok in spans_tok:
cs, ce = kept_offsets[s_tok][0], kept_offsets[e_tok][1]
entities.append({"type": typ, "text": text[cs:ce], "start": cs, "end": ce})
return tokens_info, entities
def cn_num(s):
m = re.search(r"(\d+(?:\.\d+)?)", s.replace(" ", ""))
return float(m.group(1)) if m else None
def parse_bp(value_text):
m = re.search(r"(\d{2,3})\s*/\s*(\d{2,3})", value_text.replace(" ", ""))
if m:
return int(m.group(1)), int(m.group(2))
return None, None
THRESHOLDS = {
"空腹血糖": {"unit":"mg/dL", "abnormal": lambda v: v is not None and v >= 126},
"HbA1c": {"unit":"%", "abnormal": lambda v: v is not None and v >= 6.5},
"LDL": {"unit":"mg/dL", "abnormal": lambda v: v is not None and v >= 160,
"borderline": lambda v: v is not None and 130 <= v < 160},
}
def status_for(test, val):
if test in THRESHOLDS:
th = THRESHOLDS[test]
v = cn_num(val)
if "borderline" in th and th["borderline"](v):
return "偏高"
return "異常" if th["abnormal"](v) else "正常"
return None
def pair_tests_values(entities):
ents = sorted(entities, key=lambda x: x["start"])
pairs, lone = [], []
last_test = None
for e in ents:
if e["type"] == "TEST":
if last_test: lone.append(last_test)
last_test = {"test": e["text"], "start": e["start"], "value": None}
elif e["type"] == "VALUE" and last_test and (e["start"] - last_test["start"]) < 40:
last_test["value"] = e["text"]
pairs.append({"test": last_test["test"], "value": e["text"]})
last_test = None
if last_test: lone.append(last_test)
return pairs, lone
def extract_structured(text):
tokens, entities = ner_predict_with_tokens(text)
# basic fields
name, ages, sex = None, [], None
for e in entities:
if e["type"] == "PER" and (name is None or len(e["text"]) > len(name)):
name = e["text"]
elif e["type"] == "AGE":
v = cn_num(e["text"])
if v is not None: ages.append(int(v))
elif e["type"] == "SEX":
sex = e["text"]
# fallback sex detect
if not sex:
m_sex = re.search(r"(男|女)", text)
if m_sex: sex = m_sex.group(1)
pairs, _ = pair_tests_values(entities)
key_findings = []
for p in pairs:
st = status_for(p["test"], p["value"])
row = {"test": p["test"], "value": p["value"]}
if st: row["status"] = st
key_findings.append(row)
risks = set()
# 4.1 Diabetes
fpg = next((cn_num(p["value"]) for p in pairs if p["test"] == "空腹血糖"), None)
a1c = next((cn_num(p["value"]) for p in pairs if p["test"] == "HbA1c"), None)
if (fpg is not None and fpg >= 126) or (a1c is not None and a1c >= 6.5):
risks.add("糖尿病")
# 4.2 Hyperlipidemia via LDL
ldl = next((cn_num(p["value"]) for p in pairs if p["test"] == "LDL"), None)
if ldl is not None and ldl >= 160:
risks.add("高血脂")
elif ldl is not None and ldl >= 130:
risks.add("高血脂(輕度)")
# 4.3 Hypertension via BP or DISEASE mention
bp_val = next((p["value"] for p in pairs if p["test"] in ["診間血壓","家庭血壓","24小時動態血壓"]), None)
if bp_val:
sys, dia = parse_bp(bp_val)
if sys and dia and (sys >= 140 or dia >= 90):
risks.add("高血壓")
if any(e["type"] == "DISEASE" and "高血壓" in e["text"] for e in entities):
risks.add("高血壓")
# recommendations
recs = []
for e in entities:
if e["type"] == "DRUG":
recs.append(f"開始服用 {e['text']}")
elif e["type"] == "DRUG_CLASS":
recs.append(f"考慮 {e['text']} 類藥物")
elif e["type"] == "TREATMENT":
t = e["text"]
if "飲食" in t and "低鹽" not in t:
t = "控制飲食"
recs.append(f"建議{t}")
age = max(ages) if ages else None
name_disp = name if name else "病人"
age_disp = f"{age}歲" if age is not None else ""
abns = [f"{k['test']} {k['value']}" for k in key_findings if k.get("status") in ("異常","偏高")]
parts = [f"{name_disp}{age_disp})"] if age_disp else [name_disp]
if abns: parts.append(f"檢查顯示 " + "、".join(abns[:3]))
if "糖尿病" in risks: parts.append("符合糖尿病診斷")
if "高血脂" in risks or "高血脂(輕度)" in risks: parts.append("另見 LDL 偏高")
if recs: parts.append("建議:" + "、".join(recs[:3]))
summary = ",".join(parts) + "。"
structured = {
"name": name or None,
"age": age if age is not None else None,
"sex": sex,
"key_findings": key_findings,
"disease_risk": sorted(list(risks)),
"recommendations": recs,
"summary": summary
}
return tokens, entities, structured
# ---------- Human-readable report ----------
def make_readable_report(structured: dict) -> str:
name = structured.get("name") or "病人"
age = structured.get("age")
sex = structured.get("sex")
head = f"【健檢摘要】{name}"
if age is not None: head += f"({age}歲"
else: head += "(年齡不詳"
if sex: head += f",{sex})"
else: head += ")"
# key findings
kfs = structured.get("key_findings", [])
abn_lines, nor_lines = [], []
for k in kfs:
t, v = k.get("test"), k.get("value")
st = k.get("status")
if st in ("異常","偏高"):
abn_lines.append(f".{t}: {v}{st})")
elif st == "正常":
nor_lines.append(f".{t}: {v}(正常)")
else:
nor_lines.append(f".{t}: {v}")
risks = structured.get("disease_risk", [])
recs = structured.get("recommendations", [])
summary = structured.get("summary") or ""
sections = [head, ""]
if abn_lines:
sections += ["【異常/偏高項目】"] + abn_lines + [""]
if nor_lines:
sections += ["【其他檢測】"] + nor_lines + [""]
if risks:
sections += ["【疾病風險/診斷】", "." + "、".join(risks), ""]
if recs:
sections += ["【建議】", "." + "、".join(recs), ""]
if summary:
sections += ["【摘要敘述】", summary, ""]
return "\n".join(sections).strip()
# ---------- Gradio UI ----------
EXAMPLE = "李偉(65歲,男),有高血壓與糖尿病。\n診間血壓152/94mmHg,空腹血糖138mg/dL,HbA1c 7.1%。\n建議使用ARB類藥物並低鹽飲食。"
def run(text):
tokens, entities, structured = extract_structured(text)
human_report = make_readable_report(structured)
return (
human_report,
json.dumps(structured, ensure_ascii=False, indent=2),
json.dumps(entities, ensure_ascii=False, indent=2),
json.dumps(tokens, ensure_ascii=False, indent=2),
)
with gr.Blocks(title="HTN NER (Chinese)") as demo:
gr.Markdown("## Hypertension NER → Human Report / JSON / Entities / Tokens")
inp = gr.Textbox(label="輸入文字 (中文)", lines=6, value=EXAMPLE)
btn = gr.Button("Analyze")
out_report = gr.Textbox(label="Doctor Report", lines=12)
out_struct = gr.Code(label="Structured JSON")
out_entities = gr.Code(label="Entities (spans)")
out_tokens = gr.Code(label="Token-level (B/I/O)")
btn.click(run, inputs=inp, outputs=[out_report, out_struct, out_entities, out_tokens])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")