File size: 3,168 Bytes
08865f9
7312fe1
08865f9
de73683
08865f9
 
de73683
7312fe1
 
de73683
7312fe1
 
 
de73683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08865f9
 
de73683
08865f9
 
 
 
 
7312fe1
 
08865f9
 
 
 
 
7312fe1
08865f9
 
 
de73683
08865f9
7312fe1
de73683
08865f9
 
7312fe1
 
 
 
 
 
08865f9
 
 
0f3ffac
de73683
 
 
7312fe1
08865f9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
from transformers import pipeline, AutoConfig

MODEL_ID = "CIAZIZ/arabic-ner-camelbert-wikiann"  # โ† your model repo
ner = pipeline("token-classification", model=MODEL_ID, aggregation_strategy="simple")

# --- make sure labels are correct (no LABEL_0) ---
cfg = ner.model.config
WIKIANN_LABELS = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
if not getattr(cfg, "id2label", None) or any(str(v).startswith("LABEL_") for v in cfg.id2label.values()):
    cfg.id2label = {i: lab for i, lab in enumerate(WIKIANN_LABELS)}
    cfg.label2id = {lab: i for i, lab in enumerate(WIKIANN_LABELS)}

AR_MAP    = {"PER": "ุดุฎุต", "ORG": "ู…ู†ุธู…ุฉ", "LOC": "ู…ูƒุงู†", "MISC": "ูƒูŠุงู† ุนุงู…"}
COLOR_MAP = {"PER": "#2f80ed", "ORG": "#9b51e0", "LOC": "#27ae60", "MISC": "#f2994a"}

def merge_touching_spans(ents):
    """Merge adjacent spans with the SAME entity group (fixes ุฃุฑ + ุงู… + ูƒูˆ โ†’ ุฃุฑุงู…ูƒูˆ)."""
    if not ents:
        return ents
    ents = sorted(ents, key=lambda x: x["start"])
    merged = [dict(ents[0])]
    for e in ents[1:]:
        last = merged[-1]
        g_last = last.get("entity_group") or last.get("entity")
        g_cur  = e.get("entity_group")  or e.get("entity")
        if g_last == g_cur and e["start"] == last["end"]:
            # extend the previous span
            last["end"] = e["end"]
            # average confidence (simple and robust)
            last["score"] = (float(last.get("score", 0)) + float(e.get("score", 0))) / 2.0
        else:
            merged.append(dict(e))
    return merged

def to_segments(text, ents):
    """Return [(span_text, label_or_None), ...] for gr.HighlightedText."""
    ents = sorted(ents, key=lambda x: x["start"])
    segs, i = [], 0
    for e in ents:
        if e["start"] > i:
            segs.append((text[i:e["start"]], None))
        group = e.get("entity_group", e.get("entity"))
        segs.append((text[e["start"]:e["end"]], group))
        i = e["end"]
    if i < len(text):
        segs.append((text[i:], None))
    return segs

def run(text: str):
    if not text.strip():
        return [], []
    out = ner(text)
    out = merge_touching_spans(out)            # โ† merge subword pieces
    segs = to_segments(text, out)

    # Build rows for the table (no 'position' column)
    rows = []
    for e in out:
        group = e.get("entity_group", e.get("entity", ""))
        rows.append([
            text[e["start"]:e["end"]],
            AR_MAP.get(group, group),
            round(float(e["score"]), 4),
        ])
    return segs, rows

with gr.Blocks(title="Arabic NER โ€” CAMeLBERT") as demo:
    txt = gr.Textbox(label="ุงูƒุชุจ ุงู„ู†ุต ุจุงู„ุนุฑุจูŠ", lines=3,
                     value="ุฒุงุฑ ู…ุญู…ุฏ ุจู† ุณู„ู…ุงู† ู…ุฏูŠู†ุฉ ู†ูŠูˆู… ูˆุงู„ุชู‚ู‰ ุจูˆูุฏ ู…ู† ุดุฑูƒุฉ ุฃุฑุงู…ูƒูˆ.")
    out_ht = gr.HighlightedText(label="", color_map=COLOR_MAP)  # โ† empty label
    out_tbl = gr.Dataframe(headers=["ุงู„ูƒูŠุงู†","ุงู„ู†ูˆุน","ุงู„ุซู‚ุฉ"], interactive=False)  # โ† no 'ุงู„ู…ูˆุถุน'
    gr.Button("ุชุญู„ูŠู„ ุงู„ู†ุต").click(run, inputs=txt, outputs=[out_ht, out_tbl])

if __name__ == "__main__":
    demo.launch()