|
|
import gradio as gr |
|
|
from transformers import pipeline, AutoConfig |
|
|
|
|
|
MODEL_ID = "CIAZIZ/arabic-ner-camelbert-wikiann" |
|
|
ner = pipeline("token-classification", model=MODEL_ID, aggregation_strategy="simple") |
|
|
|
|
|
|
|
|
cfg = ner.model.config |
|
|
WIKIANN_LABELS = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"] |
|
|
if not getattr(cfg, "id2label", None) or any(str(v).startswith("LABEL_") for v in cfg.id2label.values()): |
|
|
cfg.id2label = {i: lab for i, lab in enumerate(WIKIANN_LABELS)} |
|
|
cfg.label2id = {lab: i for i, lab in enumerate(WIKIANN_LABELS)} |
|
|
|
|
|
AR_MAP = {"PER": "ุดุฎุต", "ORG": "ู
ูุธู
ุฉ", "LOC": "ู
ูุงู", "MISC": "ููุงู ุนุงู
"} |
|
|
COLOR_MAP = {"PER": "#2f80ed", "ORG": "#9b51e0", "LOC": "#27ae60", "MISC": "#f2994a"} |
|
|
|
|
|
def merge_touching_spans(ents): |
|
|
"""Merge adjacent spans with the SAME entity group (fixes ุฃุฑ + ุงู
+ ูู โ ุฃุฑุงู
ูู).""" |
|
|
if not ents: |
|
|
return ents |
|
|
ents = sorted(ents, key=lambda x: x["start"]) |
|
|
merged = [dict(ents[0])] |
|
|
for e in ents[1:]: |
|
|
last = merged[-1] |
|
|
g_last = last.get("entity_group") or last.get("entity") |
|
|
g_cur = e.get("entity_group") or e.get("entity") |
|
|
if g_last == g_cur and e["start"] == last["end"]: |
|
|
|
|
|
last["end"] = e["end"] |
|
|
|
|
|
last["score"] = (float(last.get("score", 0)) + float(e.get("score", 0))) / 2.0 |
|
|
else: |
|
|
merged.append(dict(e)) |
|
|
return merged |
|
|
|
|
|
def to_segments(text, ents): |
|
|
"""Return [(span_text, label_or_None), ...] for gr.HighlightedText.""" |
|
|
ents = sorted(ents, key=lambda x: x["start"]) |
|
|
segs, i = [], 0 |
|
|
for e in ents: |
|
|
if e["start"] > i: |
|
|
segs.append((text[i:e["start"]], None)) |
|
|
group = e.get("entity_group", e.get("entity")) |
|
|
segs.append((text[e["start"]:e["end"]], group)) |
|
|
i = e["end"] |
|
|
if i < len(text): |
|
|
segs.append((text[i:], None)) |
|
|
return segs |
|
|
|
|
|
def run(text: str): |
|
|
if not text.strip(): |
|
|
return [], [] |
|
|
out = ner(text) |
|
|
out = merge_touching_spans(out) |
|
|
segs = to_segments(text, out) |
|
|
|
|
|
|
|
|
rows = [] |
|
|
for e in out: |
|
|
group = e.get("entity_group", e.get("entity", "")) |
|
|
rows.append([ |
|
|
text[e["start"]:e["end"]], |
|
|
AR_MAP.get(group, group), |
|
|
round(float(e["score"]), 4), |
|
|
]) |
|
|
return segs, rows |
|
|
|
|
|
with gr.Blocks(title="Arabic NER โ CAMeLBERT") as demo: |
|
|
txt = gr.Textbox(label="ุงูุชุจ ุงููุต ุจุงูุนุฑุจู", lines=3, |
|
|
value="ุฒุงุฑ ู
ุญู
ุฏ ุจู ุณูู
ุงู ู
ุฏููุฉ ูููู
ูุงูุชูู ุจููุฏ ู
ู ุดุฑูุฉ ุฃุฑุงู
ูู.") |
|
|
out_ht = gr.HighlightedText(label="", color_map=COLOR_MAP) |
|
|
out_tbl = gr.Dataframe(headers=["ุงูููุงู","ุงูููุน","ุงูุซูุฉ"], interactive=False) |
|
|
gr.Button("ุชุญููู ุงููุต").click(run, inputs=txt, outputs=[out_ht, out_tbl]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|