LahjatBERT / app.py
AHAAM's picture
sharper transition
39c1eba
#####################
### Version 1 #######
#####################
# import torch
# import gradio as gr
# from transformers import AutoModelForSequenceClassification, AutoTokenizer
# import pandas as pd
# # Load the model and tokenizer
# model_name = "AHAAM/B2BERT"
# model = AutoModelForSequenceClassification.from_pretrained(model_name)
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# # Define dialects
# DIALECTS = [
# "Algeria", "Bahrain", "Egypt", "Iraq", "Jordan", "Kuwait", "Lebanon", "Libya",
# "Morocco", "Oman", "Palestine", "Qatar", "Saudi_Arabia", "Sudan", "Syria",
# "Tunisia", "UAE", "Yemen"
# ]
# def predict_dialects_with_confidence(text, threshold=0.3):
# """
# Predict Arabic dialects for the given text and return confidence scores.
# Args:
# text: Input Arabic text
# threshold: Confidence threshold for classification (default 0.3)
# Returns:
# DataFrame with dialects and their confidence scores
# """
# if not text.strip():
# return pd.DataFrame({"Dialect": [], "Confidence": [], "Prediction": []})
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# # Tokenize input
# encodings = tokenizer(
# [text],
# truncation=True,
# padding=True,
# max_length=128,
# return_tensors="pt"
# )
# input_ids = encodings["input_ids"].to(device)
# attention_mask = encodings["attention_mask"].to(device)
# # Get predictions
# with torch.no_grad():
# outputs = model(input_ids=input_ids, attention_mask=attention_mask)
# logits = outputs.logits
# # Calculate probabilities
# probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1)
# # Create results dataframe
# results = []
# for dialect, prob in zip(DIALECTS, probabilities):
# prediction = "โœ“ Valid" if prob >= threshold else "โœ— Invalid"
# results.append({
# "Dialect": dialect,
# "Confidence": f"{prob:.4f}",
# "Prediction": prediction
# })
# # Sort by confidence (descending)
# df = pd.DataFrame(results)
# df = df.sort_values("Confidence", ascending=False, key=lambda x: x.astype(float))
# return df
# def predict_wrapper(text, threshold):
# """Wrapper function for Gradio interface"""
# df = predict_dialects_with_confidence(text, threshold)
# # Also create a summary of predicted dialects
# predicted = df[df["Prediction"] == "โœ“ Valid"]["Dialect"].tolist()
# summary = f"**Predicted Dialects ({len(predicted)}):** {', '.join(predicted) if predicted else 'None'}"
# return df, summary
# # Create Gradio interface
# with gr.Blocks(theme=gr.themes.Soft()) as demo:
# gr.Markdown(
# """
# # ๐ŸŒ B2BERT Arabic Dialect Classifier
# This model identifies which Arabic dialects are valid for a given text input.
# Enter Arabic text below to see the dialect predictions and confidence scores.
# **Supported Dialects:** Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya,
# Morocco, Oman, Palestine, Qatar, Saudi Arabia, Sudan, Syria, Tunisia, UAE, Yemen
# """
# )
# with gr.Row():
# with gr.Column():
# text_input = gr.Textbox(
# label="Arabic Text Input",
# placeholder="ุฃุฏุฎู„ ุงู„ู†ุต ุงู„ุนุฑุจูŠ ู‡ู†ุง... (e.g., ูƒูŠู ุญุงู„ูƒุŸ)",
# lines=3,
# rtl=True
# )
# threshold_slider = gr.Slider(
# minimum=0.1,
# maximum=0.9,
# value=0.3,
# step=0.05,
# label="Confidence Threshold",
# info="Dialects with confidence above this threshold will be marked as valid"
# )
# predict_button = gr.Button("๐Ÿ” Predict Dialects", variant="primary")
# with gr.Column():
# summary_output = gr.Markdown(label="Summary")
# results_output = gr.Dataframe(
# label="Detailed Results",
# headers=["Dialect", "Confidence", "Prediction"],
# datatype=["str", "str", "str"]
# )
# # Examples
# gr.Examples(
# examples=[
# ["ูƒูŠู ุญุงู„ูƒุŸ", 0.3],
# ["ุดู„ูˆู†ูƒุŸ", 0.3],
# ["ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ", 0.3],
# ["ุดูˆ ุฃุฎุจุงุฑูƒุŸ", 0.3],
# ],
# inputs=[text_input, threshold_slider],
# label="Try these examples"
# )
# # Connect button to function
# predict_button.click(
# fn=predict_wrapper,
# inputs=[text_input, threshold_slider],
# outputs=[results_output, summary_output]
# )
# gr.Markdown(
# """
# ---
# **Model:** [AHAAM/B2BERT](https://huggingface.co/AHAAM/B2BERT)
# **Note:** The model uses a multi-label classification approach where each dialect is
# independently evaluated. A single text can be valid in multiple dialects.
# """
# )
# # Launch the app
# if __name__ == "__main__":
# demo.launch()
#####################
### Version 2 #######
#####################
import json
from pathlib import Path
import torch
import gradio as gr
import pandas as pd
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
AutoConfig,
)
import re
import xml.etree.ElementTree as ET
import numpy as np
from svgpathtools import parse_path
# ======================
# Devices
# ======================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ======================
# Multi-dialect model registry
# ======================
MODEL_CHOICES = {
"LahjatBERT": "Mohamedelzeftawy/b2bert_baseline",
"LahjatBERT-CL-ALDI": "Mohamedelzeftawy/b2bert_cl_aldi",
"LahjatBERT-CL-Cardinality": "Mohamedelzeftawy/b2bert_cl_cardinalty",
}
# Load default model at startup (LahjatBERT)
_current_model_key = "LahjatBERT"
base_model_name = MODEL_CHOICES[_current_model_key]
base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE)
base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# Define dialects (order must match model's label mapping)
DIALECTS = [
"Algeria", "Bahrain", "Egypt", "Iraq", "Jordan", "Kuwait", "Lebanon", "Libya",
"Morocco", "Oman", "Palestine", "Qatar", "Saudi_Arabia", "Sudan", "Syria",
"Tunisia", "UAE", "Yemen"
]
# Dialect -> ISO2 country code mapping (must match SVG path ids)
DIALECT_TO_ISO2 = {
"Algeria": "dz",
"Bahrain": "bh",
"Egypt": "eg",
"Iraq": "iq",
"Jordan": "jo",
"Kuwait": "kw",
"Lebanon": "lb",
"Libya": "ly",
"Morocco": "ma",
"Oman": "om",
"Palestine": "ps",
"Qatar": "qa",
"Saudi_Arabia": "sa",
"Sudan": "sd",
"Syria": "sy",
"Tunisia": "tn",
"UAE": "ae",
"Yemen": "ye",
}
# ======================
# Added: Egyptian-only model
# ======================
egyptian_repo = "Mohamedelzeftawy/egyptian_marbert"
egyptian_cfg = AutoConfig.from_pretrained(egyptian_repo)
egyptian_tok = AutoTokenizer.from_pretrained(egyptian_repo)
egyptian_model = AutoModelForSequenceClassification.from_pretrained(
egyptian_repo, config=egyptian_cfg
).to(DEVICE)
# Heuristic: if num_labels==1 -> sigmoid; else softmax and assume positive label index=1
_EGY_SIGMOID = (egyptian_cfg.num_labels == 1)
_EGY_POS_INDEX = 1 if egyptian_cfg.num_labels >= 2 else 0
# ======================
# Map rendering
# ======================
# Put your SVG here: repo/assets/arab_world.svg
# IMPORTANT: each country shape must have id="EG", id="SA", id="AE", etc (ISO2)
SVG_PATH = Path("assets/world-map.svg")
SVG_NS = "http://www.w3.org/2000/svg"
ET.register_namespace("", SVG_NS)
def load_multidialect_model(model_key: str):
"""
Load the selected multi-dialect model + tokenizer.
Uses global variables so the rest of your pipeline stays unchanged.
"""
global base_model, base_tokenizer, base_model_name, _current_model_key
if model_key == _current_model_key:
return # already loaded
repo = MODEL_CHOICES[model_key]
base_model_name = repo
base_model = AutoModelForSequenceClassification.from_pretrained(repo).to(DEVICE)
base_tokenizer = AutoTokenizer.from_pretrained(repo)
_current_model_key = model_key
def _merge_style(old_style: str, updates: dict) -> str:
"""
Merge CSS style strings (e.g., "fill:#000;stroke:#fff") with updates dict.
"""
style_map = {}
if old_style:
for part in old_style.split(";"):
part = part.strip()
if not part or ":" not in part:
continue
k, v = part.split(":", 1)
style_map[k.strip()] = v.strip()
style_map.update(updates)
return ";".join([f"{k}:{v}" for k, v in style_map.items() if v is not None])
def get_gradient_color(confidence: float, threshold: float) -> str:
"""
Generate a gradient color from dark to bright green based on confidence.
Uses a power function for sharper contrast.
Args:
confidence: Probability score (0-1)
threshold: Minimum threshold for prediction
Returns:
Hex color string
"""
if confidence < threshold:
return "#101418" # base color (no prediction)
# Normalize confidence to [0, 1] range starting from threshold
normalized = (confidence - threshold) / (1.0 - threshold)
# Apply power function for sharper gradient (higher power = sharper)
# You can adjust the exponent: 2.0 = moderate sharp, 3.0 = very sharp
normalized = normalized ** 2.5
# Define gradient from dark green to bright green
# Dark green: #1a5e3a, Bright green: #2ecc71
r_start, g_start, b_start = 0x1a, 0x5e, 0x3a # dark green
r_end, g_end, b_end = 0x2e, 0xcc, 0x71 # bright green
# Linear interpolation
r = int(r_start + (r_end - r_start) * normalized)
g = int(g_start + (g_end - g_start) * normalized)
b = int(b_start + (b_end - b_start) * normalized)
return f"#{r:02x}{g:02x}{b:02x}"
def recolor_svg(svg_text: str, conf_by_iso2: dict, threshold: float) -> str:
"""
Recolor SVG elements by id with gradient colors based on confidence.
Handles cases where a country is stored as <g id="ye"> ... <path/> ... </g>.
Args:
svg_text: SVG content as string
conf_by_iso2: Dict mapping ISO2 codes to confidence scores
threshold: Confidence threshold for predictions
"""
root = ET.fromstring(svg_text)
# normalize ids to lowercase for robust matching
conf_by_iso2_lower = {k.lower(): v for k, v in (conf_by_iso2 or {}).items()}
base_fill = "#101418"
base_stroke = "#2a2f3a"
active_stroke = "#ffffff"
def apply_style(el, confidence: float = None):
tag = el.tag.split("}")[-1].lower()
if tag not in ("path", "polygon"):
return
if confidence is not None and confidence >= threshold:
fill_color = get_gradient_color(confidence, threshold)
stroke_color = active_stroke
stroke_width = "0.9"
opacity = "1"
else:
fill_color = base_fill
stroke_color = base_stroke
stroke_width = "0.5"
opacity = "1"
updates = {
"fill": fill_color,
"stroke": stroke_color,
"stroke-width": stroke_width,
"opacity": opacity,
}
# remove conflicting attrs
if "fill" in el.attrib:
del el.attrib["fill"]
if "stroke" in el.attrib:
del el.attrib["stroke"]
el.attrib["style"] = _merge_style(el.attrib.get("style", ""), updates)
# Pass 1: default style for everything drawable
for el in root.iter():
apply_style(el, confidence=None)
# Pass 2: apply gradient colors based on confidence
for el in root.iter():
el_id = el.attrib.get("id")
if not el_id:
continue
el_id_lower = el_id.strip().lower()
confidence = conf_by_iso2_lower.get(el_id_lower)
if confidence is None:
continue
tag = el.tag.split("}")[-1].lower()
if tag in ("path", "polygon"):
apply_style(el, confidence=confidence)
elif tag == "g":
# If the country is a GROUP, color all its child shapes
for child in el.iter():
apply_style(child, confidence=confidence)
return ET.tostring(root, encoding="unicode")
ARAB_IDS = {
"ma","dz","tn","ly","eg","sd",
"ps","jo","lb","sy","iq",
"sa","kw","bh","qa","ae","om","ye"
}
def compute_viewbox_from_ids(svg_text: str, ids: set[str], margin_ratio: float = 0.08):
"""
Compute a tight viewBox around the given country ids based on their path geometry.
Supports countries stored as <g id="..."> groups.
"""
root = ET.fromstring(svg_text)
ids_lower = {i.lower() for i in ids}
xmin, ymin = np.inf, np.inf
xmax, ymax = -np.inf, -np.inf
def update_bbox_for_element(el):
nonlocal xmin, ymin, xmax, ymax
tag = el.tag.split("}")[-1].lower()
if tag == "path":
d = el.attrib.get("d")
if not d:
return
p = parse_path(d)
bxmin, bxmax, bymin, bymax = p.bbox()
xmin = min(xmin, bxmin)
xmax = max(xmax, bxmax)
ymin = min(ymin, bymin)
ymax = max(ymax, bymax)
elif tag == "polygon":
pts = el.attrib.get("points", "").strip()
if not pts:
return
coords = []
for chunk in pts.replace(",", " ").split():
coords.append(float(chunk))
xs = coords[0::2]
ys = coords[1::2]
xmin = min(xmin, min(xs))
xmax = max(xmax, max(xs))
ymin = min(ymin, min(ys))
ymax = max(ymax, max(ys))
for el in root.iter():
el_id = el.attrib.get("id")
if not el_id:
continue
el_id_lower = el_id.strip().lower()
if el_id_lower not in ids_lower:
continue
tag = el.tag.split("}")[-1].lower()
if tag in ("path", "polygon"):
update_bbox_for_element(el)
elif tag == "g":
# If a country is a group, include all its child shapes
for child in el.iter():
update_bbox_for_element(child)
if not np.isfinite(xmin):
return None
w = xmax - xmin
h = ymax - ymin
mx = w * margin_ratio
my = h * margin_ratio
xmin -= mx
ymin -= my
w += 2 * mx
h += 2 * my
return (float(xmin), float(ymin), float(w), float(h))
def set_viewbox(svg_text: str, viewbox):
root = ET.fromstring(svg_text)
root.attrib["viewBox"] = " ".join(str(x) for x in viewbox)
root.attrib["preserveAspectRatio"] = "xMidYMid meet"
return ET.tostring(root, encoding="unicode")
def render_map_html(conf_by_iso2, threshold):
"""
Render the map with gradient colors based on confidence scores.
Args:
conf_by_iso2: Dict mapping ISO2 codes to confidence scores
threshold: Confidence threshold for predictions
"""
if not SVG_PATH.exists():
return f"""
<div style="padding:12px; border:1px solid #ddd; border-radius:10px;">
<b>Map SVG not found.</b><br/>
Please add <code>{SVG_PATH.as_posix()}</code> to your Space repo.
</div>
"""
svg = SVG_PATH.read_text(encoding="utf-8")
svg_colored = recolor_svg(svg, conf_by_iso2, threshold)
# AUTO-ZOOM to Arab world (fixed set of Arab countries, not "predicted only")
vb = compute_viewbox_from_ids(svg_colored, ARAB_IDS, margin_ratio=0.10)
if vb is not None:
svg_colored = set_viewbox(svg_colored, vb)
# Add a legend showing the gradient scale
legend_html = """
<div style="margin-top: 12px; padding: 12px; background: #1a1d23; border-radius: 8px;">
<div style="font-size: 13px; color: #e0e0e0; margin-bottom: 8px; font-weight: 500;">
Confidence Scale
</div>
<div style="display: flex; align-items: center; gap: 8px;">
<span style="font-size: 11px; color: #999;">Low</span>
<div style="flex: 1; height: 20px; background: linear-gradient(to right, #1a5e3a, #2ecc71); border-radius: 4px;"></div>
<span style="font-size: 11px; color: #999;">High</span>
</div>
<div style="margin-top: 6px; font-size: 11px; color: #888;">
Darker = closer to threshold | Brighter = higher confidence
</div>
</div>
"""
return f"""
<div style="width:100%; max-width: 950px; margin: 0 auto;">
{svg_colored}
{legend_html}
</div>
"""
# ======================
# Inference helpers
# ======================
def predict_dialects_with_confidence(text, threshold=0.3):
"""
Predict Arabic dialects for the given text (multi-label) and return confidence scores.
"""
if not text or not text.strip():
return pd.DataFrame({"Dialect": [], "Confidence": [], "Prediction": []})
enc = base_tokenizer([text], truncation=True, padding=True, max_length=128, return_tensors="pt")
input_ids = enc["input_ids"].to(DEVICE)
attention_mask = enc["attention_mask"].to(DEVICE)
with torch.no_grad():
outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits # (1, num_labels)
probs = torch.sigmoid(logits).cpu().numpy().reshape(-1)
rows = []
for dialect, p in zip(DIALECTS, probs):
rows.append({
"Dialect": dialect,
"Confidence": f"{p:.4f}",
"Prediction": "โœ“ Valid" if p >= threshold else "โœ— Invalid",
})
df = pd.DataFrame(rows)
df = df.sort_values("Confidence", ascending=False, key=lambda x: x.astype(float))
return df
def predict_wrapper(model_key, text, threshold):
"""
Returns:
df (table),
summary (markdown),
map_html (HTML)
"""
load_multidialect_model(model_key)
df = predict_dialects_with_confidence(text, threshold)
predicted_dialects = df[df["Prediction"] == "โœ“ Valid"]["Dialect"].tolist()
summary = f"**Predicted Dialects ({len(predicted_dialects)}):** {', '.join(predicted_dialects) if predicted_dialects else 'None'}"
# Build confidence dict for ALL dialects (not just predicted ones)
conf_by_iso2 = {}
for _, row in df.iterrows():
dialect = row["Dialect"]
if dialect not in DIALECT_TO_ISO2:
continue
code = DIALECT_TO_ISO2[dialect]
conf_by_iso2[code] = float(row["Confidence"])
print("conf_by_iso2:", conf_by_iso2)
map_html = render_map_html(conf_by_iso2, threshold)
return df, summary, map_html
def predict_egyptian(text, threshold=0.5):
"""
Predict whether the input is Egyptian dialect using the dedicated model.
Returns a small dataframe and a markdown summary.
"""
if not text or not text.strip():
return pd.DataFrame({"Label": [], "Confidence": []}), "**No input provided.**"
enc = egyptian_tok([text], truncation=True, padding=True, max_length=128, return_tensors="pt")
input_ids = enc["input_ids"].to(DEVICE)
attention_mask = enc["attention_mask"].to(DEVICE)
with torch.no_grad():
outputs = egyptian_model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits # (1, num_labels)
if _EGY_SIGMOID:
p = torch.sigmoid(logits).item()
else:
probs = torch.softmax(logits, dim=-1).squeeze(0)
p = probs[_EGY_POS_INDEX].item()
label = "โœ“ Egyptian" if p >= threshold else "โœ— Not Egyptian"
df = pd.DataFrame([{"Label": label, "Confidence": f"{p:.4f}"}])
md = f"**Prediction:** {label} \n**Confidence:** {p:.4f} \n**Threshold:** {threshold:.2f}"
return df, md
# ======================
# Gradio UI
# ======================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ๐ŸŒ LahjatBERT: Multi-Label Arabic Dialect Classifier
This demo predicts **which country-level Arabic dialects a sentence sounds natural in**.
Unlike classic "pick one dialect" systems, **a single sentence can be acceptable in multiple dialects**.
**How to use**
1) Paste an Arabic sentence
2) Adjust the **Confidence Threshold** (higher = fewer highlights)
3) Click **Predict Dialects**
**How to interpret the results**
- **Highlighted countries** = dialects predicted as *valid/acceptable* for the sentence
- **Color intensity** = confidence level (darker green = closer to threshold, brighter = higher confidence)
"""
)
with gr.Row():
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
choices=list(MODEL_CHOICES.keys()),
value="LahjatBERT",
label="Model",
info="Select which LahjatBERT variant to use for prediction."
)
text_input = gr.Textbox(
label="Arabic Text Input",
placeholder="ุฃุฏุฎู„ ู†ุตู‹ุง ุนุฑุจูŠู‹ุง ู‡ู†ุง... ู…ุซุงู„: ุดู„ูˆู†ูƒุŸ / ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ / ุดูˆ ุฃุฎุจุงุฑูƒุŸ",
lines=4,
rtl=True,
)
threshold_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.3,
step=0.05,
label="Confidence Threshold",
info=(
"Dialects with confidence โ‰ฅ threshold are marked as valid. "
"Try 0.30 for broader overlap, or 0.50 for stricter predictions."
),
)
predict_button = gr.Button("๐Ÿ” Predict Dialects", variant="primary")
gr.Markdown(
"""
**Tip:** If you're testing a sentence that's close to Modern Standard Arabic (MSA),
you may see **many countries highlighted**โ€”that's expected, because MSA-like text
can be acceptable across dialects.
"""
)
with gr.Column(scale=1):
summary_output = gr.Markdown(label="Summary")
results_output = gr.Dataframe(
label="Detailed Results",
headers=["Dialect", "Confidence", "Prediction"],
datatype=["str", "str", "str"],
)
gr.Markdown("---")
gr.Markdown(
"""
## ๐Ÿ—บ๏ธ Dialect Map (Zoomed to the Arab World)
The map updates after each prediction.
Green countries indicate dialects predicted as valid at your selected threshold.
The intensity of the green color reflects the confidence level.
"""
)
map_output = gr.HTML(label="Arab World Map", value=render_map_html({}, 0.3))
gr.Markdown("---")
gr.Markdown(
"""
## โœจ Try these examples
These examples are meant to show **dialect overlap**:
- Some expressions are widely shared and may light up multiple regions
- Others contain strong local signals (e.g., Egyptian, Gulf/Khaleeji, Levantine, Maghrebi)
"""
)
gr.Examples(
examples=[
# Broad / MSA-like (often acceptable widely)
["ูƒูŠู ุญุงู„ูƒุŸ", 0.30],
["ุงู„ุณู„ุงู… ุนู„ูŠูƒู… ูˆุฑุญู…ุฉ ุงู„ู„ู‡ ูˆุจุฑูƒุงุชู‡", 0.30],
# Egyptian-leaning
["ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ ุนุงู…ู„ ุฅูŠู‡ุŸ", 0.30],
["ู…ุด ูุงู‡ู… ู„ูŠู‡ ูƒุฏู‡ ุจุตุฑุงุญุฉ", 0.30],
# Gulf / Iraqi-leaning
["ุดู„ูˆู†ูƒุŸ ุดุฎุจุงุฑูƒุŸ", 0.30],
["ูˆูŠู†ูƒ ู…ู† ุฒู…ุงู†ุŸ", 0.30],
# Levantine-leaning
["ุดูˆ ุฃุฎุจุงุฑูƒุŸ ูƒูŠููƒุŸ", 0.30],
["ุจุฏู‘ูŠ ุฃุฑูˆุญ ู‡ู„ู‘ู‚", 0.30],
# Maghrebi-leaning (may vary depending on spelling)
["ู„ุงุจุงุณ ุนู„ูŠูƒุŸ ูˆุงุด ุฑุงูƒุŸ", 0.30],
["ุจุฒุงู ุฏูŠุงู„ ุงู„ู†ุงุณ ูƒูŠู‡ุถุฑูˆ ู‡ูƒุง", 0.30],
# Stricter threshold examples (fewer highlights)
["ุดู„ูˆู†ูƒุŸ", 0.30],
["ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ", 0.30],
],
inputs=[text_input, threshold_slider],
label="Click an example to auto-fill the input",
)
predict_button.click(
fn=predict_wrapper,
inputs=[model_dropdown, text_input, threshold_slider],
outputs=[results_output, summary_output, map_output],
)
gr.Markdown(
"""
---
### Notes
- The model outputs **multi-label** predictions: more than one dialect can be valid at once.
- Countries are colored with a **gradient** based on confidence: darker green means the confidence is closer to the threshold, brighter green means higher confidence.
If you use this demo in research, please cite the accompanying paper.
"""
)
# Launch
if __name__ == "__main__":
demo.launch()
#####################
### Version 3 #######
#####################
# import json
# from pathlib import Path
# import torch
# import gradio as gr
# import pandas as pd
# from transformers import (
# AutoModelForSequenceClassification,
# AutoTokenizer,
# AutoConfig,
# )
# import re
# import xml.etree.ElementTree as ET
# import numpy as np
# import xml.etree.ElementTree as ET
# from svgpathtools import parse_path
# # ======================
# # Devices
# # ======================
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # # ======================
# # # Base multi-dialect model (B2BERT)
# # # ======================
# # base_model_name = "Mohamedelzeftawy/b2bert_baseline"
# # base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE)
# # base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# # ======================
# # Multi-dialect model registry
# # ======================
# MODEL_CHOICES = {
# "LahjatBERT": "Mohamedelzeftawy/b2bert_baseline", # default (current)
# "LahjatBERT-CL-ALDI": "Mohamedelzeftawy/b2bert_cl_aldi",
# "LahjatBERT-CL-Cardinality": "Mohamedelzeftawy/b2bert_cl_cardinalty",
# }
# # Load default model at startup (LahjatBERT)
# _current_model_key = "LahjatBERT"
# base_model_name = MODEL_CHOICES[_current_model_key]
# base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE)
# base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# # Define dialects (order must match model's label mapping)
# DIALECTS = [
# "Algeria", "Bahrain", "Egypt", "Iraq", "Jordan", "Kuwait", "Lebanon", "Libya",
# "Morocco", "Oman", "Palestine", "Qatar", "Saudi_Arabia", "Sudan", "Syria",
# "Tunisia", "UAE", "Yemen"
# ]
# # Dialect -> ISO2 country code mapping (must match SVG path ids)
# DIALECT_TO_ISO2 = {
# "Algeria": "dz",
# "Bahrain": "bh",
# "Egypt": "eg",
# "Iraq": "iq",
# "Jordan": "jo",
# "Kuwait": "kw",
# "Lebanon": "lb",
# "Libya": "ly",
# "Morocco": "ma",
# "Oman": "om",
# "Palestine": "ps",
# "Qatar": "qa",
# "Saudi_Arabia": "sa",
# "Sudan": "sd",
# "Syria": "sy",
# "Tunisia": "tn",
# "UAE": "ae",
# "Yemen": "ye",
# }
# # ======================
# # Added: Egyptian-only model
# # ======================
# egyptian_repo = "Mohamedelzeftawy/egyptian_marbert"
# egyptian_cfg = AutoConfig.from_pretrained(egyptian_repo)
# egyptian_tok = AutoTokenizer.from_pretrained(egyptian_repo)
# egyptian_model = AutoModelForSequenceClassification.from_pretrained(
# egyptian_repo, config=egyptian_cfg
# ).to(DEVICE)
# # Heuristic: if num_labels==1 -> sigmoid; else softmax and assume positive label index=1
# _EGY_SIGMOID = (egyptian_cfg.num_labels == 1)
# _EGY_POS_INDEX = 1 if egyptian_cfg.num_labels >= 2 else 0
# # ======================
# # Map rendering
# # ======================
# # Put your SVG here: repo/assets/arab_world.svg
# # IMPORTANT: each country shape must have id="EG", id="SA", id="AE", etc (ISO2)
# SVG_PATH = Path("assets/world-map.svg")
# SVG_NS = "http://www.w3.org/2000/svg"
# ET.register_namespace("", SVG_NS)
# def load_multidialect_model(model_key: str):
# """
# Load the selected multi-dialect model + tokenizer.
# Uses global variables so the rest of your pipeline stays unchanged.
# """
# global base_model, base_tokenizer, base_model_name, _current_model_key
# if model_key == _current_model_key:
# return # already loaded
# repo = MODEL_CHOICES[model_key]
# base_model_name = repo
# base_model = AutoModelForSequenceClassification.from_pretrained(repo).to(DEVICE)
# base_tokenizer = AutoTokenizer.from_pretrained(repo)
# _current_model_key = model_key
# def _merge_style(old_style: str, updates: dict) -> str:
# """
# Merge CSS style strings (e.g., "fill:#000;stroke:#fff") with updates dict.
# """
# style_map = {}
# if old_style:
# for part in old_style.split(";"):
# part = part.strip()
# if not part or ":" not in part:
# continue
# k, v = part.split(":", 1)
# style_map[k.strip()] = v.strip()
# style_map.update(updates)
# return ";".join([f"{k}:{v}" for k, v in style_map.items() if v is not None])
# def recolor_svg(svg_text: str, predicted_ids: set[str]) -> str:
# """
# Recolor SVG elements by id.
# Handles cases where a country is stored as <g id="ye"> ... <path/> ... </g>.
# """
# root = ET.fromstring(svg_text)
# # normalize ids to lowercase for robust matching
# predicted_lower = {p.lower() for p in predicted_ids}
# base_fill = "#101418"
# base_stroke = "#2a2f3a"
# active_fill = "#2ecc71"
# active_stroke = "#ffffff"
# def apply_style(el, active: bool):
# tag = el.tag.split("}")[-1].lower()
# if tag not in ("path", "polygon"):
# return
# updates = {
# "fill": active_fill if active else base_fill,
# "stroke": active_stroke if active else base_stroke,
# "stroke-width": "0.9" if active else "0.5",
# "opacity": "1",
# }
# # remove conflicting attrs
# if "fill" in el.attrib:
# del el.attrib["fill"]
# if "stroke" in el.attrib:
# del el.attrib["stroke"]
# el.attrib["style"] = _merge_style(el.attrib.get("style", ""), updates)
# # Pass 1: default style for everything drawable (so the map stays consistent)
# for el in root.iter():
# apply_style(el, active=False)
# # Pass 2: activate predicted ids
# for el in root.iter():
# el_id = el.attrib.get("id")
# if not el_id:
# continue
# el_id_lower = el_id.strip().lower()
# if el_id_lower not in predicted_lower:
# continue
# tag = el.tag.split("}")[-1].lower()
# if tag in ("path", "polygon"):
# apply_style(el, active=True)
# elif tag == "g":
# # Important: if the country is a GROUP, color all its child shapes
# for child in el.iter():
# apply_style(child, active=True)
# return ET.tostring(root, encoding="unicode")
# ARAB_IDS = {
# "ma","dz","tn","ly","eg","sd",
# "ps","jo","lb","sy","iq",
# "sa","kw","bh","qa","ae","om","ye"
# }
# def compute_viewbox_from_ids(svg_text: str, ids: set[str], margin_ratio: float = 0.08):
# """
# Compute a tight viewBox around the given country ids based on their path geometry.
# Supports countries stored as <g id="..."> groups.
# """
# root = ET.fromstring(svg_text)
# ids_lower = {i.lower() for i in ids}
# xmin, ymin = np.inf, np.inf
# xmax, ymax = -np.inf, -np.inf
# def update_bbox_for_element(el):
# nonlocal xmin, ymin, xmax, ymax
# tag = el.tag.split("}")[-1].lower()
# if tag == "path":
# d = el.attrib.get("d")
# if not d:
# return
# p = parse_path(d)
# bxmin, bxmax, bymin, bymax = p.bbox()
# xmin = min(xmin, bxmin)
# xmax = max(xmax, bxmax)
# ymin = min(ymin, bymin)
# ymax = max(ymax, bymax)
# elif tag == "polygon":
# pts = el.attrib.get("points", "").strip()
# if not pts:
# return
# coords = []
# for chunk in pts.replace(",", " ").split():
# coords.append(float(chunk))
# xs = coords[0::2]
# ys = coords[1::2]
# xmin = min(xmin, min(xs))
# xmax = max(xmax, max(xs))
# ymin = min(ymin, min(ys))
# ymax = max(ymax, max(ys))
# for el in root.iter():
# el_id = el.attrib.get("id")
# if not el_id:
# continue
# el_id_lower = el_id.strip().lower()
# if el_id_lower not in ids_lower:
# continue
# tag = el.tag.split("}")[-1].lower()
# if tag in ("path", "polygon"):
# update_bbox_for_element(el)
# elif tag == "g":
# # If a country is a group, include all its child shapes
# for child in el.iter():
# update_bbox_for_element(child)
# if not np.isfinite(xmin):
# return None
# w = xmax - xmin
# h = ymax - ymin
# mx = w * margin_ratio
# my = h * margin_ratio
# xmin -= mx
# ymin -= my
# w += 2 * mx
# h += 2 * my
# return (float(xmin), float(ymin), float(w), float(h))
# def set_viewbox(svg_text: str, viewbox):
# root = ET.fromstring(svg_text)
# root.attrib["viewBox"] = " ".join(str(x) for x in viewbox)
# root.attrib["preserveAspectRatio"] = "xMidYMid meet"
# return ET.tostring(root, encoding="unicode")
# def render_map_html(predicted_iso2, conf_by_iso2=None):
# if not SVG_PATH.exists():
# return f"""
# <div style="padding:12px; border:1px solid #ddd; border-radius:10px;">
# <b>Map SVG not found.</b><br/>
# Please add <code>{SVG_PATH.as_posix()}</code> to your Space repo.
# </div>
# """
# svg = SVG_PATH.read_text(encoding="utf-8")
# predicted_ids = set(predicted_iso2 or [])
# svg_colored = recolor_svg(svg, predicted_ids)
# # AUTO-ZOOM to Arab world (fixed set of Arab countries, not โ€œpredicted onlyโ€)
# vb = compute_viewbox_from_ids(svg_colored, ARAB_IDS, margin_ratio=0.10)
# if vb is not None:
# svg_colored = set_viewbox(svg_colored, vb)
# # No JS, no script tags โ€” just return the updated SVG
# return f"""
# <div style="width:100%; max-width: 950px; margin: 0 auto;">
# {svg_colored}
# <div style="margin-top:8px; font-size:12px; color:#999;">
# Highlighted = confidence โ‰ฅ threshold
# </div>
# </div>
# """
# # ======================
# # Inference helpers
# # ======================
# def predict_dialects_with_confidence(text, threshold=0.3):
# """
# Predict Arabic dialects for the given text (multi-label) and return confidence scores.
# """
# if not text or not text.strip():
# return pd.DataFrame({"Dialect": [], "Confidence": [], "Prediction": []})
# enc = base_tokenizer([text], truncation=True, padding=True, max_length=128, return_tensors="pt")
# input_ids = enc["input_ids"].to(DEVICE)
# attention_mask = enc["attention_mask"].to(DEVICE)
# with torch.no_grad():
# outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
# logits = outputs.logits # (1, num_labels)
# probs = torch.sigmoid(logits).cpu().numpy().reshape(-1)
# rows = []
# for dialect, p in zip(DIALECTS, probs):
# rows.append({
# "Dialect": dialect,
# "Confidence": f"{p:.4f}",
# "Prediction": "โœ“ Valid" if p >= threshold else "โœ— Invalid",
# })
# df = pd.DataFrame(rows)
# df = df.sort_values("Confidence", ascending=False, key=lambda x: x.astype(float))
# return df
# def predict_wrapper(model_key, text, threshold):
# """
# Returns:
# df (table),
# summary (markdown),
# map_html (HTML)
# """
# load_multidialect_model(model_key)
# df = predict_dialects_with_confidence(text, threshold)
# predicted_dialects = df[df["Prediction"] == "โœ“ Valid"]["Dialect"].tolist()
# summary = f"**Predicted Dialects ({len(predicted_dialects)}):** {', '.join(predicted_dialects) if predicted_dialects else 'None'}"
# # Build predicted ISO2 list + confidences for tooltips
# predicted_iso2 = []
# conf_by_iso2 = {}
# for _, row in df.iterrows():
# if row["Prediction"] != "โœ“ Valid":
# continue
# dialect = row["Dialect"]
# if dialect not in DIALECT_TO_ISO2:
# continue
# code = DIALECT_TO_ISO2[dialect]
# predicted_iso2.append(code)
# conf_by_iso2[code] = float(row["Confidence"])
# print("predicted_iso2:", predicted_iso2)
# map_html = render_map_html(predicted_iso2, conf_by_iso2)
# return df, summary, map_html
# def predict_egyptian(text, threshold=0.5):
# """
# Predict whether the input is Egyptian dialect using the dedicated model.
# Returns a small dataframe and a markdown summary.
# """
# if not text or not text.strip():
# return pd.DataFrame({"Label": [], "Confidence": []}), "**No input provided.**"
# enc = egyptian_tok([text], truncation=True, padding=True, max_length=128, return_tensors="pt")
# input_ids = enc["input_ids"].to(DEVICE)
# attention_mask = enc["attention_mask"].to(DEVICE)
# with torch.no_grad():
# outputs = egyptian_model(input_ids=input_ids, attention_mask=attention_mask)
# logits = outputs.logits # (1, num_labels)
# if _EGY_SIGMOID:
# p = torch.sigmoid(logits).item()
# else:
# probs = torch.softmax(logits, dim=-1).squeeze(0)
# p = probs[_EGY_POS_INDEX].item()
# label = "โœ“ Egyptian" if p >= threshold else "โœ— Not Egyptian"
# df = pd.DataFrame([{"Label": label, "Confidence": f"{p:.4f}"}])
# md = f"**Prediction:** {label} \n**Confidence:** {p:.4f} \n**Threshold:** {threshold:.2f}"
# return df, md
# # ======================
# # Gradio UI
# # ======================
# with gr.Blocks(theme=gr.themes.Soft()) as demo:
# gr.Markdown(
# """
# # ๐ŸŒ LahjatBERT: Multi-Label Arabic Dialect Classifier
# This demo predicts **which country-level Arabic dialects a sentence sounds natural in**.
# Unlike classic โ€œpick one dialectโ€ systems, **a single sentence can be acceptable in multiple dialects**.
# **How to use**
# 1) Paste an Arabic sentence
# 2) Adjust the **Confidence Threshold** (higher = fewer highlights)
# 3) Click **Predict Dialects**
# **How to interpret the results**
# - **Highlighted countries** = dialects predicted as *valid/acceptable* for the sentence
# """
# )
# with gr.Row():
# with gr.Column(scale=1):
# model_dropdown = gr.Dropdown(
# choices=list(MODEL_CHOICES.keys()),
# value="LahjatBERT",
# label="Model",
# info="Select which LahjatBERT variant to use for prediction."
# )
# text_input = gr.Textbox(
# label="Arabic Text Input",
# placeholder="ุฃุฏุฎู„ ู†ุตู‹ุง ุนุฑุจูŠู‹ุง ู‡ู†ุง... ู…ุซุงู„: ุดู„ูˆู†ูƒุŸ / ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ / ุดูˆ ุฃุฎุจุงุฑูƒุŸ",
# lines=4,
# rtl=True,
# )
# threshold_slider = gr.Slider(
# minimum=0.1,
# maximum=0.9,
# value=0.3,
# step=0.05,
# label="Confidence Threshold",
# info=(
# "Dialects with confidence โ‰ฅ threshold are marked as valid. "
# "Try 0.30 for broader overlap, or 0.50 for stricter predictions."
# ),
# )
# predict_button = gr.Button("๐Ÿ” Predict Dialects", variant="primary")
# gr.Markdown(
# """
# **Tip:** If youโ€™re testing a sentence thatโ€™s close to Modern Standard Arabic (MSA),
# you may see **many countries highlighted**โ€”thatโ€™s expected, because MSA-like text
# can be acceptable across dialects.
# """
# )
# with gr.Column(scale=1):
# summary_output = gr.Markdown(label="Summary")
# results_output = gr.Dataframe(
# label="Detailed Results",
# headers=["Dialect", "Confidence", "Prediction"],
# datatype=["str", "str", "str"],
# )
# gr.Markdown("---")
# gr.Markdown(
# """
# ## ๐Ÿ—บ๏ธ Dialect Map (Zoomed to the Arab World)
# The map updates after each prediction.
# Green countries indicate dialects predicted as valid at your selected threshold.
# """
# )
# map_output = gr.HTML(label="Arab World Map", value=render_map_html([], {}))
# gr.Markdown("---")
# gr.Markdown(
# """
# ## โœจ Try these examples
# These examples are meant to show **dialect overlap**:
# - Some expressions are widely shared and may light up multiple regions
# - Others contain strong local signals (e.g., Egyptian, Gulf/Khaleeji, Levantine, Maghrebi)
# """
# )
# gr.Examples(
# examples=[
# # Broad / MSA-like (often acceptable widely)
# ["ูƒูŠู ุญุงู„ูƒุŸ", 0.30],
# ["ุงู„ุณู„ุงู… ุนู„ูŠูƒู… ูˆุฑุญู…ุฉ ุงู„ู„ู‡ ูˆุจุฑูƒุงุชู‡", 0.30],
# # Egyptian-leaning
# ["ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ ุนุงู…ู„ ุฅูŠู‡ุŸ", 0.30],
# ["ู…ุด ูุงู‡ู… ู„ูŠู‡ ูƒุฏู‡ ุจุตุฑุงุญุฉ", 0.30],
# # Gulf / Iraqi-leaning
# ["ุดู„ูˆู†ูƒุŸ ุดุฎุจุงุฑูƒุŸ", 0.30],
# ["ูˆูŠู†ูƒ ู…ู† ุฒู…ุงู†ุŸ", 0.30],
# # Levantine-leaning
# ["ุดูˆ ุฃุฎุจุงุฑูƒุŸ ูƒูŠููƒุŸ", 0.30],
# ["ุจุฏู‘ูŠ ุฃุฑูˆุญ ู‡ู„ู‘ู‚", 0.30],
# # Maghrebi-leaning (may vary depending on spelling)
# ["ู„ุงุจุงุณ ุนู„ูŠูƒุŸ ูˆุงุด ุฑุงูƒุŸ", 0.30],
# ["ุจุฒุงู ุฏูŠุงู„ ุงู„ู†ุงุณ ูƒูŠู‡ุถุฑูˆ ู‡ูƒุง", 0.30],
# # Stricter threshold examples (fewer highlights)
# ["ุดู„ูˆู†ูƒุŸ", 0.30],
# ["ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ", 0.30],
# ],
# inputs=[text_input, threshold_slider],
# label="Click an example to auto-fill the input",
# )
# predict_button.click(
# fn=predict_wrapper,
# inputs=[model_dropdown, text_input, threshold_slider],
# outputs=[results_output, summary_output, map_output],
# )
# gr.Markdown(
# """
# ---
# ### Notes
# - The model outputs **multi-label** predictions: more than one dialect can be valid at once.
# If you use this demo in research, please cite the accompanying paper.
# """
# )
# # Launch
# if __name__ == "__main__":
# demo.launch()