##################### ### Version 1 ####### ##################### # import torch # import gradio as gr # from transformers import AutoModelForSequenceClassification, AutoTokenizer # import pandas as pd # # Load the model and tokenizer # model_name = "AHAAM/B2BERT" # model = AutoModelForSequenceClassification.from_pretrained(model_name) # tokenizer = AutoTokenizer.from_pretrained(model_name) # # Define dialects # DIALECTS = [ # "Algeria", "Bahrain", "Egypt", "Iraq", "Jordan", "Kuwait", "Lebanon", "Libya", # "Morocco", "Oman", "Palestine", "Qatar", "Saudi_Arabia", "Sudan", "Syria", # "Tunisia", "UAE", "Yemen" # ] # def predict_dialects_with_confidence(text, threshold=0.3): # """ # Predict Arabic dialects for the given text and return confidence scores. # Args: # text: Input Arabic text # threshold: Confidence threshold for classification (default 0.3) # Returns: # DataFrame with dialects and their confidence scores # """ # if not text.strip(): # return pd.DataFrame({"Dialect": [], "Confidence": [], "Prediction": []}) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model.to(device) # # Tokenize input # encodings = tokenizer( # [text], # truncation=True, # padding=True, # max_length=128, # return_tensors="pt" # ) # input_ids = encodings["input_ids"].to(device) # attention_mask = encodings["attention_mask"].to(device) # # Get predictions # with torch.no_grad(): # outputs = model(input_ids=input_ids, attention_mask=attention_mask) # logits = outputs.logits # # Calculate probabilities # probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1) # # Create results dataframe # results = [] # for dialect, prob in zip(DIALECTS, probabilities): # prediction = "โœ“ Valid" if prob >= threshold else "โœ— Invalid" # results.append({ # "Dialect": dialect, # "Confidence": f"{prob:.4f}", # "Prediction": prediction # }) # # Sort by confidence (descending) # df = pd.DataFrame(results) # df = df.sort_values("Confidence", ascending=False, key=lambda x: x.astype(float)) # return df # def predict_wrapper(text, threshold): # """Wrapper function for Gradio interface""" # df = predict_dialects_with_confidence(text, threshold) # # Also create a summary of predicted dialects # predicted = df[df["Prediction"] == "โœ“ Valid"]["Dialect"].tolist() # summary = f"**Predicted Dialects ({len(predicted)}):** {', '.join(predicted) if predicted else 'None'}" # return df, summary # # Create Gradio interface # with gr.Blocks(theme=gr.themes.Soft()) as demo: # gr.Markdown( # """ # # ๐ŸŒ B2BERT Arabic Dialect Classifier # This model identifies which Arabic dialects are valid for a given text input. # Enter Arabic text below to see the dialect predictions and confidence scores. # **Supported Dialects:** Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, # Morocco, Oman, Palestine, Qatar, Saudi Arabia, Sudan, Syria, Tunisia, UAE, Yemen # """ # ) # with gr.Row(): # with gr.Column(): # text_input = gr.Textbox( # label="Arabic Text Input", # placeholder="ุฃุฏุฎู„ ุงู„ู†ุต ุงู„ุนุฑุจูŠ ู‡ู†ุง... (e.g., ูƒูŠู ุญุงู„ูƒุŸ)", # lines=3, # rtl=True # ) # threshold_slider = gr.Slider( # minimum=0.1, # maximum=0.9, # value=0.3, # step=0.05, # label="Confidence Threshold", # info="Dialects with confidence above this threshold will be marked as valid" # ) # predict_button = gr.Button("๐Ÿ” Predict Dialects", variant="primary") # with gr.Column(): # summary_output = gr.Markdown(label="Summary") # results_output = gr.Dataframe( # label="Detailed Results", # headers=["Dialect", "Confidence", "Prediction"], # datatype=["str", "str", "str"] # ) # # Examples # gr.Examples( # examples=[ # ["ูƒูŠู ุญุงู„ูƒุŸ", 0.3], # ["ุดู„ูˆู†ูƒุŸ", 0.3], # ["ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ", 0.3], # ["ุดูˆ ุฃุฎุจุงุฑูƒุŸ", 0.3], # ], # inputs=[text_input, threshold_slider], # label="Try these examples" # ) # # Connect button to function # predict_button.click( # fn=predict_wrapper, # inputs=[text_input, threshold_slider], # outputs=[results_output, summary_output] # ) # gr.Markdown( # """ # --- # **Model:** [AHAAM/B2BERT](https://huggingface.co/AHAAM/B2BERT) # **Note:** The model uses a multi-label classification approach where each dialect is # independently evaluated. A single text can be valid in multiple dialects. # """ # ) # # Launch the app # if __name__ == "__main__": # demo.launch() ##################### ### Version 2 ####### ##################### import json from pathlib import Path import torch import gradio as gr import pandas as pd from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, ) import re import xml.etree.ElementTree as ET import numpy as np from svgpathtools import parse_path # ====================== # Devices # ====================== DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ====================== # Multi-dialect model registry # ====================== MODEL_CHOICES = { "LahjatBERT": "Mohamedelzeftawy/b2bert_baseline", "LahjatBERT-CL-ALDI": "Mohamedelzeftawy/b2bert_cl_aldi", "LahjatBERT-CL-Cardinality": "Mohamedelzeftawy/b2bert_cl_cardinalty", } # Load default model at startup (LahjatBERT) _current_model_key = "LahjatBERT" base_model_name = MODEL_CHOICES[_current_model_key] base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE) base_tokenizer = AutoTokenizer.from_pretrained(base_model_name) # Define dialects (order must match model's label mapping) DIALECTS = [ "Algeria", "Bahrain", "Egypt", "Iraq", "Jordan", "Kuwait", "Lebanon", "Libya", "Morocco", "Oman", "Palestine", "Qatar", "Saudi_Arabia", "Sudan", "Syria", "Tunisia", "UAE", "Yemen" ] # Dialect -> ISO2 country code mapping (must match SVG path ids) DIALECT_TO_ISO2 = { "Algeria": "dz", "Bahrain": "bh", "Egypt": "eg", "Iraq": "iq", "Jordan": "jo", "Kuwait": "kw", "Lebanon": "lb", "Libya": "ly", "Morocco": "ma", "Oman": "om", "Palestine": "ps", "Qatar": "qa", "Saudi_Arabia": "sa", "Sudan": "sd", "Syria": "sy", "Tunisia": "tn", "UAE": "ae", "Yemen": "ye", } # ====================== # Added: Egyptian-only model # ====================== egyptian_repo = "Mohamedelzeftawy/egyptian_marbert" egyptian_cfg = AutoConfig.from_pretrained(egyptian_repo) egyptian_tok = AutoTokenizer.from_pretrained(egyptian_repo) egyptian_model = AutoModelForSequenceClassification.from_pretrained( egyptian_repo, config=egyptian_cfg ).to(DEVICE) # Heuristic: if num_labels==1 -> sigmoid; else softmax and assume positive label index=1 _EGY_SIGMOID = (egyptian_cfg.num_labels == 1) _EGY_POS_INDEX = 1 if egyptian_cfg.num_labels >= 2 else 0 # ====================== # Map rendering # ====================== # Put your SVG here: repo/assets/arab_world.svg # IMPORTANT: each country shape must have id="EG", id="SA", id="AE", etc (ISO2) SVG_PATH = Path("assets/world-map.svg") SVG_NS = "http://www.w3.org/2000/svg" ET.register_namespace("", SVG_NS) def load_multidialect_model(model_key: str): """ Load the selected multi-dialect model + tokenizer. Uses global variables so the rest of your pipeline stays unchanged. """ global base_model, base_tokenizer, base_model_name, _current_model_key if model_key == _current_model_key: return # already loaded repo = MODEL_CHOICES[model_key] base_model_name = repo base_model = AutoModelForSequenceClassification.from_pretrained(repo).to(DEVICE) base_tokenizer = AutoTokenizer.from_pretrained(repo) _current_model_key = model_key def _merge_style(old_style: str, updates: dict) -> str: """ Merge CSS style strings (e.g., "fill:#000;stroke:#fff") with updates dict. """ style_map = {} if old_style: for part in old_style.split(";"): part = part.strip() if not part or ":" not in part: continue k, v = part.split(":", 1) style_map[k.strip()] = v.strip() style_map.update(updates) return ";".join([f"{k}:{v}" for k, v in style_map.items() if v is not None]) def get_gradient_color(confidence: float, threshold: float) -> str: """ Generate a gradient color from dark to bright green based on confidence. Uses a power function for sharper contrast. Args: confidence: Probability score (0-1) threshold: Minimum threshold for prediction Returns: Hex color string """ if confidence < threshold: return "#101418" # base color (no prediction) # Normalize confidence to [0, 1] range starting from threshold normalized = (confidence - threshold) / (1.0 - threshold) # Apply power function for sharper gradient (higher power = sharper) # You can adjust the exponent: 2.0 = moderate sharp, 3.0 = very sharp normalized = normalized ** 2.5 # Define gradient from dark green to bright green # Dark green: #1a5e3a, Bright green: #2ecc71 r_start, g_start, b_start = 0x1a, 0x5e, 0x3a # dark green r_end, g_end, b_end = 0x2e, 0xcc, 0x71 # bright green # Linear interpolation r = int(r_start + (r_end - r_start) * normalized) g = int(g_start + (g_end - g_start) * normalized) b = int(b_start + (b_end - b_start) * normalized) return f"#{r:02x}{g:02x}{b:02x}" def recolor_svg(svg_text: str, conf_by_iso2: dict, threshold: float) -> str: """ Recolor SVG elements by id with gradient colors based on confidence. Handles cases where a country is stored as ... ... . Args: svg_text: SVG content as string conf_by_iso2: Dict mapping ISO2 codes to confidence scores threshold: Confidence threshold for predictions """ root = ET.fromstring(svg_text) # normalize ids to lowercase for robust matching conf_by_iso2_lower = {k.lower(): v for k, v in (conf_by_iso2 or {}).items()} base_fill = "#101418" base_stroke = "#2a2f3a" active_stroke = "#ffffff" def apply_style(el, confidence: float = None): tag = el.tag.split("}")[-1].lower() if tag not in ("path", "polygon"): return if confidence is not None and confidence >= threshold: fill_color = get_gradient_color(confidence, threshold) stroke_color = active_stroke stroke_width = "0.9" opacity = "1" else: fill_color = base_fill stroke_color = base_stroke stroke_width = "0.5" opacity = "1" updates = { "fill": fill_color, "stroke": stroke_color, "stroke-width": stroke_width, "opacity": opacity, } # remove conflicting attrs if "fill" in el.attrib: del el.attrib["fill"] if "stroke" in el.attrib: del el.attrib["stroke"] el.attrib["style"] = _merge_style(el.attrib.get("style", ""), updates) # Pass 1: default style for everything drawable for el in root.iter(): apply_style(el, confidence=None) # Pass 2: apply gradient colors based on confidence for el in root.iter(): el_id = el.attrib.get("id") if not el_id: continue el_id_lower = el_id.strip().lower() confidence = conf_by_iso2_lower.get(el_id_lower) if confidence is None: continue tag = el.tag.split("}")[-1].lower() if tag in ("path", "polygon"): apply_style(el, confidence=confidence) elif tag == "g": # If the country is a GROUP, color all its child shapes for child in el.iter(): apply_style(child, confidence=confidence) return ET.tostring(root, encoding="unicode") ARAB_IDS = { "ma","dz","tn","ly","eg","sd", "ps","jo","lb","sy","iq", "sa","kw","bh","qa","ae","om","ye" } def compute_viewbox_from_ids(svg_text: str, ids: set[str], margin_ratio: float = 0.08): """ Compute a tight viewBox around the given country ids based on their path geometry. Supports countries stored as groups. """ root = ET.fromstring(svg_text) ids_lower = {i.lower() for i in ids} xmin, ymin = np.inf, np.inf xmax, ymax = -np.inf, -np.inf def update_bbox_for_element(el): nonlocal xmin, ymin, xmax, ymax tag = el.tag.split("}")[-1].lower() if tag == "path": d = el.attrib.get("d") if not d: return p = parse_path(d) bxmin, bxmax, bymin, bymax = p.bbox() xmin = min(xmin, bxmin) xmax = max(xmax, bxmax) ymin = min(ymin, bymin) ymax = max(ymax, bymax) elif tag == "polygon": pts = el.attrib.get("points", "").strip() if not pts: return coords = [] for chunk in pts.replace(",", " ").split(): coords.append(float(chunk)) xs = coords[0::2] ys = coords[1::2] xmin = min(xmin, min(xs)) xmax = max(xmax, max(xs)) ymin = min(ymin, min(ys)) ymax = max(ymax, max(ys)) for el in root.iter(): el_id = el.attrib.get("id") if not el_id: continue el_id_lower = el_id.strip().lower() if el_id_lower not in ids_lower: continue tag = el.tag.split("}")[-1].lower() if tag in ("path", "polygon"): update_bbox_for_element(el) elif tag == "g": # If a country is a group, include all its child shapes for child in el.iter(): update_bbox_for_element(child) if not np.isfinite(xmin): return None w = xmax - xmin h = ymax - ymin mx = w * margin_ratio my = h * margin_ratio xmin -= mx ymin -= my w += 2 * mx h += 2 * my return (float(xmin), float(ymin), float(w), float(h)) def set_viewbox(svg_text: str, viewbox): root = ET.fromstring(svg_text) root.attrib["viewBox"] = " ".join(str(x) for x in viewbox) root.attrib["preserveAspectRatio"] = "xMidYMid meet" return ET.tostring(root, encoding="unicode") def render_map_html(conf_by_iso2, threshold): """ Render the map with gradient colors based on confidence scores. Args: conf_by_iso2: Dict mapping ISO2 codes to confidence scores threshold: Confidence threshold for predictions """ if not SVG_PATH.exists(): return f"""
Map SVG not found.
Please add {SVG_PATH.as_posix()} to your Space repo.
""" svg = SVG_PATH.read_text(encoding="utf-8") svg_colored = recolor_svg(svg, conf_by_iso2, threshold) # AUTO-ZOOM to Arab world (fixed set of Arab countries, not "predicted only") vb = compute_viewbox_from_ids(svg_colored, ARAB_IDS, margin_ratio=0.10) if vb is not None: svg_colored = set_viewbox(svg_colored, vb) # Add a legend showing the gradient scale legend_html = """
Confidence Scale
Low
High
Darker = closer to threshold | Brighter = higher confidence
""" return f"""
{svg_colored} {legend_html}
""" # ====================== # Inference helpers # ====================== def predict_dialects_with_confidence(text, threshold=0.3): """ Predict Arabic dialects for the given text (multi-label) and return confidence scores. """ if not text or not text.strip(): return pd.DataFrame({"Dialect": [], "Confidence": [], "Prediction": []}) enc = base_tokenizer([text], truncation=True, padding=True, max_length=128, return_tensors="pt") input_ids = enc["input_ids"].to(DEVICE) attention_mask = enc["attention_mask"].to(DEVICE) with torch.no_grad(): outputs = base_model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits # (1, num_labels) probs = torch.sigmoid(logits).cpu().numpy().reshape(-1) rows = [] for dialect, p in zip(DIALECTS, probs): rows.append({ "Dialect": dialect, "Confidence": f"{p:.4f}", "Prediction": "โœ“ Valid" if p >= threshold else "โœ— Invalid", }) df = pd.DataFrame(rows) df = df.sort_values("Confidence", ascending=False, key=lambda x: x.astype(float)) return df def predict_wrapper(model_key, text, threshold): """ Returns: df (table), summary (markdown), map_html (HTML) """ load_multidialect_model(model_key) df = predict_dialects_with_confidence(text, threshold) predicted_dialects = df[df["Prediction"] == "โœ“ Valid"]["Dialect"].tolist() summary = f"**Predicted Dialects ({len(predicted_dialects)}):** {', '.join(predicted_dialects) if predicted_dialects else 'None'}" # Build confidence dict for ALL dialects (not just predicted ones) conf_by_iso2 = {} for _, row in df.iterrows(): dialect = row["Dialect"] if dialect not in DIALECT_TO_ISO2: continue code = DIALECT_TO_ISO2[dialect] conf_by_iso2[code] = float(row["Confidence"]) print("conf_by_iso2:", conf_by_iso2) map_html = render_map_html(conf_by_iso2, threshold) return df, summary, map_html def predict_egyptian(text, threshold=0.5): """ Predict whether the input is Egyptian dialect using the dedicated model. Returns a small dataframe and a markdown summary. """ if not text or not text.strip(): return pd.DataFrame({"Label": [], "Confidence": []}), "**No input provided.**" enc = egyptian_tok([text], truncation=True, padding=True, max_length=128, return_tensors="pt") input_ids = enc["input_ids"].to(DEVICE) attention_mask = enc["attention_mask"].to(DEVICE) with torch.no_grad(): outputs = egyptian_model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits # (1, num_labels) if _EGY_SIGMOID: p = torch.sigmoid(logits).item() else: probs = torch.softmax(logits, dim=-1).squeeze(0) p = probs[_EGY_POS_INDEX].item() label = "โœ“ Egyptian" if p >= threshold else "โœ— Not Egyptian" df = pd.DataFrame([{"Label": label, "Confidence": f"{p:.4f}"}]) md = f"**Prediction:** {label} \n**Confidence:** {p:.4f} \n**Threshold:** {threshold:.2f}" return df, md # ====================== # Gradio UI # ====================== with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # ๐ŸŒ LahjatBERT: Multi-Label Arabic Dialect Classifier This demo predicts **which country-level Arabic dialects a sentence sounds natural in**. Unlike classic "pick one dialect" systems, **a single sentence can be acceptable in multiple dialects**. **How to use** 1) Paste an Arabic sentence 2) Adjust the **Confidence Threshold** (higher = fewer highlights) 3) Click **Predict Dialects** **How to interpret the results** - **Highlighted countries** = dialects predicted as *valid/acceptable* for the sentence - **Color intensity** = confidence level (darker green = closer to threshold, brighter = higher confidence) """ ) with gr.Row(): with gr.Column(scale=1): model_dropdown = gr.Dropdown( choices=list(MODEL_CHOICES.keys()), value="LahjatBERT", label="Model", info="Select which LahjatBERT variant to use for prediction." ) text_input = gr.Textbox( label="Arabic Text Input", placeholder="ุฃุฏุฎู„ ู†ุตู‹ุง ุนุฑุจูŠู‹ุง ู‡ู†ุง... ู…ุซุงู„: ุดู„ูˆู†ูƒุŸ / ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ / ุดูˆ ุฃุฎุจุงุฑูƒุŸ", lines=4, rtl=True, ) threshold_slider = gr.Slider( minimum=0.1, maximum=0.9, value=0.3, step=0.05, label="Confidence Threshold", info=( "Dialects with confidence โ‰ฅ threshold are marked as valid. " "Try 0.30 for broader overlap, or 0.50 for stricter predictions." ), ) predict_button = gr.Button("๐Ÿ” Predict Dialects", variant="primary") gr.Markdown( """ **Tip:** If you're testing a sentence that's close to Modern Standard Arabic (MSA), you may see **many countries highlighted**โ€”that's expected, because MSA-like text can be acceptable across dialects. """ ) with gr.Column(scale=1): summary_output = gr.Markdown(label="Summary") results_output = gr.Dataframe( label="Detailed Results", headers=["Dialect", "Confidence", "Prediction"], datatype=["str", "str", "str"], ) gr.Markdown("---") gr.Markdown( """ ## ๐Ÿ—บ๏ธ Dialect Map (Zoomed to the Arab World) The map updates after each prediction. Green countries indicate dialects predicted as valid at your selected threshold. The intensity of the green color reflects the confidence level. """ ) map_output = gr.HTML(label="Arab World Map", value=render_map_html({}, 0.3)) gr.Markdown("---") gr.Markdown( """ ## โœจ Try these examples These examples are meant to show **dialect overlap**: - Some expressions are widely shared and may light up multiple regions - Others contain strong local signals (e.g., Egyptian, Gulf/Khaleeji, Levantine, Maghrebi) """ ) gr.Examples( examples=[ # Broad / MSA-like (often acceptable widely) ["ูƒูŠู ุญุงู„ูƒุŸ", 0.30], ["ุงู„ุณู„ุงู… ุนู„ูŠูƒู… ูˆุฑุญู…ุฉ ุงู„ู„ู‡ ูˆุจุฑูƒุงุชู‡", 0.30], # Egyptian-leaning ["ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ ุนุงู…ู„ ุฅูŠู‡ุŸ", 0.30], ["ู…ุด ูุงู‡ู… ู„ูŠู‡ ูƒุฏู‡ ุจุตุฑุงุญุฉ", 0.30], # Gulf / Iraqi-leaning ["ุดู„ูˆู†ูƒุŸ ุดุฎุจุงุฑูƒุŸ", 0.30], ["ูˆูŠู†ูƒ ู…ู† ุฒู…ุงู†ุŸ", 0.30], # Levantine-leaning ["ุดูˆ ุฃุฎุจุงุฑูƒุŸ ูƒูŠููƒุŸ", 0.30], ["ุจุฏู‘ูŠ ุฃุฑูˆุญ ู‡ู„ู‘ู‚", 0.30], # Maghrebi-leaning (may vary depending on spelling) ["ู„ุงุจุงุณ ุนู„ูŠูƒุŸ ูˆุงุด ุฑุงูƒุŸ", 0.30], ["ุจุฒุงู ุฏูŠุงู„ ุงู„ู†ุงุณ ูƒูŠู‡ุถุฑูˆ ู‡ูƒุง", 0.30], # Stricter threshold examples (fewer highlights) ["ุดู„ูˆู†ูƒุŸ", 0.30], ["ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ", 0.30], ], inputs=[text_input, threshold_slider], label="Click an example to auto-fill the input", ) predict_button.click( fn=predict_wrapper, inputs=[model_dropdown, text_input, threshold_slider], outputs=[results_output, summary_output, map_output], ) gr.Markdown( """ --- ### Notes - The model outputs **multi-label** predictions: more than one dialect can be valid at once. - Countries are colored with a **gradient** based on confidence: darker green means the confidence is closer to the threshold, brighter green means higher confidence. If you use this demo in research, please cite the accompanying paper. """ ) # Launch if __name__ == "__main__": demo.launch() ##################### ### Version 3 ####### ##################### # import json # from pathlib import Path # import torch # import gradio as gr # import pandas as pd # from transformers import ( # AutoModelForSequenceClassification, # AutoTokenizer, # AutoConfig, # ) # import re # import xml.etree.ElementTree as ET # import numpy as np # import xml.etree.ElementTree as ET # from svgpathtools import parse_path # # ====================== # # Devices # # ====================== # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # # # ====================== # # # Base multi-dialect model (B2BERT) # # # ====================== # # base_model_name = "Mohamedelzeftawy/b2bert_baseline" # # base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE) # # base_tokenizer = AutoTokenizer.from_pretrained(base_model_name) # # ====================== # # Multi-dialect model registry # # ====================== # MODEL_CHOICES = { # "LahjatBERT": "Mohamedelzeftawy/b2bert_baseline", # default (current) # "LahjatBERT-CL-ALDI": "Mohamedelzeftawy/b2bert_cl_aldi", # "LahjatBERT-CL-Cardinality": "Mohamedelzeftawy/b2bert_cl_cardinalty", # } # # Load default model at startup (LahjatBERT) # _current_model_key = "LahjatBERT" # base_model_name = MODEL_CHOICES[_current_model_key] # base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE) # base_tokenizer = AutoTokenizer.from_pretrained(base_model_name) # # Define dialects (order must match model's label mapping) # DIALECTS = [ # "Algeria", "Bahrain", "Egypt", "Iraq", "Jordan", "Kuwait", "Lebanon", "Libya", # "Morocco", "Oman", "Palestine", "Qatar", "Saudi_Arabia", "Sudan", "Syria", # "Tunisia", "UAE", "Yemen" # ] # # Dialect -> ISO2 country code mapping (must match SVG path ids) # DIALECT_TO_ISO2 = { # "Algeria": "dz", # "Bahrain": "bh", # "Egypt": "eg", # "Iraq": "iq", # "Jordan": "jo", # "Kuwait": "kw", # "Lebanon": "lb", # "Libya": "ly", # "Morocco": "ma", # "Oman": "om", # "Palestine": "ps", # "Qatar": "qa", # "Saudi_Arabia": "sa", # "Sudan": "sd", # "Syria": "sy", # "Tunisia": "tn", # "UAE": "ae", # "Yemen": "ye", # } # # ====================== # # Added: Egyptian-only model # # ====================== # egyptian_repo = "Mohamedelzeftawy/egyptian_marbert" # egyptian_cfg = AutoConfig.from_pretrained(egyptian_repo) # egyptian_tok = AutoTokenizer.from_pretrained(egyptian_repo) # egyptian_model = AutoModelForSequenceClassification.from_pretrained( # egyptian_repo, config=egyptian_cfg # ).to(DEVICE) # # Heuristic: if num_labels==1 -> sigmoid; else softmax and assume positive label index=1 # _EGY_SIGMOID = (egyptian_cfg.num_labels == 1) # _EGY_POS_INDEX = 1 if egyptian_cfg.num_labels >= 2 else 0 # # ====================== # # Map rendering # # ====================== # # Put your SVG here: repo/assets/arab_world.svg # # IMPORTANT: each country shape must have id="EG", id="SA", id="AE", etc (ISO2) # SVG_PATH = Path("assets/world-map.svg") # SVG_NS = "http://www.w3.org/2000/svg" # ET.register_namespace("", SVG_NS) # def load_multidialect_model(model_key: str): # """ # Load the selected multi-dialect model + tokenizer. # Uses global variables so the rest of your pipeline stays unchanged. # """ # global base_model, base_tokenizer, base_model_name, _current_model_key # if model_key == _current_model_key: # return # already loaded # repo = MODEL_CHOICES[model_key] # base_model_name = repo # base_model = AutoModelForSequenceClassification.from_pretrained(repo).to(DEVICE) # base_tokenizer = AutoTokenizer.from_pretrained(repo) # _current_model_key = model_key # def _merge_style(old_style: str, updates: dict) -> str: # """ # Merge CSS style strings (e.g., "fill:#000;stroke:#fff") with updates dict. # """ # style_map = {} # if old_style: # for part in old_style.split(";"): # part = part.strip() # if not part or ":" not in part: # continue # k, v = part.split(":", 1) # style_map[k.strip()] = v.strip() # style_map.update(updates) # return ";".join([f"{k}:{v}" for k, v in style_map.items() if v is not None]) # def recolor_svg(svg_text: str, predicted_ids: set[str]) -> str: # """ # Recolor SVG elements by id. # Handles cases where a country is stored as ... ... . # """ # root = ET.fromstring(svg_text) # # normalize ids to lowercase for robust matching # predicted_lower = {p.lower() for p in predicted_ids} # base_fill = "#101418" # base_stroke = "#2a2f3a" # active_fill = "#2ecc71" # active_stroke = "#ffffff" # def apply_style(el, active: bool): # tag = el.tag.split("}")[-1].lower() # if tag not in ("path", "polygon"): # return # updates = { # "fill": active_fill if active else base_fill, # "stroke": active_stroke if active else base_stroke, # "stroke-width": "0.9" if active else "0.5", # "opacity": "1", # } # # remove conflicting attrs # if "fill" in el.attrib: # del el.attrib["fill"] # if "stroke" in el.attrib: # del el.attrib["stroke"] # el.attrib["style"] = _merge_style(el.attrib.get("style", ""), updates) # # Pass 1: default style for everything drawable (so the map stays consistent) # for el in root.iter(): # apply_style(el, active=False) # # Pass 2: activate predicted ids # for el in root.iter(): # el_id = el.attrib.get("id") # if not el_id: # continue # el_id_lower = el_id.strip().lower() # if el_id_lower not in predicted_lower: # continue # tag = el.tag.split("}")[-1].lower() # if tag in ("path", "polygon"): # apply_style(el, active=True) # elif tag == "g": # # Important: if the country is a GROUP, color all its child shapes # for child in el.iter(): # apply_style(child, active=True) # return ET.tostring(root, encoding="unicode") # ARAB_IDS = { # "ma","dz","tn","ly","eg","sd", # "ps","jo","lb","sy","iq", # "sa","kw","bh","qa","ae","om","ye" # } # def compute_viewbox_from_ids(svg_text: str, ids: set[str], margin_ratio: float = 0.08): # """ # Compute a tight viewBox around the given country ids based on their path geometry. # Supports countries stored as groups. # """ # root = ET.fromstring(svg_text) # ids_lower = {i.lower() for i in ids} # xmin, ymin = np.inf, np.inf # xmax, ymax = -np.inf, -np.inf # def update_bbox_for_element(el): # nonlocal xmin, ymin, xmax, ymax # tag = el.tag.split("}")[-1].lower() # if tag == "path": # d = el.attrib.get("d") # if not d: # return # p = parse_path(d) # bxmin, bxmax, bymin, bymax = p.bbox() # xmin = min(xmin, bxmin) # xmax = max(xmax, bxmax) # ymin = min(ymin, bymin) # ymax = max(ymax, bymax) # elif tag == "polygon": # pts = el.attrib.get("points", "").strip() # if not pts: # return # coords = [] # for chunk in pts.replace(",", " ").split(): # coords.append(float(chunk)) # xs = coords[0::2] # ys = coords[1::2] # xmin = min(xmin, min(xs)) # xmax = max(xmax, max(xs)) # ymin = min(ymin, min(ys)) # ymax = max(ymax, max(ys)) # for el in root.iter(): # el_id = el.attrib.get("id") # if not el_id: # continue # el_id_lower = el_id.strip().lower() # if el_id_lower not in ids_lower: # continue # tag = el.tag.split("}")[-1].lower() # if tag in ("path", "polygon"): # update_bbox_for_element(el) # elif tag == "g": # # If a country is a group, include all its child shapes # for child in el.iter(): # update_bbox_for_element(child) # if not np.isfinite(xmin): # return None # w = xmax - xmin # h = ymax - ymin # mx = w * margin_ratio # my = h * margin_ratio # xmin -= mx # ymin -= my # w += 2 * mx # h += 2 * my # return (float(xmin), float(ymin), float(w), float(h)) # def set_viewbox(svg_text: str, viewbox): # root = ET.fromstring(svg_text) # root.attrib["viewBox"] = " ".join(str(x) for x in viewbox) # root.attrib["preserveAspectRatio"] = "xMidYMid meet" # return ET.tostring(root, encoding="unicode") # def render_map_html(predicted_iso2, conf_by_iso2=None): # if not SVG_PATH.exists(): # return f""" #
# Map SVG not found.
# Please add {SVG_PATH.as_posix()} to your Space repo. #
# """ # svg = SVG_PATH.read_text(encoding="utf-8") # predicted_ids = set(predicted_iso2 or []) # svg_colored = recolor_svg(svg, predicted_ids) # # AUTO-ZOOM to Arab world (fixed set of Arab countries, not โ€œpredicted onlyโ€) # vb = compute_viewbox_from_ids(svg_colored, ARAB_IDS, margin_ratio=0.10) # if vb is not None: # svg_colored = set_viewbox(svg_colored, vb) # # No JS, no script tags โ€” just return the updated SVG # return f""" #
# {svg_colored} #
# Highlighted = confidence โ‰ฅ threshold #
#
# """ # # ====================== # # Inference helpers # # ====================== # def predict_dialects_with_confidence(text, threshold=0.3): # """ # Predict Arabic dialects for the given text (multi-label) and return confidence scores. # """ # if not text or not text.strip(): # return pd.DataFrame({"Dialect": [], "Confidence": [], "Prediction": []}) # enc = base_tokenizer([text], truncation=True, padding=True, max_length=128, return_tensors="pt") # input_ids = enc["input_ids"].to(DEVICE) # attention_mask = enc["attention_mask"].to(DEVICE) # with torch.no_grad(): # outputs = base_model(input_ids=input_ids, attention_mask=attention_mask) # logits = outputs.logits # (1, num_labels) # probs = torch.sigmoid(logits).cpu().numpy().reshape(-1) # rows = [] # for dialect, p in zip(DIALECTS, probs): # rows.append({ # "Dialect": dialect, # "Confidence": f"{p:.4f}", # "Prediction": "โœ“ Valid" if p >= threshold else "โœ— Invalid", # }) # df = pd.DataFrame(rows) # df = df.sort_values("Confidence", ascending=False, key=lambda x: x.astype(float)) # return df # def predict_wrapper(model_key, text, threshold): # """ # Returns: # df (table), # summary (markdown), # map_html (HTML) # """ # load_multidialect_model(model_key) # df = predict_dialects_with_confidence(text, threshold) # predicted_dialects = df[df["Prediction"] == "โœ“ Valid"]["Dialect"].tolist() # summary = f"**Predicted Dialects ({len(predicted_dialects)}):** {', '.join(predicted_dialects) if predicted_dialects else 'None'}" # # Build predicted ISO2 list + confidences for tooltips # predicted_iso2 = [] # conf_by_iso2 = {} # for _, row in df.iterrows(): # if row["Prediction"] != "โœ“ Valid": # continue # dialect = row["Dialect"] # if dialect not in DIALECT_TO_ISO2: # continue # code = DIALECT_TO_ISO2[dialect] # predicted_iso2.append(code) # conf_by_iso2[code] = float(row["Confidence"]) # print("predicted_iso2:", predicted_iso2) # map_html = render_map_html(predicted_iso2, conf_by_iso2) # return df, summary, map_html # def predict_egyptian(text, threshold=0.5): # """ # Predict whether the input is Egyptian dialect using the dedicated model. # Returns a small dataframe and a markdown summary. # """ # if not text or not text.strip(): # return pd.DataFrame({"Label": [], "Confidence": []}), "**No input provided.**" # enc = egyptian_tok([text], truncation=True, padding=True, max_length=128, return_tensors="pt") # input_ids = enc["input_ids"].to(DEVICE) # attention_mask = enc["attention_mask"].to(DEVICE) # with torch.no_grad(): # outputs = egyptian_model(input_ids=input_ids, attention_mask=attention_mask) # logits = outputs.logits # (1, num_labels) # if _EGY_SIGMOID: # p = torch.sigmoid(logits).item() # else: # probs = torch.softmax(logits, dim=-1).squeeze(0) # p = probs[_EGY_POS_INDEX].item() # label = "โœ“ Egyptian" if p >= threshold else "โœ— Not Egyptian" # df = pd.DataFrame([{"Label": label, "Confidence": f"{p:.4f}"}]) # md = f"**Prediction:** {label} \n**Confidence:** {p:.4f} \n**Threshold:** {threshold:.2f}" # return df, md # # ====================== # # Gradio UI # # ====================== # with gr.Blocks(theme=gr.themes.Soft()) as demo: # gr.Markdown( # """ # # ๐ŸŒ LahjatBERT: Multi-Label Arabic Dialect Classifier # This demo predicts **which country-level Arabic dialects a sentence sounds natural in**. # Unlike classic โ€œpick one dialectโ€ systems, **a single sentence can be acceptable in multiple dialects**. # **How to use** # 1) Paste an Arabic sentence # 2) Adjust the **Confidence Threshold** (higher = fewer highlights) # 3) Click **Predict Dialects** # **How to interpret the results** # - **Highlighted countries** = dialects predicted as *valid/acceptable* for the sentence # """ # ) # with gr.Row(): # with gr.Column(scale=1): # model_dropdown = gr.Dropdown( # choices=list(MODEL_CHOICES.keys()), # value="LahjatBERT", # label="Model", # info="Select which LahjatBERT variant to use for prediction." # ) # text_input = gr.Textbox( # label="Arabic Text Input", # placeholder="ุฃุฏุฎู„ ู†ุตู‹ุง ุนุฑุจูŠู‹ุง ู‡ู†ุง... ู…ุซุงู„: ุดู„ูˆู†ูƒุŸ / ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ / ุดูˆ ุฃุฎุจุงุฑูƒุŸ", # lines=4, # rtl=True, # ) # threshold_slider = gr.Slider( # minimum=0.1, # maximum=0.9, # value=0.3, # step=0.05, # label="Confidence Threshold", # info=( # "Dialects with confidence โ‰ฅ threshold are marked as valid. " # "Try 0.30 for broader overlap, or 0.50 for stricter predictions." # ), # ) # predict_button = gr.Button("๐Ÿ” Predict Dialects", variant="primary") # gr.Markdown( # """ # **Tip:** If youโ€™re testing a sentence thatโ€™s close to Modern Standard Arabic (MSA), # you may see **many countries highlighted**โ€”thatโ€™s expected, because MSA-like text # can be acceptable across dialects. # """ # ) # with gr.Column(scale=1): # summary_output = gr.Markdown(label="Summary") # results_output = gr.Dataframe( # label="Detailed Results", # headers=["Dialect", "Confidence", "Prediction"], # datatype=["str", "str", "str"], # ) # gr.Markdown("---") # gr.Markdown( # """ # ## ๐Ÿ—บ๏ธ Dialect Map (Zoomed to the Arab World) # The map updates after each prediction. # Green countries indicate dialects predicted as valid at your selected threshold. # """ # ) # map_output = gr.HTML(label="Arab World Map", value=render_map_html([], {})) # gr.Markdown("---") # gr.Markdown( # """ # ## โœจ Try these examples # These examples are meant to show **dialect overlap**: # - Some expressions are widely shared and may light up multiple regions # - Others contain strong local signals (e.g., Egyptian, Gulf/Khaleeji, Levantine, Maghrebi) # """ # ) # gr.Examples( # examples=[ # # Broad / MSA-like (often acceptable widely) # ["ูƒูŠู ุญุงู„ูƒุŸ", 0.30], # ["ุงู„ุณู„ุงู… ุนู„ูŠูƒู… ูˆุฑุญู…ุฉ ุงู„ู„ู‡ ูˆุจุฑูƒุงุชู‡", 0.30], # # Egyptian-leaning # ["ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ ุนุงู…ู„ ุฅูŠู‡ุŸ", 0.30], # ["ู…ุด ูุงู‡ู… ู„ูŠู‡ ูƒุฏู‡ ุจุตุฑุงุญุฉ", 0.30], # # Gulf / Iraqi-leaning # ["ุดู„ูˆู†ูƒุŸ ุดุฎุจุงุฑูƒุŸ", 0.30], # ["ูˆูŠู†ูƒ ู…ู† ุฒู…ุงู†ุŸ", 0.30], # # Levantine-leaning # ["ุดูˆ ุฃุฎุจุงุฑูƒุŸ ูƒูŠููƒุŸ", 0.30], # ["ุจุฏู‘ูŠ ุฃุฑูˆุญ ู‡ู„ู‘ู‚", 0.30], # # Maghrebi-leaning (may vary depending on spelling) # ["ู„ุงุจุงุณ ุนู„ูŠูƒุŸ ูˆุงุด ุฑุงูƒุŸ", 0.30], # ["ุจุฒุงู ุฏูŠุงู„ ุงู„ู†ุงุณ ูƒูŠู‡ุถุฑูˆ ู‡ูƒุง", 0.30], # # Stricter threshold examples (fewer highlights) # ["ุดู„ูˆู†ูƒุŸ", 0.30], # ["ุฅุฒูŠูƒ ูŠุง ุนู…ุŸ", 0.30], # ], # inputs=[text_input, threshold_slider], # label="Click an example to auto-fill the input", # ) # predict_button.click( # fn=predict_wrapper, # inputs=[model_dropdown, text_input, threshold_slider], # outputs=[results_output, summary_output, map_output], # ) # gr.Markdown( # """ # --- # ### Notes # - The model outputs **multi-label** predictions: more than one dialect can be valid at once. # If you use this demo in research, please cite the accompanying paper. # """ # ) # # Launch # if __name__ == "__main__": # demo.launch()