File size: 3,168 Bytes
08865f9 7312fe1 08865f9 de73683 08865f9 de73683 7312fe1 de73683 7312fe1 de73683 08865f9 de73683 08865f9 7312fe1 08865f9 7312fe1 08865f9 de73683 08865f9 7312fe1 de73683 08865f9 7312fe1 08865f9 0f3ffac de73683 7312fe1 08865f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | import gradio as gr
from transformers import pipeline, AutoConfig
MODEL_ID = "CIAZIZ/arabic-ner-camelbert-wikiann" # โ your model repo
ner = pipeline("token-classification", model=MODEL_ID, aggregation_strategy="simple")
# --- make sure labels are correct (no LABEL_0) ---
cfg = ner.model.config
WIKIANN_LABELS = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
if not getattr(cfg, "id2label", None) or any(str(v).startswith("LABEL_") for v in cfg.id2label.values()):
cfg.id2label = {i: lab for i, lab in enumerate(WIKIANN_LABELS)}
cfg.label2id = {lab: i for i, lab in enumerate(WIKIANN_LABELS)}
AR_MAP = {"PER": "ุดุฎุต", "ORG": "ู
ูุธู
ุฉ", "LOC": "ู
ูุงู", "MISC": "ููุงู ุนุงู
"}
COLOR_MAP = {"PER": "#2f80ed", "ORG": "#9b51e0", "LOC": "#27ae60", "MISC": "#f2994a"}
def merge_touching_spans(ents):
"""Merge adjacent spans with the SAME entity group (fixes ุฃุฑ + ุงู
+ ูู โ ุฃุฑุงู
ูู)."""
if not ents:
return ents
ents = sorted(ents, key=lambda x: x["start"])
merged = [dict(ents[0])]
for e in ents[1:]:
last = merged[-1]
g_last = last.get("entity_group") or last.get("entity")
g_cur = e.get("entity_group") or e.get("entity")
if g_last == g_cur and e["start"] == last["end"]:
# extend the previous span
last["end"] = e["end"]
# average confidence (simple and robust)
last["score"] = (float(last.get("score", 0)) + float(e.get("score", 0))) / 2.0
else:
merged.append(dict(e))
return merged
def to_segments(text, ents):
"""Return [(span_text, label_or_None), ...] for gr.HighlightedText."""
ents = sorted(ents, key=lambda x: x["start"])
segs, i = [], 0
for e in ents:
if e["start"] > i:
segs.append((text[i:e["start"]], None))
group = e.get("entity_group", e.get("entity"))
segs.append((text[e["start"]:e["end"]], group))
i = e["end"]
if i < len(text):
segs.append((text[i:], None))
return segs
def run(text: str):
if not text.strip():
return [], []
out = ner(text)
out = merge_touching_spans(out) # โ merge subword pieces
segs = to_segments(text, out)
# Build rows for the table (no 'position' column)
rows = []
for e in out:
group = e.get("entity_group", e.get("entity", ""))
rows.append([
text[e["start"]:e["end"]],
AR_MAP.get(group, group),
round(float(e["score"]), 4),
])
return segs, rows
with gr.Blocks(title="Arabic NER โ CAMeLBERT") as demo:
txt = gr.Textbox(label="ุงูุชุจ ุงููุต ุจุงูุนุฑุจู", lines=3,
value="ุฒุงุฑ ู
ุญู
ุฏ ุจู ุณูู
ุงู ู
ุฏููุฉ ูููู
ูุงูุชูู ุจููุฏ ู
ู ุดุฑูุฉ ุฃุฑุงู
ูู.")
out_ht = gr.HighlightedText(label="", color_map=COLOR_MAP) # โ empty label
out_tbl = gr.Dataframe(headers=["ุงูููุงู","ุงูููุน","ุงูุซูุฉ"], interactive=False) # โ no 'ุงูู
ูุถุน'
gr.Button("ุชุญููู ุงููุต").click(run, inputs=txt, outputs=[out_ht, out_tbl])
if __name__ == "__main__":
demo.launch()
|