Donlagon007 commited on
Commit
03cf9b3
·
verified ·
1 Parent(s): 149caca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -216
app.py CHANGED
@@ -1,216 +1,216 @@
1
- import re, json
2
- import torch
3
- import gradio as gr
4
- from transformers import BertTokenizerFast, BertForTokenClassification
5
-
6
- # === ตั้งค่าโมเดลจาก Hub ===
7
- # เปลี่ยนเป็นโมเดลของคุณ เช่น "donla/htn-ner"
8
- MODEL_ID = "Donlagon007/htn-ner-v1"
9
-
10
- # โหลดโมเดล/โทเคนไนเซอร์ (CPU เป็นค่าเริ่มต้นใน Spaces)
11
- tokenizer = BertTokenizerFast.from_pretrained(MODEL_ID)
12
- model = BertForTokenClassification.from_pretrained(MODEL_ID)
13
- model.eval()
14
- id2label = model.config.id2label
15
-
16
- # ---------- Utils ----------
17
- def decode_bio_to_spans(labels):
18
- spans, cur_type, s = [], None, None
19
- for i, lab in enumerate(labels):
20
- if lab == "O" or lab is None:
21
- if cur_type is not None:
22
- spans.append((cur_type, s, i-1))
23
- cur_type, s = None, None
24
- continue
25
- tag, typ = lab.split("-", 1)
26
- if tag == "B":
27
- if cur_type is not None:
28
- spans.append((cur_type, s, i-1))
29
- cur_type, s = typ, i
30
- elif tag == "I":
31
- if cur_type != typ:
32
- if cur_type is not None:
33
- spans.append((cur_type, s, i-1))
34
- cur_type, s = typ, i
35
- if cur_type is not None:
36
- spans.append((cur_type, s, len(labels)-1))
37
- return spans
38
-
39
- def ner_predict_with_tokens(text, max_length=256):
40
- enc = tokenizer(
41
- text,
42
- return_offsets_mapping=True,
43
- return_tensors="pt",
44
- truncation=True, max_length=max_length
45
- )
46
- with torch.no_grad():
47
- out = model(
48
- input_ids=enc["input_ids"],
49
- attention_mask=enc["attention_mask"]
50
- )
51
- pred_ids = out.logits.argmax(-1).squeeze(0).tolist()
52
-
53
- offsets = enc["offset_mapping"].squeeze(0).tolist()
54
- input_ids = enc["input_ids"].squeeze(0).tolist()
55
-
56
- tokens_info, kept_labels, kept_offsets = [], [], []
57
- # ตัด [CLS]/[SEP]/padding โดยเช็ค offset (0,0)
58
- for lid, (st, ed), tid in zip(pred_ids, offsets, input_ids):
59
- if st == ed == 0:
60
- continue
61
- tok = tokenizer.convert_ids_to_tokens([tid])[0]
62
- lab = id2label[lid]
63
- tokens_info.append({"token": tok, "label": lab, "start": st, "end": ed})
64
- kept_labels.append(lab)
65
- kept_offsets.append((st, ed))
66
-
67
- spans_tok = decode_bio_to_spans(kept_labels)
68
- entities = []
69
- for typ, s_tok, e_tok in spans_tok:
70
- cs, ce = kept_offsets[s_tok][0], kept_offsets[e_tok][1]
71
- entities.append({"type": typ, "text": text[cs:ce], "start": cs, "end": ce})
72
- return tokens_info, entities
73
-
74
- def cn_num(s):
75
- m = re.search(r"(\d+(?:\.\d+)?)", s.replace(" ", ""))
76
- return float(m.group(1)) if m else None
77
-
78
- def parse_bp(value_text):
79
- m = re.search(r"(\d{2,3})\s*/\s*(\d{2,3})", value_text.replace(" ", ""))
80
- if m:
81
- return int(m.group(1)), int(m.group(2))
82
- return None, None
83
-
84
- THRESHOLDS = {
85
- "空腹血糖": {"unit":"mg/dL", "abnormal": lambda v: v is not None and v >= 126},
86
- "HbA1c": {"unit":"%", "abnormal": lambda v: v is not None and v >= 6.5},
87
- "LDL": {"unit":"mg/dL", "abnormal": lambda v: v is not None and v >= 160,
88
- "borderline": lambda v: v is not None and 130 <= v < 160},
89
- }
90
-
91
- def status_for(test, val):
92
- if test in THRESHOLDS:
93
- th = THRESHOLDS[test]
94
- v = cn_num(val)
95
- if "borderline" in th and th["borderline"](v):
96
- return "偏高"
97
- return "異常" if th["abnormal"](v) else "正常"
98
- return None
99
-
100
- def pair_tests_values(entities):
101
- ents = sorted(entities, key=lambda x: x["start"])
102
- pairs, lone = [], []
103
- last_test = None
104
- for e in ents:
105
- if e["type"] == "TEST":
106
- if last_test: lone.append(last_test)
107
- last_test = {"test": e["text"], "start": e["start"], "value": None}
108
- elif e["type"] == "VALUE" and last_test and (e["start"] - last_test["start"]) < 40:
109
- last_test["value"] = e["text"]
110
- pairs.append({"test": last_test["test"], "value": e["text"]})
111
- last_test = None
112
- if last_test: lone.append(last_test)
113
- return pairs, lone
114
-
115
- def extract_structured(text):
116
- tokens, entities = ner_predict_with_tokens(text)
117
-
118
- # basic fields
119
- name, ages, sex = None, [], None
120
- for e in entities:
121
- if e["type"] == "PER" and (name is None or len(e["text"]) > len(name)):
122
- name = e["text"]
123
- elif e["type"] == "AGE":
124
- v = cn_num(e["text"])
125
- if v is not None: ages.append(int(v))
126
- elif e["type"] == "SEX":
127
- sex = e["text"]
128
-
129
- # fallback sex detect
130
- if not sex:
131
- m_sex = re.search(r"(男|女)", text)
132
- if m_sex: sex = m_sex.group(1)
133
-
134
- pairs, _ = pair_tests_values(entities)
135
- key_findings = []
136
- for p in pairs:
137
- st = status_for(p["test"], p["value"])
138
- row = {"test": p["test"], "value": p["value"]}
139
- if st: row["status"] = st
140
- key_findings.append(row)
141
-
142
- risks = set()
143
- fpg = next((cn_num(p["value"]) for p in pairs if p["test"] == "空腹血糖"), None)
144
- a1c = next((cn_num(p["value"]) for p in pairs if p["test"] == "HbA1c"), None)
145
- ldl = next((cn_num(p["value"]) for p in pairs if p["test"] == "LDL"), None)
146
- bp_val = next((p["value"] for p in pairs if p["test"] in ["診間血壓","家庭血壓","24小時動態血壓"]), None)
147
- if (fpg is not None and fpg >= 126) or (a1c is not None and a1c >= 6.5):
148
- risks.add("糖尿病")
149
- if ldl is not None and ldl >= 160:
150
- risks.add("高血脂")
151
- elif ldl is not None and ldl >= 130:
152
- risks.add("高血脂(輕度)")
153
- if bp_val:
154
- sys, dia = parse_bp(bp_val)
155
- if sys and dia and (sys >= 140 or dia >= 90):
156
- risks.add("高血壓")
157
- if any(e["type"] == "DISEASE" and "高血壓" in e["text"] for e in entities):
158
- risks.add("高血壓")
159
-
160
- recs = []
161
- for e in entities:
162
- if e["type"] == "DRUG":
163
- recs.append(f"開始服用 {e['text']}")
164
- elif e["type"] == "DRUG_CLASS":
165
- recs.append(f"考慮 {e['text']} 類藥物")
166
- elif e["type"] == "TREATMENT":
167
- t = e["text"]
168
- if "飲食" in t and "低鹽" not in t:
169
- t = "控制飲食"
170
- recs.append(f"建議{t}")
171
-
172
- age = max(ages) if ages else None
173
- name_disp = name if name else "病人"
174
- age_disp = f"{age}歲" if age is not None else ""
175
- abns = [f"{k['test']} {k['value']}" for k in key_findings if k.get("status") in ("異常","偏高")]
176
- parts = [f"{name_disp}({age_disp})"] if age_disp else [name_disp]
177
- if abns: parts.append(f"檢查顯示 " + "、".join(abns[:3]))
178
- if "糖尿病" in risks: parts.append("符合糖尿病診斷")
179
- if "高血脂" in risks or "高血脂(輕度)" in risks: parts.append("另見 LDL 偏高")
180
- if recs: parts.append("建議:" + "、".join(recs[:3]))
181
- summary = ",".join(parts) + "。"
182
-
183
- structured = {
184
- "name": name or None,
185
- "age": age if age is not None else None,
186
- "sex": sex,
187
- "key_findings": key_findings,
188
- "disease_risk": sorted(list(risks)),
189
- "recommendations": recs,
190
- "summary": summary
191
- }
192
- return tokens, entities, structured
193
-
194
- # ---------- Gradio UI ----------
195
- EXAMPLE = "李偉(65歲,男),有高血壓與糖尿病。\n診間血壓152/94mmHg,空腹血糖138mg/dL,HbA1c 7.1%。\n建議使用ARB類藥物並低鹽飲食。"
196
-
197
- def run(text):
198
- tokens, entities, structured = extract_structured(text)
199
- return json.dumps(tokens, ensure_ascii=False, indent=2), \
200
- json.dumps(entities, ensure_ascii=False, indent=2), \
201
- json.dumps(structured, ensure_ascii=False, indent=2)
202
-
203
- with gr.Blocks(title="HTN NER (Chinese)") as demo:
204
- gr.Markdown("## Hypertension NER → Tokens / Entities / Structured JSON")
205
- inp = gr.Textbox(label="輸入文字 (中文)", lines=6, value=EXAMPLE)
206
- btn = gr.Button("Analyze")
207
- out_tokens = gr.Code(label="Token-level (B/I/O)")
208
- out_entities = gr.Code(label="Entities (spans)")
209
- out_struct = gr.Code(label="Structured JSON")
210
- btn.click(run, inputs=inp, outputs=[out_tokens, out_entities, out_struct])
211
-
212
- demo.launch()
213
-
214
-
215
- if __name__ == "__main__":
216
- demo.launch(server_name="0.0.0.0")
 
