# app.py # Gradio NER demo using: tner/xlm-roberta-base-conll2003 # # requirements.txt: # gradio>=4.0 # transformers>=4.35 # torch # sentencepiece import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForTokenClassification MODEL_ID = "tner/xlm-roberta-base-conll2003" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) model = AutoModelForTokenClassification.from_pretrained(MODEL_ID) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() def merge_spans(text: str, per_token): """ Merge token-level labels into entity spans. Supports BIO labels (B- / I- / O) and non-BIO labels. per_token: [{label, start, end}] returns: [{entity, start, end}] """ spans = [] cur = None def close_cur(): nonlocal cur if cur: spans.append(cur) cur = None for t in per_token: lab = t["label"] or "O" st, ed = t["start"], t["end"] if lab == "O": close_cur() continue if lab.startswith("B-"): close_cur() cur = {"entity": lab[2:], "start": st, "end": ed} continue if lab.startswith("I-"): ent = lab[2:] if cur and cur["entity"] == ent: cur["end"] = ed else: close_cur() cur = {"entity": ent, "start": st, "end": ed} continue # non-BIO: treat same label as continuation ent = lab if cur and cur["entity"] == ent: cur["end"] = ed else: close_cur() cur = {"entity": ent, "start": st, "end": ed} close_cur() return spans def run_ner(text: str, max_length: int, show_tokens: bool): text = (text or "").strip() if not text: return [], "" enc = tokenizer( text, return_tensors="pt", truncation=True, max_length=int(max_length), return_offsets_mapping=True, ) offsets = enc.pop("offset_mapping")[0].tolist() input_ids = enc["input_ids"][0].tolist() tokens = tokenizer.convert_ids_to_tokens(input_ids) enc = {k: v.to(device) for k, v in enc.items()} with torch.no_grad(): out = model(**enc) logits = out.logits[0] # (seq_len, num_labels) pred_ids = logits.argmax(dim=-1).tolist() id2label = model.config.id2label per_token = [] for tok, pid, (st, ed) in zip(tokens, pred_ids, offsets): # offset st==ed 通常代表 special/padding 或沒對應到原文 if st == ed: continue per_token.append({ "token": tok, "label": id2label[pid], "start": int(st), "end": int(ed), }) spans = merge_spans(text, per_token) # ✅ 2D list: avoid Gradio showing [object Object] table_rows = [] for s in spans: table_rows.append([ s["entity"], text[s["start"]:s["end"]], s["start"], s["end"], ]) debug = "" if show_tokens: lines = ["token\tlabel\t[offsets]"] for t in per_token: lines.append(f"{t['token']}\t{t['label']}\t[{t['start']},{t['end']}]") debug = "\n".join(lines) return table_rows, debug with gr.Blocks(title="XLM-R NER (CoNLL-2003)") as demo: gr.Markdown(f""" # XLM-R NER Demo (CoNLL-2003) Model: **{MODEL_ID}** Labels: typically **PER / ORG / LOC / MISC** """) with gr.Row(): max_length = gr.Slider(64, 512, value=256, step=32, label="max_length (truncate)") show_tokens = gr.Checkbox(value=False, label="Show token labels (debug)") text = gr.Textbox( label="Input text", lines=10, value="Tim Chen\nSenior Software Engineer at Apple Inc.\nTaipei, Taiwan", placeholder="Paste text here (e.g., OCR output).", ) btn = gr.Button("Run NER") out_table = gr.Dataframe( label="Entities (spans)", headers=["entity", "text", "start", "end"], datatype=["str", "str", "number", "number"], interactive=False, wrap=True, ) debug_box = gr.Textbox(label="Raw token output", lines=12) btn.click(fn=run_ner, inputs=[text, max_length, show_tokens], outputs=[out_table, debug_box]) if __name__ == "__main__": demo.launch()