File size: 4,404 Bytes
7105a56 529995b 7105a56 529995b 85a5e50 7105a56 77487c2 7105a56 85a5e50 7105a56 529995b 7105a56 85a5e50 5bdae6a 85a5e50 529995b 85a5e50 529995b 85a5e50 5bdae6a 85a5e50 5bdae6a 529995b 5bdae6a 85a5e50 5bdae6a 85a5e50 7105a56 6273d0a 7105a56 85a5e50 6273d0a 85a5e50 6273d0a 529995b 6273d0a 5bdae6a 6273d0a 529995b 85b9fc4 7105a56 85b9fc4 7105a56 85a5e50 6273d0a 85a5e50 7105a56 85b9fc4 7105a56 529995b 85a5e50 529995b 7105a56 529995b 85a5e50 7105a56 529995b 7105a56 529995b 7105a56 85a5e50 7105a56 85a5e50 7105a56 77487c2 85a5e50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# app.py
# Gradio NER demo using: tner/xlm-roberta-base-conll2003
#
# requirements.txt:
# gradio>=4.0
# transformers>=4.35
# torch
# sentencepiece
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
MODEL_ID = "tner/xlm-roberta-base-conll2003"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_ID)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
def merge_spans(text: str, per_token):
"""
Merge token-level labels into entity spans.
Supports BIO labels (B- / I- / O) and non-BIO labels.
per_token: [{label, start, end}]
returns: [{entity, start, end}]
"""
spans = []
cur = None
def close_cur():
nonlocal cur
if cur:
spans.append(cur)
cur = None
for t in per_token:
lab = t["label"] or "O"
st, ed = t["start"], t["end"]
if lab == "O":
close_cur()
continue
if lab.startswith("B-"):
close_cur()
cur = {"entity": lab[2:], "start": st, "end": ed}
continue
if lab.startswith("I-"):
ent = lab[2:]
if cur and cur["entity"] == ent:
cur["end"] = ed
else:
close_cur()
cur = {"entity": ent, "start": st, "end": ed}
continue
# non-BIO: treat same label as continuation
ent = lab
if cur and cur["entity"] == ent:
cur["end"] = ed
else:
close_cur()
cur = {"entity": ent, "start": st, "end": ed}
close_cur()
return spans
def run_ner(text: str, max_length: int, show_tokens: bool):
text = (text or "").strip()
if not text:
return [], ""
enc = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=int(max_length),
return_offsets_mapping=True,
)
offsets = enc.pop("offset_mapping")[0].tolist()
input_ids = enc["input_ids"][0].tolist()
tokens = tokenizer.convert_ids_to_tokens(input_ids)
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
out = model(**enc)
logits = out.logits[0] # (seq_len, num_labels)
pred_ids = logits.argmax(dim=-1).tolist()
id2label = model.config.id2label
per_token = []
for tok, pid, (st, ed) in zip(tokens, pred_ids, offsets):
# offset st==ed 通常代表 special/padding 或沒對應到原文
if st == ed:
continue
per_token.append({
"token": tok,
"label": id2label[pid],
"start": int(st),
"end": int(ed),
})
spans = merge_spans(text, per_token)
# ✅ 2D list: avoid Gradio showing [object Object]
table_rows = []
for s in spans:
table_rows.append([
s["entity"],
text[s["start"]:s["end"]],
s["start"],
s["end"],
])
debug = ""
if show_tokens:
lines = ["token\tlabel\t[offsets]"]
for t in per_token:
lines.append(f"{t['token']}\t{t['label']}\t[{t['start']},{t['end']}]")
debug = "\n".join(lines)
return table_rows, debug
with gr.Blocks(title="XLM-R NER (CoNLL-2003)") as demo:
gr.Markdown(f"""
# XLM-R NER Demo (CoNLL-2003)
Model: **{MODEL_ID}**
Labels: typically **PER / ORG / LOC / MISC**
""")
with gr.Row():
max_length = gr.Slider(64, 512, value=256, step=32, label="max_length (truncate)")
show_tokens = gr.Checkbox(value=False, label="Show token labels (debug)")
text = gr.Textbox(
label="Input text",
lines=10,
value="Tim Chen\nSenior Software Engineer at Apple Inc.\nTaipei, Taiwan",
placeholder="Paste text here (e.g., OCR output).",
)
btn = gr.Button("Run NER")
out_table = gr.Dataframe(
label="Entities (spans)",
headers=["entity", "text", "start", "end"],
datatype=["str", "str", "number", "number"],
interactive=False,
wrap=True,
)
debug_box = gr.Textbox(label="Raw token output", lines=12)
btn.click(fn=run_ner, inputs=[text, max_length, show_tokens], outputs=[out_table, debug_box])
if __name__ == "__main__":
demo.launch() |