File size: 4,404 Bytes
7105a56
529995b
7105a56
529995b
85a5e50
 
 
 
7105a56
77487c2
7105a56
85a5e50
7105a56
529995b
7105a56
 
 
 
85a5e50
 
 
 
5bdae6a
85a5e50
529995b
 
 
 
85a5e50
 
 
 
 
 
 
 
 
 
 
529995b
85a5e50
 
5bdae6a
 
 
 
85a5e50
 
 
5bdae6a
 
 
 
 
 
 
 
 
 
 
529995b
5bdae6a
 
85a5e50
 
 
5bdae6a
85a5e50
 
 
7105a56
 
 
 
 
 
6273d0a
 
 
 
 
 
 
7105a56
85a5e50
 
 
 
 
6273d0a
 
 
85a5e50
6273d0a
 
 
 
 
 
529995b
6273d0a
 
 
 
 
 
 
 
 
5bdae6a
6273d0a
529995b
85b9fc4
7105a56
85b9fc4
 
 
 
 
 
7105a56
 
 
85a5e50
6273d0a
85a5e50
 
7105a56
85b9fc4
7105a56
529995b
85a5e50
529995b
7105a56
529995b
85a5e50
7105a56
 
 
529995b
7105a56
 
 
 
529995b
7105a56
 
 
 
 
 
85a5e50
 
 
7105a56
 
 
 
85a5e50
7105a56
 
77487c2
85a5e50
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# app.py
# Gradio NER demo using: tner/xlm-roberta-base-conll2003
#
# requirements.txt:
# gradio>=4.0
# transformers>=4.35
# torch
# sentencepiece

import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification

MODEL_ID = "tner/xlm-roberta-base-conll2003"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_ID)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

def merge_spans(text: str, per_token):
    """
    Merge token-level labels into entity spans.
    Supports BIO labels (B- / I- / O) and non-BIO labels.
    per_token: [{label, start, end}]
    returns:  [{entity, start, end}]
    """
    spans = []
    cur = None

    def close_cur():
        nonlocal cur
        if cur:
            spans.append(cur)
            cur = None

    for t in per_token:
        lab = t["label"] or "O"
        st, ed = t["start"], t["end"]

        if lab == "O":
            close_cur()
            continue

        if lab.startswith("B-"):
            close_cur()
            cur = {"entity": lab[2:], "start": st, "end": ed}
            continue

        if lab.startswith("I-"):
            ent = lab[2:]
            if cur and cur["entity"] == ent:
                cur["end"] = ed
            else:
                close_cur()
                cur = {"entity": ent, "start": st, "end": ed}
            continue

        # non-BIO: treat same label as continuation
        ent = lab
        if cur and cur["entity"] == ent:
            cur["end"] = ed
        else:
            close_cur()
            cur = {"entity": ent, "start": st, "end": ed}

    close_cur()
    return spans

def run_ner(text: str, max_length: int, show_tokens: bool):
    text = (text or "").strip()
    if not text:
        return [], ""

    enc = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=int(max_length),
        return_offsets_mapping=True,
    )

    offsets = enc.pop("offset_mapping")[0].tolist()
    input_ids = enc["input_ids"][0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    enc = {k: v.to(device) for k, v in enc.items()}

    with torch.no_grad():
        out = model(**enc)

    logits = out.logits[0]  # (seq_len, num_labels)
    pred_ids = logits.argmax(dim=-1).tolist()
    id2label = model.config.id2label

    per_token = []
    for tok, pid, (st, ed) in zip(tokens, pred_ids, offsets):
        # offset st==ed 通常代表 special/padding 或沒對應到原文
        if st == ed:
            continue
        per_token.append({
            "token": tok,
            "label": id2label[pid],
            "start": int(st),
            "end": int(ed),
        })

    spans = merge_spans(text, per_token)

    # ✅ 2D list: avoid Gradio showing [object Object]
    table_rows = []
    for s in spans:
        table_rows.append([
            s["entity"],
            text[s["start"]:s["end"]],
            s["start"],
            s["end"],
        ])

    debug = ""
    if show_tokens:
        lines = ["token\tlabel\t[offsets]"]
        for t in per_token:
            lines.append(f"{t['token']}\t{t['label']}\t[{t['start']},{t['end']}]")
        debug = "\n".join(lines)

    return table_rows, debug

with gr.Blocks(title="XLM-R NER (CoNLL-2003)") as demo:
    gr.Markdown(f"""
# XLM-R NER Demo (CoNLL-2003)
Model: **{MODEL_ID}**  
Labels: typically **PER / ORG / LOC / MISC**
""")

    with gr.Row():
        max_length = gr.Slider(64, 512, value=256, step=32, label="max_length (truncate)")
        show_tokens = gr.Checkbox(value=False, label="Show token labels (debug)")

    text = gr.Textbox(
        label="Input text",
        lines=10,
        value="Tim Chen\nSenior Software Engineer at Apple Inc.\nTaipei, Taiwan",
        placeholder="Paste text here (e.g., OCR output).",
    )

    btn = gr.Button("Run NER")

    out_table = gr.Dataframe(
        label="Entities (spans)",
        headers=["entity", "text", "start", "end"],
        datatype=["str", "str", "number", "number"],
        interactive=False,
        wrap=True,
    )

    debug_box = gr.Textbox(label="Raw token output", lines=12)

    btn.click(fn=run_ner, inputs=[text, max_length, show_tokens], outputs=[out_table, debug_box])

if __name__ == "__main__":
    demo.launch()