magicboker's picture
Update app.py
529995b verified
# app.py
# Gradio NER demo using: tner/xlm-roberta-base-conll2003
#
# requirements.txt:
# gradio>=4.0
# transformers>=4.35
# torch
# sentencepiece
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
MODEL_ID = "tner/xlm-roberta-base-conll2003"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_ID)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
def merge_spans(text: str, per_token):
"""
Merge token-level labels into entity spans.
Supports BIO labels (B- / I- / O) and non-BIO labels.
per_token: [{label, start, end}]
returns: [{entity, start, end}]
"""
spans = []
cur = None
def close_cur():
nonlocal cur
if cur:
spans.append(cur)
cur = None
for t in per_token:
lab = t["label"] or "O"
st, ed = t["start"], t["end"]
if lab == "O":
close_cur()
continue
if lab.startswith("B-"):
close_cur()
cur = {"entity": lab[2:], "start": st, "end": ed}
continue
if lab.startswith("I-"):
ent = lab[2:]
if cur and cur["entity"] == ent:
cur["end"] = ed
else:
close_cur()
cur = {"entity": ent, "start": st, "end": ed}
continue
# non-BIO: treat same label as continuation
ent = lab
if cur and cur["entity"] == ent:
cur["end"] = ed
else:
close_cur()
cur = {"entity": ent, "start": st, "end": ed}
close_cur()
return spans
def run_ner(text: str, max_length: int, show_tokens: bool):
text = (text or "").strip()
if not text:
return [], ""
enc = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=int(max_length),
return_offsets_mapping=True,
)
offsets = enc.pop("offset_mapping")[0].tolist()
input_ids = enc["input_ids"][0].tolist()
tokens = tokenizer.convert_ids_to_tokens(input_ids)
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
out = model(**enc)
logits = out.logits[0] # (seq_len, num_labels)
pred_ids = logits.argmax(dim=-1).tolist()
id2label = model.config.id2label
per_token = []
for tok, pid, (st, ed) in zip(tokens, pred_ids, offsets):
# offset st==ed 通常代表 special/padding 或沒對應到原文
if st == ed:
continue
per_token.append({
"token": tok,
"label": id2label[pid],
"start": int(st),
"end": int(ed),
})
spans = merge_spans(text, per_token)
# ✅ 2D list: avoid Gradio showing [object Object]
table_rows = []
for s in spans:
table_rows.append([
s["entity"],
text[s["start"]:s["end"]],
s["start"],
s["end"],
])
debug = ""
if show_tokens:
lines = ["token\tlabel\t[offsets]"]
for t in per_token:
lines.append(f"{t['token']}\t{t['label']}\t[{t['start']},{t['end']}]")
debug = "\n".join(lines)
return table_rows, debug
with gr.Blocks(title="XLM-R NER (CoNLL-2003)") as demo:
gr.Markdown(f"""
# XLM-R NER Demo (CoNLL-2003)
Model: **{MODEL_ID}**
Labels: typically **PER / ORG / LOC / MISC**
""")
with gr.Row():
max_length = gr.Slider(64, 512, value=256, step=32, label="max_length (truncate)")
show_tokens = gr.Checkbox(value=False, label="Show token labels (debug)")
text = gr.Textbox(
label="Input text",
lines=10,
value="Tim Chen\nSenior Software Engineer at Apple Inc.\nTaipei, Taiwan",
placeholder="Paste text here (e.g., OCR output).",
)
btn = gr.Button("Run NER")
out_table = gr.Dataframe(
label="Entities (spans)",
headers=["entity", "text", "start", "end"],
datatype=["str", "str", "number", "number"],
interactive=False,
wrap=True,
)
debug_box = gr.Textbox(label="Raw token output", lines=12)
btn.click(fn=run_ner, inputs=[text, max_length, show_tokens], outputs=[out_table, debug_box])
if __name__ == "__main__":
demo.launch()