1
+ import re, json
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import BertTokenizerFast, BertForTokenClassification
5
+
6
+ # === ตั้งค่าโมเดลจาก Hub ===
7
+ # เปลี่ยนเป็นโมเดลของคุณ เช่น "donla/htn-ner"
8
+ MODEL_ID = "Donlagon007/htn-ner-v1"
9
+
10
+ # โหลดโมเดล/โทเคนไนเซอร์ (CPU เป็นค่าเริ่มต้นใน Spaces)
11
+ tokenizer = BertTokenizerFast.from_pretrained(MODEL_ID)
12
+ model = BertForTokenClassification.from_pretrained(MODEL_ID)
13
+ model.eval()
14
+ id2label = model.config.id2label
15
+
16
+ # ---------- Utils ----------
17
+ def decode_bio_to_spans(labels):
18
+ spans, cur_type, s = [], None, None
19
+ for i, lab in enumerate(labels):
20
+ if lab == "O" or lab is None:
21
+ if cur_type is not None:
22
+ spans.append((cur_type, s, i-1))
23
+ cur_type, s = None, None
24
+ continue
25
+ tag, typ = lab.split("-", 1)
26
+ if tag == "B":
27
+ if cur_type is not None:
28
+ spans.append((cur_type, s, i-1))
29
+ cur_type, s = typ, i
30
+ elif tag == "I":
31
+ if cur_type != typ:
32
+ if cur_type is not None:
33
+ spans.append((cur_type, s, i-1))
34
+ cur_type, s = typ, i
35
+ if cur_type is not None:
36
+ spans.append((cur_type, s, len(labels)-1))
37
+ return spans
38
+
39
+ def ner_predict_with_tokens(text, max_length=256):
40
+ enc = tokenizer(
41
+ text,
42
+ return_offsets_mapping=True,
43
+ return_tensors="pt",
44
+ truncation=True, max_length=max_length
45
+ )
46
+ with torch.no_grad():
47
+ out = model(
48
+ input_ids=enc["input_ids"],
49
+ attention_mask=enc["attention_mask"]
50
+ )
51
+ pred_ids = out.logits.argmax(-1).squeeze(0).tolist()
52
+
53
+ offsets = enc["offset_mapping"].squeeze(0).tolist()
54
+ input_ids = enc["input_ids"].squeeze(0).tolist()
55
+
56
+ tokens_info, kept_labels, kept_offsets = [], [], []
57
+ # ตัด [CLS]/[SEP]/padding โดยเช็ค offset (0,0)
58
+ for lid, (st, ed), tid in zip(pred_ids, offsets, input_ids):
59
+ if st == ed == 0:
60
+ continue
61
+ tok = tokenizer.convert_ids_to_tokens([tid])[0]
62
+ lab = id2label[lid]
63
+ tokens_info.append({"token": tok, "label": lab, "start": st, "end": ed})
64
+ kept_labels.append(lab)
65
+ kept_offsets.append((st, ed))
66
+
67
+ spans_tok = decode_bio_to_spans(kept_labels)
68
+ entities = []
69
+ for typ, s_tok, e_tok in spans_tok:
70
+ cs, ce = kept_offsets[s_tok][0], kept_offsets[e_tok][1]
71
+ entities.append({"type": typ, "text": text[cs:ce], "start": cs, "end": ce})
72
+ return tokens_info, entities
73
+
74
+ def cn_num(s):
75
+ m = re.search(r"(\d+(?:\.\d+)?)", s.replace(" ", ""))
76
+ return float(m.group(1)) if m else None
77
+
78
+ def parse_bp(value_text):
79
+ m = re.search(r"(\d{2,3})\s*/\s*(\d{2,3})", value_text.replace(" ", ""))
80
+ if m:
81
+ return int(m.group(1)), int(m.group(2))
82
+ return None, None
83
+
84
+ THRESHOLDS = {
85
+ "空腹血糖": {"unit":"mg/dL", "abnormal": lambda v: v is not None and v >= 126},
86
+ "HbA1c": {"unit":"%", "abnormal": lambda v: v is not None and v >= 6.5},
87
+ "LDL": {"unit":"mg/dL", "abnormal": lambda v: v is not None and v >= 160,
88
+ "borderline": lambda v: v is not None and 130 <= v < 160},
89
+ }
90
+
91
+ def status_for(test, val):
92
+ if test in THRESHOLDS:
93
+ th = THRESHOLDS[test]
94
+ v = cn_num(val)
95
+ if "borderline" in th and th["borderline"](v):
96
+ return "偏高"
97
+ return "異常" if th["abnormal"](v) else "正常"
98
+ return None
99
+
100
+ def pair_tests_values(entities):
101
+ ents = sorted(entities, key=lambda x: x["start"])
102
+ pairs, lone = [], []
103
+ last_test = None
104
+ for e in ents:
105
+ if e["type"] == "TEST":
106
+ if last_test: lone.append(last_test)
107
+ last_test = {"test": e["text"], "start": e["start"], "value": None}
108
+ elif e["type"] == "VALUE" and last_test and (e["start"] - last_test["start"]) < 40:
109
+ last_test["value"] = e["text"]
110
+ pairs.append({"test": last_test["test"], "value": e["text"]})
111
+ last_test = None
112
+ if last_test: lone.append(last_test)
113
+ return pairs, lone
114
+
115
+ def extract_structured(text):
116
+ tokens, entities = ner_predict_with_tokens(text)
117
+
118
+ # basic fields
119
+ name, ages, sex = None, [], None
120
+ for e in entities:
121
+ if e["type"] == "PER" and (name is None or len(e["text"]) > len(name)):
122
+ name = e["text"]
123
+ elif e["type"] == "AGE":
124
+ v = cn_num(e["text"])
125
+ if v is not None: ages.append(int(v))
126
+ elif e["type"] == "SEX":
127
+ sex = e["text"]
128
+
129
+ # fallback sex detect
130
+ if not sex:
131
+ m_sex = re.search(r"(男|女)", text)
132
+ if m_sex: sex = m_sex.group(1)
133
+
134
+ pairs, _ = pair_tests_values(entities)
135
+ key_findings = []
136
+ for p in pairs:
137
+ st = status_for(p["test"], p["value"])
138
+ row = {"test": p["test"], "value": p["value"]}
139
+ if st: row["status"] = st
140
+ key_findings.append(row)
141
+
142
+ risks = set()
143
+ fpg = next((cn_num(p["value"]) for p in pairs if p["test"] == "空腹血糖"), None)
144
+ a1c = next((cn_num(p["value"]) for p in pairs if p["test"] == "HbA1c"), None)
145
+ ldl = next((cn_num(p["value"]) for p in pairs if p["test"] == "LDL"), None)
146
+ bp_val = next((p["value"] for p in pairs if p["test"] in ["診間血壓","家庭血壓","24小時動態血壓"]), None)
147
+ if (fpg is not None and fpg >= 126) or (a1c is not None and a1c >= 6.5):
148
+ risks.add("糖尿病")
149
+ if ldl is not None and ldl >= 160:
150
+ risks.add("高血脂")
151
+ elif ldl is not None and ldl >= 130:
152
+ risks.add("高血脂(輕度)")
153
+ if bp_val:
154
+ sys, dia = parse_bp(bp_val)
155
+ if sys and dia and (sys >= 140 or dia >= 90):
156
+ risks.add("高血壓")
157
+ if any(e["type"] == "DISEASE" and "高血壓" in e["text"] for e in entities):
158
+ risks.add("高血壓")
159
+
160
+ recs = []
161
+ for e in entities:
162
+ if e["type"] == "DRUG":
163
+ recs.append(f"開始服用 {e['text']}")
164
+ elif e["type"] == "DRUG_CLASS":
165
+ recs.append(f"考慮 {e['text']} 類藥物")
166
+ elif e["type"] == "TREATMENT":
167
+ t = e["text"]
168
+ if "飲食" in t and "低鹽" not in t:
169
+ t = "控制飲食"
170
+ recs.append(f"建議{t}")
171
+
172
+ age = max(ages) if ages else None
173
+ name_disp = name if name else "病人"
174
+ age_disp = f"{age}歲" if age is not None else ""
175
+ abns = [f"{k['test']} {k['value']}" for k in key_findings if k.get("status") in ("異常","偏高")]
176
+ parts = [f"{name_disp}({age_disp})"] if age_disp else [name_disp]
177
+ if abns: parts.append(f"檢查顯示 " + "、".join(abns[:3]))
178
+ if "糖尿病" in risks: parts.append("符合糖尿病診斷")
179
+ if "高血脂" in risks or "高血脂(輕度)" in risks: parts.append("另見 LDL 偏高")
180
+ if recs: parts.append("建議:" + "、".join(recs[:3]))
181
+ summary = ",".join(parts) + "。"
182
+
183
+ structured = {
184
+ "name": name or None,
185
+ "age": age if age is not None else None,
186
+ "sex": sex,
187
+ "key_findings": key_findings,
188
+ "disease_risk": sorted(list(risks)),
189
+ "recommendations": recs,
190
+ "summary": summary
191
+ }
192
+ return tokens, entities, structured
193
+
194
+ # ---------- Gradio UI ----------
195
+ EXAMPLE = "李偉(65歲,男),有高血壓與糖尿病。\n診間血壓152/94mmHg,空腹血糖138mg/dL,HbA1c 7.1%。\n建議使用ARB類藥物並低鹽飲食。"
196
+
197
+ def run(text):
198
+ tokens, entities, structured = extract_structured(text)
199
+ return json.dumps(tokens, ensure_ascii=False, indent=2), \
200
+ json.dumps(entities, ensure_ascii=False, indent=2), \
201
+ json.dumps(structured, ensure_ascii=False, indent=2)
202
+
203
+ with gr.Blocks(title="HTN NER (Chinese)") as demo:
204
+ gr.Markdown("## Hypertension NER → Tokens / Entities / Structured JSON")
205
+ inp = gr.Textbox(label="輸入文字 (中文)", lines=6, value=EXAMPLE)
206
+ btn = gr.Button("Analyze")
207
+ out_tokens = gr.Code(label="Token-level (B/I/O)")
208
+ out_entities = gr.Code(label="Entities (spans)")
209
+ out_struct = gr.Code(label="Structured Reports")
210
+ btn.click(run, inputs=inp, outputs=[out_tokens, out_entities, out_struct])
211
+
212
+ demo.launch()
213
+
214
+
215
+ if __name__ == "__main__":
216
+ demo.launch(server_name="0.0.0.0")