groups.
"""
root = ET.fromstring(svg_text)
ids_lower = {i.lower() for i in ids}
xmin, ymin = np.inf, np.inf
xmax, ymax = -np.inf, -np.inf
def update_bbox_for_element(el):
nonlocal xmin, ymin, xmax, ymax
tag = el.tag.split("}")[-1].lower()
if tag == "path":
d = el.attrib.get("d")
if not d:
return
p = parse_path(d)
bxmin, bxmax, bymin, bymax = p.bbox()
xmin = min(xmin, bxmin)
xmax = max(xmax, bxmax)
ymin = min(ymin, bymin)
ymax = max(ymax, bymax)
elif tag == "polygon":
pts = el.attrib.get("points", "").strip()
if not pts:
return
coords = []
for chunk in pts.replace(",", " ").split():
coords.append(float(chunk))
xs = coords[0::2]
ys = coords[1::2]
xmin = min(xmin, min(xs))
xmax = max(xmax, max(xs))
ymin = min(ymin, min(ys))
ymax = max(ymax, max(ys))
for el in root.iter():
el_id = el.attrib.get("id")
if not el_id:
continue
el_id_lower = el_id.strip().lower()
if el_id_lower not in ids_lower:
continue
tag = el.tag.split("}")[-1].lower()
if tag in ("path", "polygon"):
update_bbox_for_element(el)
elif tag == "g":
# If a country is a group, include all its child shapes
for child in el.iter():
update_bbox_for_element(child)
if not np.isfinite(xmin):
return None
w = xmax - xmin
h = ymax - ymin
mx = w * margin_ratio
my = h * margin_ratio
xmin -= mx
ymin -= my
w += 2 * mx
h += 2 * my
return (float(xmin), float(ymin), float(w), float(h))
def set_viewbox(svg_text: str, viewbox):
root = ET.fromstring(svg_text)
root.attrib["viewBox"] = " ".join(str(x) for x in viewbox)
root.attrib["preserveAspectRatio"] = "xMidYMid meet"
return ET.tostring(root, encoding="unicode")
def render_map_html(conf_by_iso2, threshold):
"""
Render the map with gradient colors based on confidence scores.
Args:
conf_by_iso2: Dict mapping ISO2 codes to confidence scores
threshold: Confidence threshold for predictions
"""
if not SVG_PATH.exists():
return f"""
Map SVG not found.
Please add {SVG_PATH.as_posix()} to your Space repo.
"""
svg = SVG_PATH.read_text(encoding="utf-8")
svg_colored = recolor_svg(svg, conf_by_iso2, threshold)
# AUTO-ZOOM to Arab world (fixed set of Arab countries, not "predicted only")
vb = compute_viewbox_from_ids(svg_colored, ARAB_IDS, margin_ratio=0.10)
if vb is not None:
svg_colored = set_viewbox(svg_colored, vb)
# Add a legend showing the gradient scale
legend_html = """
Confidence Scale
Darker = closer to threshold | Brighter = higher confidence
"""
return f"""
{svg_colored}
{legend_html}
"""
# ======================
# Inference helpers
# ======================
def predict_dialects_with_confidence(text, threshold=0.3):
"""
Predict Arabic dialects for the given text (multi-label) and return confidence scores.
"""
if not text or not text.strip():
return pd.DataFrame({"Dialect": [], "Confidence": [], "Prediction": []})
enc = base_tokenizer([text], truncation=True, padding=True, max_length=128, return_tensors="pt")
input_ids = enc["input_ids"].to(DEVICE)
attention_mask = enc["attention_mask"].to(DEVICE)
with torch.no_grad():
outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits # (1, num_labels)
probs = torch.sigmoid(logits).cpu().numpy().reshape(-1)
rows = []
for dialect, p in zip(DIALECTS, probs):
rows.append({
"Dialect": dialect,
"Confidence": f"{p:.4f}",
"Prediction": "โ Valid" if p >= threshold else "โ Invalid",
})
df = pd.DataFrame(rows)
df = df.sort_values("Confidence", ascending=False, key=lambda x: x.astype(float))
return df
def predict_wrapper(model_key, text, threshold):
"""
Returns:
df (table),
summary (markdown),
map_html (HTML)
"""
load_multidialect_model(model_key)
df = predict_dialects_with_confidence(text, threshold)
predicted_dialects = df[df["Prediction"] == "โ Valid"]["Dialect"].tolist()
summary = f"**Predicted Dialects ({len(predicted_dialects)}):** {', '.join(predicted_dialects) if predicted_dialects else 'None'}"
# Build confidence dict for ALL dialects (not just predicted ones)
conf_by_iso2 = {}
for _, row in df.iterrows():
dialect = row["Dialect"]
if dialect not in DIALECT_TO_ISO2:
continue
code = DIALECT_TO_ISO2[dialect]
conf_by_iso2[code] = float(row["Confidence"])
print("conf_by_iso2:", conf_by_iso2)
map_html = render_map_html(conf_by_iso2, threshold)
return df, summary, map_html
def predict_egyptian(text, threshold=0.5):
"""
Predict whether the input is Egyptian dialect using the dedicated model.
Returns a small dataframe and a markdown summary.
"""
if not text or not text.strip():
return pd.DataFrame({"Label": [], "Confidence": []}), "**No input provided.**"
enc = egyptian_tok([text], truncation=True, padding=True, max_length=128, return_tensors="pt")
input_ids = enc["input_ids"].to(DEVICE)
attention_mask = enc["attention_mask"].to(DEVICE)
with torch.no_grad():
outputs = egyptian_model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits # (1, num_labels)
if _EGY_SIGMOID:
p = torch.sigmoid(logits).item()
else:
probs = torch.softmax(logits, dim=-1).squeeze(0)
p = probs[_EGY_POS_INDEX].item()
label = "โ Egyptian" if p >= threshold else "โ Not Egyptian"
df = pd.DataFrame([{"Label": label, "Confidence": f"{p:.4f}"}])
md = f"**Prediction:** {label} \n**Confidence:** {p:.4f} \n**Threshold:** {threshold:.2f}"
return df, md
# ======================
# Gradio UI
# ======================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ๐ LahjatBERT: Multi-Label Arabic Dialect Classifier
This demo predicts **which country-level Arabic dialects a sentence sounds natural in**.
Unlike classic "pick one dialect" systems, **a single sentence can be acceptable in multiple dialects**.
**How to use**
1) Paste an Arabic sentence
2) Adjust the **Confidence Threshold** (higher = fewer highlights)
3) Click **Predict Dialects**
**How to interpret the results**
- **Highlighted countries** = dialects predicted as *valid/acceptable* for the sentence
- **Color intensity** = confidence level (darker green = closer to threshold, brighter = higher confidence)
"""
)
with gr.Row():
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
choices=list(MODEL_CHOICES.keys()),
value="LahjatBERT",
label="Model",
info="Select which LahjatBERT variant to use for prediction."
)
text_input = gr.Textbox(
label="Arabic Text Input",
placeholder="ุฃุฏุฎู ูุตูุง ุนุฑุจููุง ููุง... ู
ุซุงู: ุดููููุ / ุฅุฒูู ูุง ุนู
ุ / ุดู ุฃุฎุจุงุฑูุ",
lines=4,
rtl=True,
)
threshold_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.3,
step=0.05,
label="Confidence Threshold",
info=(
"Dialects with confidence โฅ threshold are marked as valid. "
"Try 0.30 for broader overlap, or 0.50 for stricter predictions."
),
)
predict_button = gr.Button("๐ Predict Dialects", variant="primary")
gr.Markdown(
"""
**Tip:** If you're testing a sentence that's close to Modern Standard Arabic (MSA),
you may see **many countries highlighted**โthat's expected, because MSA-like text
can be acceptable across dialects.
"""
)
with gr.Column(scale=1):
summary_output = gr.Markdown(label="Summary")
results_output = gr.Dataframe(
label="Detailed Results",
headers=["Dialect", "Confidence", "Prediction"],
datatype=["str", "str", "str"],
)
gr.Markdown("---")
gr.Markdown(
"""
## ๐บ๏ธ Dialect Map (Zoomed to the Arab World)
The map updates after each prediction.
Green countries indicate dialects predicted as valid at your selected threshold.
The intensity of the green color reflects the confidence level.
"""
)
map_output = gr.HTML(label="Arab World Map", value=render_map_html({}, 0.3))
gr.Markdown("---")
gr.Markdown(
"""
## โจ Try these examples
These examples are meant to show **dialect overlap**:
- Some expressions are widely shared and may light up multiple regions
- Others contain strong local signals (e.g., Egyptian, Gulf/Khaleeji, Levantine, Maghrebi)
"""
)
gr.Examples(
examples=[
# Broad / MSA-like (often acceptable widely)
["ููู ุญุงููุ", 0.30],
["ุงูุณูุงู
ุนูููู
ูุฑุญู
ุฉ ุงููู ูุจุฑูุงุชู", 0.30],
# Egyptian-leaning
["ุฅุฒูู ูุง ุนู
ุ ุนุงู
ู ุฅููุ", 0.30],
["ู
ุด ูุงูู
ููู ูุฏู ุจุตุฑุงุญุฉ", 0.30],
# Gulf / Iraqi-leaning
["ุดููููุ ุดุฎุจุงุฑูุ", 0.30],
["ูููู ู
ู ุฒู
ุงูุ", 0.30],
# Levantine-leaning
["ุดู ุฃุฎุจุงุฑูุ ููููุ", 0.30],
["ุจุฏูู ุฃุฑูุญ ูููู", 0.30],
# Maghrebi-leaning (may vary depending on spelling)
["ูุงุจุงุณ ุนูููุ ูุงุด ุฑุงูุ", 0.30],
["ุจุฒุงู ุฏูุงู ุงููุงุณ ูููุถุฑู ููุง", 0.30],
# Stricter threshold examples (fewer highlights)
["ุดููููุ", 0.30],
["ุฅุฒูู ูุง ุนู
ุ", 0.30],
],
inputs=[text_input, threshold_slider],
label="Click an example to auto-fill the input",
)
predict_button.click(
fn=predict_wrapper,
inputs=[model_dropdown, text_input, threshold_slider],
outputs=[results_output, summary_output, map_output],
)
gr.Markdown(
"""
---
### Notes
- The model outputs **multi-label** predictions: more than one dialect can be valid at once.
- Countries are colored with a **gradient** based on confidence: darker green means the confidence is closer to the threshold, brighter green means higher confidence.
If you use this demo in research, please cite the accompanying paper.
"""
)
# Launch
if __name__ == "__main__":
demo.launch()
#####################
### Version 3 #######
#####################
# import json
# from pathlib import Path
# import torch
# import gradio as gr
# import pandas as pd
# from transformers import (
# AutoModelForSequenceClassification,
# AutoTokenizer,
# AutoConfig,
# )
# import re
# import xml.etree.ElementTree as ET
# import numpy as np
# import xml.etree.ElementTree as ET
# from svgpathtools import parse_path
# # ======================
# # Devices
# # ======================
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # # ======================
# # # Base multi-dialect model (B2BERT)
# # # ======================
# # base_model_name = "Mohamedelzeftawy/b2bert_baseline"
# # base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE)
# # base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# # ======================
# # Multi-dialect model registry
# # ======================
# MODEL_CHOICES = {
# "LahjatBERT": "Mohamedelzeftawy/b2bert_baseline", # default (current)
# "LahjatBERT-CL-ALDI": "Mohamedelzeftawy/b2bert_cl_aldi",
# "LahjatBERT-CL-Cardinality": "Mohamedelzeftawy/b2bert_cl_cardinalty",
# }
# # Load default model at startup (LahjatBERT)
# _current_model_key = "LahjatBERT"
# base_model_name = MODEL_CHOICES[_current_model_key]
# base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE)
# base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# # Define dialects (order must match model's label mapping)
# DIALECTS = [
# "Algeria", "Bahrain", "Egypt", "Iraq", "Jordan", "Kuwait", "Lebanon", "Libya",
# "Morocco", "Oman", "Palestine", "Qatar", "Saudi_Arabia", "Sudan", "Syria",
# "Tunisia", "UAE", "Yemen"
# ]
# # Dialect -> ISO2 country code mapping (must match SVG path ids)
# DIALECT_TO_ISO2 = {
# "Algeria": "dz",
# "Bahrain": "bh",
# "Egypt": "eg",
# "Iraq": "iq",
# "Jordan": "jo",
# "Kuwait": "kw",
# "Lebanon": "lb",
# "Libya": "ly",
# "Morocco": "ma",
# "Oman": "om",
# "Palestine": "ps",
# "Qatar": "qa",
# "Saudi_Arabia": "sa",
# "Sudan": "sd",
# "Syria": "sy",
# "Tunisia": "tn",
# "UAE": "ae",
# "Yemen": "ye",
# }
# # ======================
# # Added: Egyptian-only model
# # ======================
# egyptian_repo = "Mohamedelzeftawy/egyptian_marbert"
# egyptian_cfg = AutoConfig.from_pretrained(egyptian_repo)
# egyptian_tok = AutoTokenizer.from_pretrained(egyptian_repo)
# egyptian_model = AutoModelForSequenceClassification.from_pretrained(
# egyptian_repo, config=egyptian_cfg
# ).to(DEVICE)
# # Heuristic: if num_labels==1 -> sigmoid; else softmax and assume positive label index=1
# _EGY_SIGMOID = (egyptian_cfg.num_labels == 1)
# _EGY_POS_INDEX = 1 if egyptian_cfg.num_labels >= 2 else 0
# # ======================
# # Map rendering
# # ======================
# # Put your SVG here: repo/assets/arab_world.svg
# # IMPORTANT: each country shape must have id="EG", id="SA", id="AE", etc (ISO2)
# SVG_PATH = Path("assets/world-map.svg")
# SVG_NS = "http://www.w3.org/2000/svg"
# ET.register_namespace("", SVG_NS)
# def load_multidialect_model(model_key: str):
# """
# Load the selected multi-dialect model + tokenizer.
# Uses global variables so the rest of your pipeline stays unchanged.
# """
# global base_model, base_tokenizer, base_model_name, _current_model_key
# if model_key == _current_model_key:
# return # already loaded
# repo = MODEL_CHOICES[model_key]
# base_model_name = repo
# base_model = AutoModelForSequenceClassification.from_pretrained(repo).to(DEVICE)
# base_tokenizer = AutoTokenizer.from_pretrained(repo)
# _current_model_key = model_key
# def _merge_style(old_style: str, updates: dict) -> str:
# """
# Merge CSS style strings (e.g., "fill:#000;stroke:#fff") with updates dict.
# """
# style_map = {}
# if old_style:
# for part in old_style.split(";"):
# part = part.strip()
# if not part or ":" not in part:
# continue
# k, v = part.split(":", 1)
# style_map[k.strip()] = v.strip()
# style_map.update(updates)
# return ";".join([f"{k}:{v}" for k, v in style_map.items() if v is not None])
# def recolor_svg(svg_text: str, predicted_ids: set[str]) -> str:
# """
# Recolor SVG elements by id.
# Handles cases where a country is stored as ... ... .
# """
# root = ET.fromstring(svg_text)
# # normalize ids to lowercase for robust matching
# predicted_lower = {p.lower() for p in predicted_ids}
# base_fill = "#101418"
# base_stroke = "#2a2f3a"
# active_fill = "#2ecc71"
# active_stroke = "#ffffff"
# def apply_style(el, active: bool):
# tag = el.tag.split("}")[-1].lower()
# if tag not in ("path", "polygon"):
# return
# updates = {
# "fill": active_fill if active else base_fill,
# "stroke": active_stroke if active else base_stroke,
# "stroke-width": "0.9" if active else "0.5",
# "opacity": "1",
# }
# # remove conflicting attrs
# if "fill" in el.attrib:
# del el.attrib["fill"]
# if "stroke" in el.attrib:
# del el.attrib["stroke"]
# el.attrib["style"] = _merge_style(el.attrib.get("style", ""), updates)
# # Pass 1: default style for everything drawable (so the map stays consistent)
# for el in root.iter():
# apply_style(el, active=False)
# # Pass 2: activate predicted ids
# for el in root.iter():
# el_id = el.attrib.get("id")
# if not el_id:
# continue
# el_id_lower = el_id.strip().lower()
# if el_id_lower not in predicted_lower:
# continue
# tag = el.tag.split("}")[-1].lower()
# if tag in ("path", "polygon"):
# apply_style(el, active=True)
# elif tag == "g":
# # Important: if the country is a GROUP, color all its child shapes
# for child in el.iter():
# apply_style(child, active=True)
# return ET.tostring(root, encoding="unicode")
# ARAB_IDS = {
# "ma","dz","tn","ly","eg","sd",
# "ps","jo","lb","sy","iq",
# "sa","kw","bh","qa","ae","om","ye"
# }
# def compute_viewbox_from_ids(svg_text: str, ids: set[str], margin_ratio: float = 0.08):
# """
# Compute a tight viewBox around the given country ids based on their path geometry.
# Supports countries stored as groups.
# """
# root = ET.fromstring(svg_text)
# ids_lower = {i.lower() for i in ids}
# xmin, ymin = np.inf, np.inf
# xmax, ymax = -np.inf, -np.inf
# def update_bbox_for_element(el):
# nonlocal xmin, ymin, xmax, ymax
# tag = el.tag.split("}")[-1].lower()
# if tag == "path":
# d = el.attrib.get("d")
# if not d:
# return
# p = parse_path(d)
# bxmin, bxmax, bymin, bymax = p.bbox()
# xmin = min(xmin, bxmin)
# xmax = max(xmax, bxmax)
# ymin = min(ymin, bymin)
# ymax = max(ymax, bymax)
# elif tag == "polygon":
# pts = el.attrib.get("points", "").strip()
# if not pts:
# return
# coords = []
# for chunk in pts.replace(",", " ").split():
# coords.append(float(chunk))
# xs = coords[0::2]
# ys = coords[1::2]
# xmin = min(xmin, min(xs))
# xmax = max(xmax, max(xs))
# ymin = min(ymin, min(ys))
# ymax = max(ymax, max(ys))
# for el in root.iter():
# el_id = el.attrib.get("id")
# if not el_id:
# continue
# el_id_lower = el_id.strip().lower()
# if el_id_lower not in ids_lower:
# continue
# tag = el.tag.split("}")[-1].lower()
# if tag in ("path", "polygon"):
# update_bbox_for_element(el)
# elif tag == "g":
# # If a country is a group, include all its child shapes
# for child in el.iter():
# update_bbox_for_element(child)
# if not np.isfinite(xmin):
# return None
# w = xmax - xmin
# h = ymax - ymin
# mx = w * margin_ratio
# my = h * margin_ratio
# xmin -= mx
# ymin -= my
# w += 2 * mx
# h += 2 * my
# return (float(xmin), float(ymin), float(w), float(h))
# def set_viewbox(svg_text: str, viewbox):
# root = ET.fromstring(svg_text)
# root.attrib["viewBox"] = " ".join(str(x) for x in viewbox)
# root.attrib["preserveAspectRatio"] = "xMidYMid meet"
# return ET.tostring(root, encoding="unicode")
# def render_map_html(predicted_iso2, conf_by_iso2=None):
# if not SVG_PATH.exists():
# return f"""
#
# Map SVG not found.
# Please add {SVG_PATH.as_posix()} to your Space repo.
#
# """
# svg = SVG_PATH.read_text(encoding="utf-8")
# predicted_ids = set(predicted_iso2 or [])
# svg_colored = recolor_svg(svg, predicted_ids)
# # AUTO-ZOOM to Arab world (fixed set of Arab countries, not โpredicted onlyโ)
# vb = compute_viewbox_from_ids(svg_colored, ARAB_IDS, margin_ratio=0.10)
# if vb is not None:
# svg_colored = set_viewbox(svg_colored, vb)
# # No JS, no script tags โ just return the updated SVG
# return f"""
#
# {svg_colored}
#
# Highlighted = confidence โฅ threshold
#
#
# """
# # ======================
# # Inference helpers
# # ======================
# def predict_dialects_with_confidence(text, threshold=0.3):
# """
# Predict Arabic dialects for the given text (multi-label) and return confidence scores.
# """
# if not text or not text.strip():
# return pd.DataFrame({"Dialect": [], "Confidence": [], "Prediction": []})
# enc = base_tokenizer([text], truncation=True, padding=True, max_length=128, return_tensors="pt")
# input_ids = enc["input_ids"].to(DEVICE)
# attention_mask = enc["attention_mask"].to(DEVICE)
# with torch.no_grad():
# outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
# logits = outputs.logits # (1, num_labels)
# probs = torch.sigmoid(logits).cpu().numpy().reshape(-1)
# rows = []
# for dialect, p in zip(DIALECTS, probs):
# rows.append({
# "Dialect": dialect,
# "Confidence": f"{p:.4f}",
# "Prediction": "โ Valid" if p >= threshold else "โ Invalid",
# })
# df = pd.DataFrame(rows)
# df = df.sort_values("Confidence", ascending=False, key=lambda x: x.astype(float))
# return df
# def predict_wrapper(model_key, text, threshold):
# """
# Returns:
# df (table),
# summary (markdown),
# map_html (HTML)
# """
# load_multidialect_model(model_key)
# df = predict_dialects_with_confidence(text, threshold)
# predicted_dialects = df[df["Prediction"] == "โ Valid"]["Dialect"].tolist()
# summary = f"**Predicted Dialects ({len(predicted_dialects)}):** {', '.join(predicted_dialects) if predicted_dialects else 'None'}"
# # Build predicted ISO2 list + confidences for tooltips
# predicted_iso2 = []
# conf_by_iso2 = {}
# for _, row in df.iterrows():
# if row["Prediction"] != "โ Valid":
# continue
# dialect = row["Dialect"]
# if dialect not in DIALECT_TO_ISO2:
# continue
# code = DIALECT_TO_ISO2[dialect]
# predicted_iso2.append(code)
# conf_by_iso2[code] = float(row["Confidence"])
# print("predicted_iso2:", predicted_iso2)
# map_html = render_map_html(predicted_iso2, conf_by_iso2)
# return df, summary, map_html
# def predict_egyptian(text, threshold=0.5):
# """
# Predict whether the input is Egyptian dialect using the dedicated model.
# Returns a small dataframe and a markdown summary.
# """
# if not text or not text.strip():
# return pd.DataFrame({"Label": [], "Confidence": []}), "**No input provided.**"
# enc = egyptian_tok([text], truncation=True, padding=True, max_length=128, return_tensors="pt")
# input_ids = enc["input_ids"].to(DEVICE)
# attention_mask = enc["attention_mask"].to(DEVICE)
# with torch.no_grad():
# outputs = egyptian_model(input_ids=input_ids, attention_mask=attention_mask)
# logits = outputs.logits # (1, num_labels)
# if _EGY_SIGMOID:
# p = torch.sigmoid(logits).item()
# else:
# probs = torch.softmax(logits, dim=-1).squeeze(0)
# p = probs[_EGY_POS_INDEX].item()
# label = "โ Egyptian" if p >= threshold else "โ Not Egyptian"
# df = pd.DataFrame([{"Label": label, "Confidence": f"{p:.4f}"}])
# md = f"**Prediction:** {label} \n**Confidence:** {p:.4f} \n**Threshold:** {threshold:.2f}"
# return df, md
# # ======================
# # Gradio UI
# # ======================
# with gr.Blocks(theme=gr.themes.Soft()) as demo:
# gr.Markdown(
# """
# # ๐ LahjatBERT: Multi-Label Arabic Dialect Classifier
# This demo predicts **which country-level Arabic dialects a sentence sounds natural in**.
# Unlike classic โpick one dialectโ systems, **a single sentence can be acceptable in multiple dialects**.
# **How to use**
# 1) Paste an Arabic sentence
# 2) Adjust the **Confidence Threshold** (higher = fewer highlights)
# 3) Click **Predict Dialects**
# **How to interpret the results**
# - **Highlighted countries** = dialects predicted as *valid/acceptable* for the sentence
# """
# )
# with gr.Row():
# with gr.Column(scale=1):
# model_dropdown = gr.Dropdown(
# choices=list(MODEL_CHOICES.keys()),
# value="LahjatBERT",
# label="Model",
# info="Select which LahjatBERT variant to use for prediction."
# )
# text_input = gr.Textbox(
# label="Arabic Text Input",
# placeholder="ุฃุฏุฎู ูุตูุง ุนุฑุจููุง ููุง... ู
ุซุงู: ุดููููุ / ุฅุฒูู ูุง ุนู
ุ / ุดู ุฃุฎุจุงุฑูุ",
# lines=4,
# rtl=True,
# )
# threshold_slider = gr.Slider(
# minimum=0.1,
# maximum=0.9,
# value=0.3,
# step=0.05,
# label="Confidence Threshold",
# info=(
# "Dialects with confidence โฅ threshold are marked as valid. "
# "Try 0.30 for broader overlap, or 0.50 for stricter predictions."
# ),
# )
# predict_button = gr.Button("๐ Predict Dialects", variant="primary")
# gr.Markdown(
# """
# **Tip:** If youโre testing a sentence thatโs close to Modern Standard Arabic (MSA),
# you may see **many countries highlighted**โthatโs expected, because MSA-like text
# can be acceptable across dialects.
# """
# )
# with gr.Column(scale=1):
# summary_output = gr.Markdown(label="Summary")
# results_output = gr.Dataframe(
# label="Detailed Results",
# headers=["Dialect", "Confidence", "Prediction"],
# datatype=["str", "str", "str"],
# )
# gr.Markdown("---")
# gr.Markdown(
# """
# ## ๐บ๏ธ Dialect Map (Zoomed to the Arab World)
# The map updates after each prediction.
# Green countries indicate dialects predicted as valid at your selected threshold.
# """
# )
# map_output = gr.HTML(label="Arab World Map", value=render_map_html([], {}))
# gr.Markdown("---")
# gr.Markdown(
# """
# ## โจ Try these examples
# These examples are meant to show **dialect overlap**:
# - Some expressions are widely shared and may light up multiple regions
# - Others contain strong local signals (e.g., Egyptian, Gulf/Khaleeji, Levantine, Maghrebi)
# """
# )
# gr.Examples(
# examples=[
# # Broad / MSA-like (often acceptable widely)
# ["ููู ุญุงููุ", 0.30],
# ["ุงูุณูุงู
ุนูููู
ูุฑุญู
ุฉ ุงููู ูุจุฑูุงุชู", 0.30],
# # Egyptian-leaning
# ["ุฅุฒูู ูุง ุนู
ุ ุนุงู
ู ุฅููุ", 0.30],
# ["ู
ุด ูุงูู
ููู ูุฏู ุจุตุฑุงุญุฉ", 0.30],
# # Gulf / Iraqi-leaning
# ["ุดููููุ ุดุฎุจุงุฑูุ", 0.30],
# ["ูููู ู
ู ุฒู
ุงูุ", 0.30],
# # Levantine-leaning
# ["ุดู ุฃุฎุจุงุฑูุ ููููุ", 0.30],
# ["ุจุฏูู ุฃุฑูุญ ูููู", 0.30],
# # Maghrebi-leaning (may vary depending on spelling)
# ["ูุงุจุงุณ ุนูููุ ูุงุด ุฑุงูุ", 0.30],
# ["ุจุฒุงู ุฏูุงู ุงููุงุณ ูููุถุฑู ููุง", 0.30],
# # Stricter threshold examples (fewer highlights)
# ["ุดููููุ", 0.30],
# ["ุฅุฒูู ูุง ุนู
ุ", 0.30],
# ],
# inputs=[text_input, threshold_slider],
# label="Click an example to auto-fill the input",
# )
# predict_button.click(
# fn=predict_wrapper,
# inputs=[model_dropdown, text_input, threshold_slider],
# outputs=[results_output, summary_output, map_output],
# )
# gr.Markdown(
# """
# ---
# ### Notes
# - The model outputs **multi-label** predictions: more than one dialect can be valid at once.
# If you use this demo in research, please cite the accompanying paper.
# """
# )
# # Launch
# if __name__ == "__main__":
# demo.launch()