|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
|
|
|
|
MODEL_ID = "tner/xlm-roberta-base-conll2003" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) |
|
|
model = AutoModelForTokenClassification.from_pretrained(MODEL_ID) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
def merge_spans(text: str, per_token): |
|
|
""" |
|
|
Merge token-level labels into entity spans. |
|
|
Supports BIO labels (B- / I- / O) and non-BIO labels. |
|
|
per_token: [{label, start, end}] |
|
|
returns: [{entity, start, end}] |
|
|
""" |
|
|
spans = [] |
|
|
cur = None |
|
|
|
|
|
def close_cur(): |
|
|
nonlocal cur |
|
|
if cur: |
|
|
spans.append(cur) |
|
|
cur = None |
|
|
|
|
|
for t in per_token: |
|
|
lab = t["label"] or "O" |
|
|
st, ed = t["start"], t["end"] |
|
|
|
|
|
if lab == "O": |
|
|
close_cur() |
|
|
continue |
|
|
|
|
|
if lab.startswith("B-"): |
|
|
close_cur() |
|
|
cur = {"entity": lab[2:], "start": st, "end": ed} |
|
|
continue |
|
|
|
|
|
if lab.startswith("I-"): |
|
|
ent = lab[2:] |
|
|
if cur and cur["entity"] == ent: |
|
|
cur["end"] = ed |
|
|
else: |
|
|
close_cur() |
|
|
cur = {"entity": ent, "start": st, "end": ed} |
|
|
continue |
|
|
|
|
|
|
|
|
ent = lab |
|
|
if cur and cur["entity"] == ent: |
|
|
cur["end"] = ed |
|
|
else: |
|
|
close_cur() |
|
|
cur = {"entity": ent, "start": st, "end": ed} |
|
|
|
|
|
close_cur() |
|
|
return spans |
|
|
|
|
|
def run_ner(text: str, max_length: int, show_tokens: bool): |
|
|
text = (text or "").strip() |
|
|
if not text: |
|
|
return [], "" |
|
|
|
|
|
enc = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=int(max_length), |
|
|
return_offsets_mapping=True, |
|
|
) |
|
|
|
|
|
offsets = enc.pop("offset_mapping")[0].tolist() |
|
|
input_ids = enc["input_ids"][0].tolist() |
|
|
tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
|
|
|
|
enc = {k: v.to(device) for k, v in enc.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model(**enc) |
|
|
|
|
|
logits = out.logits[0] |
|
|
pred_ids = logits.argmax(dim=-1).tolist() |
|
|
id2label = model.config.id2label |
|
|
|
|
|
per_token = [] |
|
|
for tok, pid, (st, ed) in zip(tokens, pred_ids, offsets): |
|
|
|
|
|
if st == ed: |
|
|
continue |
|
|
per_token.append({ |
|
|
"token": tok, |
|
|
"label": id2label[pid], |
|
|
"start": int(st), |
|
|
"end": int(ed), |
|
|
}) |
|
|
|
|
|
spans = merge_spans(text, per_token) |
|
|
|
|
|
|
|
|
table_rows = [] |
|
|
for s in spans: |
|
|
table_rows.append([ |
|
|
s["entity"], |
|
|
text[s["start"]:s["end"]], |
|
|
s["start"], |
|
|
s["end"], |
|
|
]) |
|
|
|
|
|
debug = "" |
|
|
if show_tokens: |
|
|
lines = ["token\tlabel\t[offsets]"] |
|
|
for t in per_token: |
|
|
lines.append(f"{t['token']}\t{t['label']}\t[{t['start']},{t['end']}]") |
|
|
debug = "\n".join(lines) |
|
|
|
|
|
return table_rows, debug |
|
|
|
|
|
with gr.Blocks(title="XLM-R NER (CoNLL-2003)") as demo: |
|
|
gr.Markdown(f""" |
|
|
# XLM-R NER Demo (CoNLL-2003) |
|
|
Model: **{MODEL_ID}** |
|
|
Labels: typically **PER / ORG / LOC / MISC** |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
max_length = gr.Slider(64, 512, value=256, step=32, label="max_length (truncate)") |
|
|
show_tokens = gr.Checkbox(value=False, label="Show token labels (debug)") |
|
|
|
|
|
text = gr.Textbox( |
|
|
label="Input text", |
|
|
lines=10, |
|
|
value="Tim Chen\nSenior Software Engineer at Apple Inc.\nTaipei, Taiwan", |
|
|
placeholder="Paste text here (e.g., OCR output).", |
|
|
) |
|
|
|
|
|
btn = gr.Button("Run NER") |
|
|
|
|
|
out_table = gr.Dataframe( |
|
|
label="Entities (spans)", |
|
|
headers=["entity", "text", "start", "end"], |
|
|
datatype=["str", "str", "number", "number"], |
|
|
interactive=False, |
|
|
wrap=True, |
|
|
) |
|
|
|
|
|
debug_box = gr.Textbox(label="Raw token output", lines=12) |
|
|
|
|
|
btn.click(fn=run_ner, inputs=[text, max_length, show_tokens], outputs=[out_table, debug_box]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |