AER-bert / app.py
CIAZIZ's picture
Title edit
0f3ffac
import gradio as gr
from transformers import pipeline, AutoConfig
MODEL_ID = "CIAZIZ/arabic-ner-camelbert-wikiann" # โ† your model repo
ner = pipeline("token-classification", model=MODEL_ID, aggregation_strategy="simple")
# --- make sure labels are correct (no LABEL_0) ---
cfg = ner.model.config
WIKIANN_LABELS = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
if not getattr(cfg, "id2label", None) or any(str(v).startswith("LABEL_") for v in cfg.id2label.values()):
cfg.id2label = {i: lab for i, lab in enumerate(WIKIANN_LABELS)}
cfg.label2id = {lab: i for i, lab in enumerate(WIKIANN_LABELS)}
AR_MAP = {"PER": "ุดุฎุต", "ORG": "ู…ู†ุธู…ุฉ", "LOC": "ู…ูƒุงู†", "MISC": "ูƒูŠุงู† ุนุงู…"}
COLOR_MAP = {"PER": "#2f80ed", "ORG": "#9b51e0", "LOC": "#27ae60", "MISC": "#f2994a"}
def merge_touching_spans(ents):
"""Merge adjacent spans with the SAME entity group (fixes ุฃุฑ + ุงู… + ูƒูˆ โ†’ ุฃุฑุงู…ูƒูˆ)."""
if not ents:
return ents
ents = sorted(ents, key=lambda x: x["start"])
merged = [dict(ents[0])]
for e in ents[1:]:
last = merged[-1]
g_last = last.get("entity_group") or last.get("entity")
g_cur = e.get("entity_group") or e.get("entity")
if g_last == g_cur and e["start"] == last["end"]:
# extend the previous span
last["end"] = e["end"]
# average confidence (simple and robust)
last["score"] = (float(last.get("score", 0)) + float(e.get("score", 0))) / 2.0
else:
merged.append(dict(e))
return merged
def to_segments(text, ents):
"""Return [(span_text, label_or_None), ...] for gr.HighlightedText."""
ents = sorted(ents, key=lambda x: x["start"])
segs, i = [], 0
for e in ents:
if e["start"] > i:
segs.append((text[i:e["start"]], None))
group = e.get("entity_group", e.get("entity"))
segs.append((text[e["start"]:e["end"]], group))
i = e["end"]
if i < len(text):
segs.append((text[i:], None))
return segs
def run(text: str):
if not text.strip():
return [], []
out = ner(text)
out = merge_touching_spans(out) # โ† merge subword pieces
segs = to_segments(text, out)
# Build rows for the table (no 'position' column)
rows = []
for e in out:
group = e.get("entity_group", e.get("entity", ""))
rows.append([
text[e["start"]:e["end"]],
AR_MAP.get(group, group),
round(float(e["score"]), 4),
])
return segs, rows
with gr.Blocks(title="Arabic NER โ€” CAMeLBERT") as demo:
txt = gr.Textbox(label="ุงูƒุชุจ ุงู„ู†ุต ุจุงู„ุนุฑุจูŠ", lines=3,
value="ุฒุงุฑ ู…ุญู…ุฏ ุจู† ุณู„ู…ุงู† ู…ุฏูŠู†ุฉ ู†ูŠูˆู… ูˆุงู„ุชู‚ู‰ ุจูˆูุฏ ู…ู† ุดุฑูƒุฉ ุฃุฑุงู…ูƒูˆ.")
out_ht = gr.HighlightedText(label="", color_map=COLOR_MAP) # โ† empty label
out_tbl = gr.Dataframe(headers=["ุงู„ูƒูŠุงู†","ุงู„ู†ูˆุน","ุงู„ุซู‚ุฉ"], interactive=False) # โ† no 'ุงู„ู…ูˆุถุน'
gr.Button("ุชุญู„ูŠู„ ุงู„ู†ุต").click(run, inputs=txt, outputs=[out_ht, out_tbl])
if __name__ == "__main__":
demo.launch()