Spaces:
Sleeping
Sleeping
| ##################### | |
| ### Version 1 ####### | |
| ##################### | |
| # import torch | |
| # import gradio as gr | |
| # from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| # import pandas as pd | |
| # # Load the model and tokenizer | |
| # model_name = "AHAAM/B2BERT" | |
| # model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| # tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # # Define dialects | |
| # DIALECTS = [ | |
| # "Algeria", "Bahrain", "Egypt", "Iraq", "Jordan", "Kuwait", "Lebanon", "Libya", | |
| # "Morocco", "Oman", "Palestine", "Qatar", "Saudi_Arabia", "Sudan", "Syria", | |
| # "Tunisia", "UAE", "Yemen" | |
| # ] | |
| # def predict_dialects_with_confidence(text, threshold=0.3): | |
| # """ | |
| # Predict Arabic dialects for the given text and return confidence scores. | |
| # Args: | |
| # text: Input Arabic text | |
| # threshold: Confidence threshold for classification (default 0.3) | |
| # Returns: | |
| # DataFrame with dialects and their confidence scores | |
| # """ | |
| # if not text.strip(): | |
| # return pd.DataFrame({"Dialect": [], "Confidence": [], "Prediction": []}) | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # model.to(device) | |
| # # Tokenize input | |
| # encodings = tokenizer( | |
| # [text], | |
| # truncation=True, | |
| # padding=True, | |
| # max_length=128, | |
| # return_tensors="pt" | |
| # ) | |
| # input_ids = encodings["input_ids"].to(device) | |
| # attention_mask = encodings["attention_mask"].to(device) | |
| # # Get predictions | |
| # with torch.no_grad(): | |
| # outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| # logits = outputs.logits | |
| # # Calculate probabilities | |
| # probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1) | |
| # # Create results dataframe | |
| # results = [] | |
| # for dialect, prob in zip(DIALECTS, probabilities): | |
| # prediction = "โ Valid" if prob >= threshold else "โ Invalid" | |
| # results.append({ | |
| # "Dialect": dialect, | |
| # "Confidence": f"{prob:.4f}", | |
| # "Prediction": prediction | |
| # }) | |
| # # Sort by confidence (descending) | |
| # df = pd.DataFrame(results) | |
| # df = df.sort_values("Confidence", ascending=False, key=lambda x: x.astype(float)) | |
| # return df | |
| # def predict_wrapper(text, threshold): | |
| # """Wrapper function for Gradio interface""" | |
| # df = predict_dialects_with_confidence(text, threshold) | |
| # # Also create a summary of predicted dialects | |
| # predicted = df[df["Prediction"] == "โ Valid"]["Dialect"].tolist() | |
| # summary = f"**Predicted Dialects ({len(predicted)}):** {', '.join(predicted) if predicted else 'None'}" | |
| # return df, summary | |
| # # Create Gradio interface | |
| # with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| # gr.Markdown( | |
| # """ | |
| # # ๐ B2BERT Arabic Dialect Classifier | |
| # This model identifies which Arabic dialects are valid for a given text input. | |
| # Enter Arabic text below to see the dialect predictions and confidence scores. | |
| # **Supported Dialects:** Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, | |
| # Morocco, Oman, Palestine, Qatar, Saudi Arabia, Sudan, Syria, Tunisia, UAE, Yemen | |
| # """ | |
| # ) | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # text_input = gr.Textbox( | |
| # label="Arabic Text Input", | |
| # placeholder="ุฃุฏุฎู ุงููุต ุงูุนุฑุจู ููุง... (e.g., ููู ุญุงููุ)", | |
| # lines=3, | |
| # rtl=True | |
| # ) | |
| # threshold_slider = gr.Slider( | |
| # minimum=0.1, | |
| # maximum=0.9, | |
| # value=0.3, | |
| # step=0.05, | |
| # label="Confidence Threshold", | |
| # info="Dialects with confidence above this threshold will be marked as valid" | |
| # ) | |
| # predict_button = gr.Button("๐ Predict Dialects", variant="primary") | |
| # with gr.Column(): | |
| # summary_output = gr.Markdown(label="Summary") | |
| # results_output = gr.Dataframe( | |
| # label="Detailed Results", | |
| # headers=["Dialect", "Confidence", "Prediction"], | |
| # datatype=["str", "str", "str"] | |
| # ) | |
| # # Examples | |
| # gr.Examples( | |
| # examples=[ | |
| # ["ููู ุญุงููุ", 0.3], | |
| # ["ุดููููุ", 0.3], | |
| # ["ุฅุฒูู ูุง ุนู ุ", 0.3], | |
| # ["ุดู ุฃุฎุจุงุฑูุ", 0.3], | |
| # ], | |
| # inputs=[text_input, threshold_slider], | |
| # label="Try these examples" | |
| # ) | |
| # # Connect button to function | |
| # predict_button.click( | |
| # fn=predict_wrapper, | |
| # inputs=[text_input, threshold_slider], | |
| # outputs=[results_output, summary_output] | |
| # ) | |
| # gr.Markdown( | |
| # """ | |
| # --- | |
| # **Model:** [AHAAM/B2BERT](https://huggingface.co/AHAAM/B2BERT) | |
| # **Note:** The model uses a multi-label classification approach where each dialect is | |
| # independently evaluated. A single text can be valid in multiple dialects. | |
| # """ | |
| # ) | |
| # # Launch the app | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| ##################### | |
| ### Version 2 ####### | |
| ##################### | |
| import json | |
| from pathlib import Path | |
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| AutoConfig, | |
| ) | |
| import re | |
| import xml.etree.ElementTree as ET | |
| import numpy as np | |
| from svgpathtools import parse_path | |
| # ====================== | |
| # Devices | |
| # ====================== | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ====================== | |
| # Multi-dialect model registry | |
| # ====================== | |
| MODEL_CHOICES = { | |
| "LahjatBERT": "Mohamedelzeftawy/b2bert_baseline", | |
| "LahjatBERT-CL-ALDI": "Mohamedelzeftawy/b2bert_cl_aldi", | |
| "LahjatBERT-CL-Cardinality": "Mohamedelzeftawy/b2bert_cl_cardinalty", | |
| } | |
| # Load default model at startup (LahjatBERT) | |
| _current_model_key = "LahjatBERT" | |
| base_model_name = MODEL_CHOICES[_current_model_key] | |
| base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE) | |
| base_tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| # Define dialects (order must match model's label mapping) | |
| DIALECTS = [ | |
| "Algeria", "Bahrain", "Egypt", "Iraq", "Jordan", "Kuwait", "Lebanon", "Libya", | |
| "Morocco", "Oman", "Palestine", "Qatar", "Saudi_Arabia", "Sudan", "Syria", | |
| "Tunisia", "UAE", "Yemen" | |
| ] | |
| # Dialect -> ISO2 country code mapping (must match SVG path ids) | |
| DIALECT_TO_ISO2 = { | |
| "Algeria": "dz", | |
| "Bahrain": "bh", | |
| "Egypt": "eg", | |
| "Iraq": "iq", | |
| "Jordan": "jo", | |
| "Kuwait": "kw", | |
| "Lebanon": "lb", | |
| "Libya": "ly", | |
| "Morocco": "ma", | |
| "Oman": "om", | |
| "Palestine": "ps", | |
| "Qatar": "qa", | |
| "Saudi_Arabia": "sa", | |
| "Sudan": "sd", | |
| "Syria": "sy", | |
| "Tunisia": "tn", | |
| "UAE": "ae", | |
| "Yemen": "ye", | |
| } | |
| # ====================== | |
| # Added: Egyptian-only model | |
| # ====================== | |
| egyptian_repo = "Mohamedelzeftawy/egyptian_marbert" | |
| egyptian_cfg = AutoConfig.from_pretrained(egyptian_repo) | |
| egyptian_tok = AutoTokenizer.from_pretrained(egyptian_repo) | |
| egyptian_model = AutoModelForSequenceClassification.from_pretrained( | |
| egyptian_repo, config=egyptian_cfg | |
| ).to(DEVICE) | |
| # Heuristic: if num_labels==1 -> sigmoid; else softmax and assume positive label index=1 | |
| _EGY_SIGMOID = (egyptian_cfg.num_labels == 1) | |
| _EGY_POS_INDEX = 1 if egyptian_cfg.num_labels >= 2 else 0 | |
| # ====================== | |
| # Map rendering | |
| # ====================== | |
| # Put your SVG here: repo/assets/arab_world.svg | |
| # IMPORTANT: each country shape must have id="EG", id="SA", id="AE", etc (ISO2) | |
| SVG_PATH = Path("assets/world-map.svg") | |
| SVG_NS = "http://www.w3.org/2000/svg" | |
| ET.register_namespace("", SVG_NS) | |
| def load_multidialect_model(model_key: str): | |
| """ | |
| Load the selected multi-dialect model + tokenizer. | |
| Uses global variables so the rest of your pipeline stays unchanged. | |
| """ | |
| global base_model, base_tokenizer, base_model_name, _current_model_key | |
| if model_key == _current_model_key: | |
| return # already loaded | |
| repo = MODEL_CHOICES[model_key] | |
| base_model_name = repo | |
| base_model = AutoModelForSequenceClassification.from_pretrained(repo).to(DEVICE) | |
| base_tokenizer = AutoTokenizer.from_pretrained(repo) | |
| _current_model_key = model_key | |
| def _merge_style(old_style: str, updates: dict) -> str: | |
| """ | |
| Merge CSS style strings (e.g., "fill:#000;stroke:#fff") with updates dict. | |
| """ | |
| style_map = {} | |
| if old_style: | |
| for part in old_style.split(";"): | |
| part = part.strip() | |
| if not part or ":" not in part: | |
| continue | |
| k, v = part.split(":", 1) | |
| style_map[k.strip()] = v.strip() | |
| style_map.update(updates) | |
| return ";".join([f"{k}:{v}" for k, v in style_map.items() if v is not None]) | |
| def get_gradient_color(confidence: float, threshold: float) -> str: | |
| """ | |
| Generate a gradient color from dark to bright green based on confidence. | |
| Uses a power function for sharper contrast. | |
| Args: | |
| confidence: Probability score (0-1) | |
| threshold: Minimum threshold for prediction | |
| Returns: | |
| Hex color string | |
| """ | |
| if confidence < threshold: | |
| return "#101418" # base color (no prediction) | |
| # Normalize confidence to [0, 1] range starting from threshold | |
| normalized = (confidence - threshold) / (1.0 - threshold) | |
| # Apply power function for sharper gradient (higher power = sharper) | |
| # You can adjust the exponent: 2.0 = moderate sharp, 3.0 = very sharp | |
| normalized = normalized ** 2.5 | |
| # Define gradient from dark green to bright green | |
| # Dark green: #1a5e3a, Bright green: #2ecc71 | |
| r_start, g_start, b_start = 0x1a, 0x5e, 0x3a # dark green | |
| r_end, g_end, b_end = 0x2e, 0xcc, 0x71 # bright green | |
| # Linear interpolation | |
| r = int(r_start + (r_end - r_start) * normalized) | |
| g = int(g_start + (g_end - g_start) * normalized) | |
| b = int(b_start + (b_end - b_start) * normalized) | |
| return f"#{r:02x}{g:02x}{b:02x}" | |
| def recolor_svg(svg_text: str, conf_by_iso2: dict, threshold: float) -> str: | |
| """ | |
| Recolor SVG elements by id with gradient colors based on confidence. | |
| Handles cases where a country is stored as <g id="ye"> ... <path/> ... </g>. | |
| Args: | |
| svg_text: SVG content as string | |
| conf_by_iso2: Dict mapping ISO2 codes to confidence scores | |
| threshold: Confidence threshold for predictions | |
| """ | |
| root = ET.fromstring(svg_text) | |
| # normalize ids to lowercase for robust matching | |
| conf_by_iso2_lower = {k.lower(): v for k, v in (conf_by_iso2 or {}).items()} | |
| base_fill = "#101418" | |
| base_stroke = "#2a2f3a" | |
| active_stroke = "#ffffff" | |
| def apply_style(el, confidence: float = None): | |
| tag = el.tag.split("}")[-1].lower() | |
| if tag not in ("path", "polygon"): | |
| return | |
| if confidence is not None and confidence >= threshold: | |
| fill_color = get_gradient_color(confidence, threshold) | |
| stroke_color = active_stroke | |
| stroke_width = "0.9" | |
| opacity = "1" | |
| else: | |
| fill_color = base_fill | |
| stroke_color = base_stroke | |
| stroke_width = "0.5" | |
| opacity = "1" | |
| updates = { | |
| "fill": fill_color, | |
| "stroke": stroke_color, | |
| "stroke-width": stroke_width, | |
| "opacity": opacity, | |
| } | |
| # remove conflicting attrs | |
| if "fill" in el.attrib: | |
| del el.attrib["fill"] | |
| if "stroke" in el.attrib: | |
| del el.attrib["stroke"] | |
| el.attrib["style"] = _merge_style(el.attrib.get("style", ""), updates) | |
| # Pass 1: default style for everything drawable | |
| for el in root.iter(): | |
| apply_style(el, confidence=None) | |
| # Pass 2: apply gradient colors based on confidence | |
| for el in root.iter(): | |
| el_id = el.attrib.get("id") | |
| if not el_id: | |
| continue | |
| el_id_lower = el_id.strip().lower() | |
| confidence = conf_by_iso2_lower.get(el_id_lower) | |
| if confidence is None: | |
| continue | |
| tag = el.tag.split("}")[-1].lower() | |
| if tag in ("path", "polygon"): | |
| apply_style(el, confidence=confidence) | |
| elif tag == "g": | |
| # If the country is a GROUP, color all its child shapes | |
| for child in el.iter(): | |
| apply_style(child, confidence=confidence) | |
| return ET.tostring(root, encoding="unicode") | |
| ARAB_IDS = { | |
| "ma","dz","tn","ly","eg","sd", | |
| "ps","jo","lb","sy","iq", | |
| "sa","kw","bh","qa","ae","om","ye" | |
| } | |
| def compute_viewbox_from_ids(svg_text: str, ids: set[str], margin_ratio: float = 0.08): | |
| """ | |
| Compute a tight viewBox around the given country ids based on their path geometry. | |
| Supports countries stored as <g id="..."> groups. | |
| """ | |
| root = ET.fromstring(svg_text) | |
| ids_lower = {i.lower() for i in ids} | |
| xmin, ymin = np.inf, np.inf | |
| xmax, ymax = -np.inf, -np.inf | |
| def update_bbox_for_element(el): | |
| nonlocal xmin, ymin, xmax, ymax | |
| tag = el.tag.split("}")[-1].lower() | |
| if tag == "path": | |
| d = el.attrib.get("d") | |
| if not d: | |
| return | |
| p = parse_path(d) | |
| bxmin, bxmax, bymin, bymax = p.bbox() | |
| xmin = min(xmin, bxmin) | |
| xmax = max(xmax, bxmax) | |
| ymin = min(ymin, bymin) | |
| ymax = max(ymax, bymax) | |
| elif tag == "polygon": | |
| pts = el.attrib.get("points", "").strip() | |
| if not pts: | |
| return | |
| coords = [] | |
| for chunk in pts.replace(",", " ").split(): | |
| coords.append(float(chunk)) | |
| xs = coords[0::2] | |
| ys = coords[1::2] | |
| xmin = min(xmin, min(xs)) | |
| xmax = max(xmax, max(xs)) | |
| ymin = min(ymin, min(ys)) | |
| ymax = max(ymax, max(ys)) | |
| for el in root.iter(): | |
| el_id = el.attrib.get("id") | |
| if not el_id: | |
| continue | |
| el_id_lower = el_id.strip().lower() | |
| if el_id_lower not in ids_lower: | |
| continue | |
| tag = el.tag.split("}")[-1].lower() | |
| if tag in ("path", "polygon"): | |
| update_bbox_for_element(el) | |
| elif tag == "g": | |
| # If a country is a group, include all its child shapes | |
| for child in el.iter(): | |
| update_bbox_for_element(child) | |
| if not np.isfinite(xmin): | |
| return None | |
| w = xmax - xmin | |
| h = ymax - ymin | |
| mx = w * margin_ratio | |
| my = h * margin_ratio | |
| xmin -= mx | |
| ymin -= my | |
| w += 2 * mx | |
| h += 2 * my | |
| return (float(xmin), float(ymin), float(w), float(h)) | |
| def set_viewbox(svg_text: str, viewbox): | |
| root = ET.fromstring(svg_text) | |
| root.attrib["viewBox"] = " ".join(str(x) for x in viewbox) | |
| root.attrib["preserveAspectRatio"] = "xMidYMid meet" | |
| return ET.tostring(root, encoding="unicode") | |
| def render_map_html(conf_by_iso2, threshold): | |
| """ | |
| Render the map with gradient colors based on confidence scores. | |
| Args: | |
| conf_by_iso2: Dict mapping ISO2 codes to confidence scores | |
| threshold: Confidence threshold for predictions | |
| """ | |
| if not SVG_PATH.exists(): | |
| return f""" | |
| <div style="padding:12px; border:1px solid #ddd; border-radius:10px;"> | |
| <b>Map SVG not found.</b><br/> | |
| Please add <code>{SVG_PATH.as_posix()}</code> to your Space repo. | |
| </div> | |
| """ | |
| svg = SVG_PATH.read_text(encoding="utf-8") | |
| svg_colored = recolor_svg(svg, conf_by_iso2, threshold) | |
| # AUTO-ZOOM to Arab world (fixed set of Arab countries, not "predicted only") | |
| vb = compute_viewbox_from_ids(svg_colored, ARAB_IDS, margin_ratio=0.10) | |
| if vb is not None: | |
| svg_colored = set_viewbox(svg_colored, vb) | |
| # Add a legend showing the gradient scale | |
| legend_html = """ | |
| <div style="margin-top: 12px; padding: 12px; background: #1a1d23; border-radius: 8px;"> | |
| <div style="font-size: 13px; color: #e0e0e0; margin-bottom: 8px; font-weight: 500;"> | |
| Confidence Scale | |
| </div> | |
| <div style="display: flex; align-items: center; gap: 8px;"> | |
| <span style="font-size: 11px; color: #999;">Low</span> | |
| <div style="flex: 1; height: 20px; background: linear-gradient(to right, #1a5e3a, #2ecc71); border-radius: 4px;"></div> | |
| <span style="font-size: 11px; color: #999;">High</span> | |
| </div> | |
| <div style="margin-top: 6px; font-size: 11px; color: #888;"> | |
| Darker = closer to threshold | Brighter = higher confidence | |
| </div> | |
| </div> | |
| """ | |
| return f""" | |
| <div style="width:100%; max-width: 950px; margin: 0 auto;"> | |
| {svg_colored} | |
| {legend_html} | |
| </div> | |
| """ | |
| # ====================== | |
| # Inference helpers | |
| # ====================== | |
| def predict_dialects_with_confidence(text, threshold=0.3): | |
| """ | |
| Predict Arabic dialects for the given text (multi-label) and return confidence scores. | |
| """ | |
| if not text or not text.strip(): | |
| return pd.DataFrame({"Dialect": [], "Confidence": [], "Prediction": []}) | |
| enc = base_tokenizer([text], truncation=True, padding=True, max_length=128, return_tensors="pt") | |
| input_ids = enc["input_ids"].to(DEVICE) | |
| attention_mask = enc["attention_mask"].to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = base_model(input_ids=input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits # (1, num_labels) | |
| probs = torch.sigmoid(logits).cpu().numpy().reshape(-1) | |
| rows = [] | |
| for dialect, p in zip(DIALECTS, probs): | |
| rows.append({ | |
| "Dialect": dialect, | |
| "Confidence": f"{p:.4f}", | |
| "Prediction": "โ Valid" if p >= threshold else "โ Invalid", | |
| }) | |
| df = pd.DataFrame(rows) | |
| df = df.sort_values("Confidence", ascending=False, key=lambda x: x.astype(float)) | |
| return df | |
| def predict_wrapper(model_key, text, threshold): | |
| """ | |
| Returns: | |
| df (table), | |
| summary (markdown), | |
| map_html (HTML) | |
| """ | |
| load_multidialect_model(model_key) | |
| df = predict_dialects_with_confidence(text, threshold) | |
| predicted_dialects = df[df["Prediction"] == "โ Valid"]["Dialect"].tolist() | |
| summary = f"**Predicted Dialects ({len(predicted_dialects)}):** {', '.join(predicted_dialects) if predicted_dialects else 'None'}" | |
| # Build confidence dict for ALL dialects (not just predicted ones) | |
| conf_by_iso2 = {} | |
| for _, row in df.iterrows(): | |
| dialect = row["Dialect"] | |
| if dialect not in DIALECT_TO_ISO2: | |
| continue | |
| code = DIALECT_TO_ISO2[dialect] | |
| conf_by_iso2[code] = float(row["Confidence"]) | |
| print("conf_by_iso2:", conf_by_iso2) | |
| map_html = render_map_html(conf_by_iso2, threshold) | |
| return df, summary, map_html | |
| def predict_egyptian(text, threshold=0.5): | |
| """ | |
| Predict whether the input is Egyptian dialect using the dedicated model. | |
| Returns a small dataframe and a markdown summary. | |
| """ | |
| if not text or not text.strip(): | |
| return pd.DataFrame({"Label": [], "Confidence": []}), "**No input provided.**" | |
| enc = egyptian_tok([text], truncation=True, padding=True, max_length=128, return_tensors="pt") | |
| input_ids = enc["input_ids"].to(DEVICE) | |
| attention_mask = enc["attention_mask"].to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = egyptian_model(input_ids=input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits # (1, num_labels) | |
| if _EGY_SIGMOID: | |
| p = torch.sigmoid(logits).item() | |
| else: | |
| probs = torch.softmax(logits, dim=-1).squeeze(0) | |
| p = probs[_EGY_POS_INDEX].item() | |
| label = "โ Egyptian" if p >= threshold else "โ Not Egyptian" | |
| df = pd.DataFrame([{"Label": label, "Confidence": f"{p:.4f}"}]) | |
| md = f"**Prediction:** {label} \n**Confidence:** {p:.4f} \n**Threshold:** {threshold:.2f}" | |
| return df, md | |
| # ====================== | |
| # Gradio UI | |
| # ====================== | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # ๐ LahjatBERT: Multi-Label Arabic Dialect Classifier | |
| This demo predicts **which country-level Arabic dialects a sentence sounds natural in**. | |
| Unlike classic "pick one dialect" systems, **a single sentence can be acceptable in multiple dialects**. | |
| **How to use** | |
| 1) Paste an Arabic sentence | |
| 2) Adjust the **Confidence Threshold** (higher = fewer highlights) | |
| 3) Click **Predict Dialects** | |
| **How to interpret the results** | |
| - **Highlighted countries** = dialects predicted as *valid/acceptable* for the sentence | |
| - **Color intensity** = confidence level (darker green = closer to threshold, brighter = higher confidence) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODEL_CHOICES.keys()), | |
| value="LahjatBERT", | |
| label="Model", | |
| info="Select which LahjatBERT variant to use for prediction." | |
| ) | |
| text_input = gr.Textbox( | |
| label="Arabic Text Input", | |
| placeholder="ุฃุฏุฎู ูุตูุง ุนุฑุจููุง ููุง... ู ุซุงู: ุดููููุ / ุฅุฒูู ูุง ุนู ุ / ุดู ุฃุฎุจุงุฑูุ", | |
| lines=4, | |
| rtl=True, | |
| ) | |
| threshold_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.3, | |
| step=0.05, | |
| label="Confidence Threshold", | |
| info=( | |
| "Dialects with confidence โฅ threshold are marked as valid. " | |
| "Try 0.30 for broader overlap, or 0.50 for stricter predictions." | |
| ), | |
| ) | |
| predict_button = gr.Button("๐ Predict Dialects", variant="primary") | |
| gr.Markdown( | |
| """ | |
| **Tip:** If you're testing a sentence that's close to Modern Standard Arabic (MSA), | |
| you may see **many countries highlighted**โthat's expected, because MSA-like text | |
| can be acceptable across dialects. | |
| """ | |
| ) | |
| with gr.Column(scale=1): | |
| summary_output = gr.Markdown(label="Summary") | |
| results_output = gr.Dataframe( | |
| label="Detailed Results", | |
| headers=["Dialect", "Confidence", "Prediction"], | |
| datatype=["str", "str", "str"], | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown( | |
| """ | |
| ## ๐บ๏ธ Dialect Map (Zoomed to the Arab World) | |
| The map updates after each prediction. | |
| Green countries indicate dialects predicted as valid at your selected threshold. | |
| The intensity of the green color reflects the confidence level. | |
| """ | |
| ) | |
| map_output = gr.HTML(label="Arab World Map", value=render_map_html({}, 0.3)) | |
| gr.Markdown("---") | |
| gr.Markdown( | |
| """ | |
| ## โจ Try these examples | |
| These examples are meant to show **dialect overlap**: | |
| - Some expressions are widely shared and may light up multiple regions | |
| - Others contain strong local signals (e.g., Egyptian, Gulf/Khaleeji, Levantine, Maghrebi) | |
| """ | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| # Broad / MSA-like (often acceptable widely) | |
| ["ููู ุญุงููุ", 0.30], | |
| ["ุงูุณูุงู ุนูููู ูุฑุญู ุฉ ุงููู ูุจุฑูุงุชู", 0.30], | |
| # Egyptian-leaning | |
| ["ุฅุฒูู ูุง ุนู ุ ุนุงู ู ุฅููุ", 0.30], | |
| ["ู ุด ูุงูู ููู ูุฏู ุจุตุฑุงุญุฉ", 0.30], | |
| # Gulf / Iraqi-leaning | |
| ["ุดููููุ ุดุฎุจุงุฑูุ", 0.30], | |
| ["ูููู ู ู ุฒู ุงูุ", 0.30], | |
| # Levantine-leaning | |
| ["ุดู ุฃุฎุจุงุฑูุ ููููุ", 0.30], | |
| ["ุจุฏูู ุฃุฑูุญ ูููู", 0.30], | |
| # Maghrebi-leaning (may vary depending on spelling) | |
| ["ูุงุจุงุณ ุนูููุ ูุงุด ุฑุงูุ", 0.30], | |
| ["ุจุฒุงู ุฏูุงู ุงููุงุณ ูููุถุฑู ููุง", 0.30], | |
| # Stricter threshold examples (fewer highlights) | |
| ["ุดููููุ", 0.30], | |
| ["ุฅุฒูู ูุง ุนู ุ", 0.30], | |
| ], | |
| inputs=[text_input, threshold_slider], | |
| label="Click an example to auto-fill the input", | |
| ) | |
| predict_button.click( | |
| fn=predict_wrapper, | |
| inputs=[model_dropdown, text_input, threshold_slider], | |
| outputs=[results_output, summary_output, map_output], | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### Notes | |
| - The model outputs **multi-label** predictions: more than one dialect can be valid at once. | |
| - Countries are colored with a **gradient** based on confidence: darker green means the confidence is closer to the threshold, brighter green means higher confidence. | |
| If you use this demo in research, please cite the accompanying paper. | |
| """ | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch() | |
| ##################### | |
| ### Version 3 ####### | |
| ##################### | |
| # import json | |
| # from pathlib import Path | |
| # import torch | |
| # import gradio as gr | |
| # import pandas as pd | |
| # from transformers import ( | |
| # AutoModelForSequenceClassification, | |
| # AutoTokenizer, | |
| # AutoConfig, | |
| # ) | |
| # import re | |
| # import xml.etree.ElementTree as ET | |
| # import numpy as np | |
| # import xml.etree.ElementTree as ET | |
| # from svgpathtools import parse_path | |
| # # ====================== | |
| # # Devices | |
| # # ====================== | |
| # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # # # ====================== | |
| # # # Base multi-dialect model (B2BERT) | |
| # # # ====================== | |
| # # base_model_name = "Mohamedelzeftawy/b2bert_baseline" | |
| # # base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE) | |
| # # base_tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| # # ====================== | |
| # # Multi-dialect model registry | |
| # # ====================== | |
| # MODEL_CHOICES = { | |
| # "LahjatBERT": "Mohamedelzeftawy/b2bert_baseline", # default (current) | |
| # "LahjatBERT-CL-ALDI": "Mohamedelzeftawy/b2bert_cl_aldi", | |
| # "LahjatBERT-CL-Cardinality": "Mohamedelzeftawy/b2bert_cl_cardinalty", | |
| # } | |
| # # Load default model at startup (LahjatBERT) | |
| # _current_model_key = "LahjatBERT" | |
| # base_model_name = MODEL_CHOICES[_current_model_key] | |
| # base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE) | |
| # base_tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| # # Define dialects (order must match model's label mapping) | |
| # DIALECTS = [ | |
| # "Algeria", "Bahrain", "Egypt", "Iraq", "Jordan", "Kuwait", "Lebanon", "Libya", | |
| # "Morocco", "Oman", "Palestine", "Qatar", "Saudi_Arabia", "Sudan", "Syria", | |
| # "Tunisia", "UAE", "Yemen" | |
| # ] | |
| # # Dialect -> ISO2 country code mapping (must match SVG path ids) | |
| # DIALECT_TO_ISO2 = { | |
| # "Algeria": "dz", | |
| # "Bahrain": "bh", | |
| # "Egypt": "eg", | |
| # "Iraq": "iq", | |
| # "Jordan": "jo", | |
| # "Kuwait": "kw", | |
| # "Lebanon": "lb", | |
| # "Libya": "ly", | |
| # "Morocco": "ma", | |
| # "Oman": "om", | |
| # "Palestine": "ps", | |
| # "Qatar": "qa", | |
| # "Saudi_Arabia": "sa", | |
| # "Sudan": "sd", | |
| # "Syria": "sy", | |
| # "Tunisia": "tn", | |
| # "UAE": "ae", | |
| # "Yemen": "ye", | |
| # } | |
| # # ====================== | |
| # # Added: Egyptian-only model | |
| # # ====================== | |
| # egyptian_repo = "Mohamedelzeftawy/egyptian_marbert" | |
| # egyptian_cfg = AutoConfig.from_pretrained(egyptian_repo) | |
| # egyptian_tok = AutoTokenizer.from_pretrained(egyptian_repo) | |
| # egyptian_model = AutoModelForSequenceClassification.from_pretrained( | |
| # egyptian_repo, config=egyptian_cfg | |
| # ).to(DEVICE) | |
| # # Heuristic: if num_labels==1 -> sigmoid; else softmax and assume positive label index=1 | |
| # _EGY_SIGMOID = (egyptian_cfg.num_labels == 1) | |
| # _EGY_POS_INDEX = 1 if egyptian_cfg.num_labels >= 2 else 0 | |
| # # ====================== | |
| # # Map rendering | |
| # # ====================== | |
| # # Put your SVG here: repo/assets/arab_world.svg | |
| # # IMPORTANT: each country shape must have id="EG", id="SA", id="AE", etc (ISO2) | |
| # SVG_PATH = Path("assets/world-map.svg") | |
| # SVG_NS = "http://www.w3.org/2000/svg" | |
| # ET.register_namespace("", SVG_NS) | |
| # def load_multidialect_model(model_key: str): | |
| # """ | |
| # Load the selected multi-dialect model + tokenizer. | |
| # Uses global variables so the rest of your pipeline stays unchanged. | |
| # """ | |
| # global base_model, base_tokenizer, base_model_name, _current_model_key | |
| # if model_key == _current_model_key: | |
| # return # already loaded | |
| # repo = MODEL_CHOICES[model_key] | |
| # base_model_name = repo | |
| # base_model = AutoModelForSequenceClassification.from_pretrained(repo).to(DEVICE) | |
| # base_tokenizer = AutoTokenizer.from_pretrained(repo) | |
| # _current_model_key = model_key | |
| # def _merge_style(old_style: str, updates: dict) -> str: | |
| # """ | |
| # Merge CSS style strings (e.g., "fill:#000;stroke:#fff") with updates dict. | |
| # """ | |
| # style_map = {} | |
| # if old_style: | |
| # for part in old_style.split(";"): | |
| # part = part.strip() | |
| # if not part or ":" not in part: | |
| # continue | |
| # k, v = part.split(":", 1) | |
| # style_map[k.strip()] = v.strip() | |
| # style_map.update(updates) | |
| # return ";".join([f"{k}:{v}" for k, v in style_map.items() if v is not None]) | |
| # def recolor_svg(svg_text: str, predicted_ids: set[str]) -> str: | |
| # """ | |
| # Recolor SVG elements by id. | |
| # Handles cases where a country is stored as <g id="ye"> ... <path/> ... </g>. | |
| # """ | |
| # root = ET.fromstring(svg_text) | |
| # # normalize ids to lowercase for robust matching | |
| # predicted_lower = {p.lower() for p in predicted_ids} | |
| # base_fill = "#101418" | |
| # base_stroke = "#2a2f3a" | |
| # active_fill = "#2ecc71" | |
| # active_stroke = "#ffffff" | |
| # def apply_style(el, active: bool): | |
| # tag = el.tag.split("}")[-1].lower() | |
| # if tag not in ("path", "polygon"): | |
| # return | |
| # updates = { | |
| # "fill": active_fill if active else base_fill, | |
| # "stroke": active_stroke if active else base_stroke, | |
| # "stroke-width": "0.9" if active else "0.5", | |
| # "opacity": "1", | |
| # } | |
| # # remove conflicting attrs | |
| # if "fill" in el.attrib: | |
| # del el.attrib["fill"] | |
| # if "stroke" in el.attrib: | |
| # del el.attrib["stroke"] | |
| # el.attrib["style"] = _merge_style(el.attrib.get("style", ""), updates) | |
| # # Pass 1: default style for everything drawable (so the map stays consistent) | |
| # for el in root.iter(): | |
| # apply_style(el, active=False) | |
| # # Pass 2: activate predicted ids | |
| # for el in root.iter(): | |
| # el_id = el.attrib.get("id") | |
| # if not el_id: | |
| # continue | |
| # el_id_lower = el_id.strip().lower() | |
| # if el_id_lower not in predicted_lower: | |
| # continue | |
| # tag = el.tag.split("}")[-1].lower() | |
| # if tag in ("path", "polygon"): | |
| # apply_style(el, active=True) | |
| # elif tag == "g": | |
| # # Important: if the country is a GROUP, color all its child shapes | |
| # for child in el.iter(): | |
| # apply_style(child, active=True) | |
| # return ET.tostring(root, encoding="unicode") | |
| # ARAB_IDS = { | |
| # "ma","dz","tn","ly","eg","sd", | |
| # "ps","jo","lb","sy","iq", | |
| # "sa","kw","bh","qa","ae","om","ye" | |
| # } | |
| # def compute_viewbox_from_ids(svg_text: str, ids: set[str], margin_ratio: float = 0.08): | |
| # """ | |
| # Compute a tight viewBox around the given country ids based on their path geometry. | |
| # Supports countries stored as <g id="..."> groups. | |
| # """ | |
| # root = ET.fromstring(svg_text) | |
| # ids_lower = {i.lower() for i in ids} | |
| # xmin, ymin = np.inf, np.inf | |
| # xmax, ymax = -np.inf, -np.inf | |
| # def update_bbox_for_element(el): | |
| # nonlocal xmin, ymin, xmax, ymax | |
| # tag = el.tag.split("}")[-1].lower() | |
| # if tag == "path": | |
| # d = el.attrib.get("d") | |
| # if not d: | |
| # return | |
| # p = parse_path(d) | |
| # bxmin, bxmax, bymin, bymax = p.bbox() | |
| # xmin = min(xmin, bxmin) | |
| # xmax = max(xmax, bxmax) | |
| # ymin = min(ymin, bymin) | |
| # ymax = max(ymax, bymax) | |
| # elif tag == "polygon": | |
| # pts = el.attrib.get("points", "").strip() | |
| # if not pts: | |
| # return | |
| # coords = [] | |
| # for chunk in pts.replace(",", " ").split(): | |
| # coords.append(float(chunk)) | |
| # xs = coords[0::2] | |
| # ys = coords[1::2] | |
| # xmin = min(xmin, min(xs)) | |
| # xmax = max(xmax, max(xs)) | |
| # ymin = min(ymin, min(ys)) | |
| # ymax = max(ymax, max(ys)) | |
| # for el in root.iter(): | |
| # el_id = el.attrib.get("id") | |
| # if not el_id: | |
| # continue | |
| # el_id_lower = el_id.strip().lower() | |
| # if el_id_lower not in ids_lower: | |
| # continue | |
| # tag = el.tag.split("}")[-1].lower() | |
| # if tag in ("path", "polygon"): | |
| # update_bbox_for_element(el) | |
| # elif tag == "g": | |
| # # If a country is a group, include all its child shapes | |
| # for child in el.iter(): | |
| # update_bbox_for_element(child) | |
| # if not np.isfinite(xmin): | |
| # return None | |
| # w = xmax - xmin | |
| # h = ymax - ymin | |
| # mx = w * margin_ratio | |
| # my = h * margin_ratio | |
| # xmin -= mx | |
| # ymin -= my | |
| # w += 2 * mx | |
| # h += 2 * my | |
| # return (float(xmin), float(ymin), float(w), float(h)) | |
| # def set_viewbox(svg_text: str, viewbox): | |
| # root = ET.fromstring(svg_text) | |
| # root.attrib["viewBox"] = " ".join(str(x) for x in viewbox) | |
| # root.attrib["preserveAspectRatio"] = "xMidYMid meet" | |
| # return ET.tostring(root, encoding="unicode") | |
| # def render_map_html(predicted_iso2, conf_by_iso2=None): | |
| # if not SVG_PATH.exists(): | |
| # return f""" | |
| # <div style="padding:12px; border:1px solid #ddd; border-radius:10px;"> | |
| # <b>Map SVG not found.</b><br/> | |
| # Please add <code>{SVG_PATH.as_posix()}</code> to your Space repo. | |
| # </div> | |
| # """ | |
| # svg = SVG_PATH.read_text(encoding="utf-8") | |
| # predicted_ids = set(predicted_iso2 or []) | |
| # svg_colored = recolor_svg(svg, predicted_ids) | |
| # # AUTO-ZOOM to Arab world (fixed set of Arab countries, not โpredicted onlyโ) | |
| # vb = compute_viewbox_from_ids(svg_colored, ARAB_IDS, margin_ratio=0.10) | |
| # if vb is not None: | |
| # svg_colored = set_viewbox(svg_colored, vb) | |
| # # No JS, no script tags โ just return the updated SVG | |
| # return f""" | |
| # <div style="width:100%; max-width: 950px; margin: 0 auto;"> | |
| # {svg_colored} | |
| # <div style="margin-top:8px; font-size:12px; color:#999;"> | |
| # Highlighted = confidence โฅ threshold | |
| # </div> | |
| # </div> | |
| # """ | |
| # # ====================== | |
| # # Inference helpers | |
| # # ====================== | |
| # def predict_dialects_with_confidence(text, threshold=0.3): | |
| # """ | |
| # Predict Arabic dialects for the given text (multi-label) and return confidence scores. | |
| # """ | |
| # if not text or not text.strip(): | |
| # return pd.DataFrame({"Dialect": [], "Confidence": [], "Prediction": []}) | |
| # enc = base_tokenizer([text], truncation=True, padding=True, max_length=128, return_tensors="pt") | |
| # input_ids = enc["input_ids"].to(DEVICE) | |
| # attention_mask = enc["attention_mask"].to(DEVICE) | |
| # with torch.no_grad(): | |
| # outputs = base_model(input_ids=input_ids, attention_mask=attention_mask) | |
| # logits = outputs.logits # (1, num_labels) | |
| # probs = torch.sigmoid(logits).cpu().numpy().reshape(-1) | |
| # rows = [] | |
| # for dialect, p in zip(DIALECTS, probs): | |
| # rows.append({ | |
| # "Dialect": dialect, | |
| # "Confidence": f"{p:.4f}", | |
| # "Prediction": "โ Valid" if p >= threshold else "โ Invalid", | |
| # }) | |
| # df = pd.DataFrame(rows) | |
| # df = df.sort_values("Confidence", ascending=False, key=lambda x: x.astype(float)) | |
| # return df | |
| # def predict_wrapper(model_key, text, threshold): | |
| # """ | |
| # Returns: | |
| # df (table), | |
| # summary (markdown), | |
| # map_html (HTML) | |
| # """ | |
| # load_multidialect_model(model_key) | |
| # df = predict_dialects_with_confidence(text, threshold) | |
| # predicted_dialects = df[df["Prediction"] == "โ Valid"]["Dialect"].tolist() | |
| # summary = f"**Predicted Dialects ({len(predicted_dialects)}):** {', '.join(predicted_dialects) if predicted_dialects else 'None'}" | |
| # # Build predicted ISO2 list + confidences for tooltips | |
| # predicted_iso2 = [] | |
| # conf_by_iso2 = {} | |
| # for _, row in df.iterrows(): | |
| # if row["Prediction"] != "โ Valid": | |
| # continue | |
| # dialect = row["Dialect"] | |
| # if dialect not in DIALECT_TO_ISO2: | |
| # continue | |
| # code = DIALECT_TO_ISO2[dialect] | |
| # predicted_iso2.append(code) | |
| # conf_by_iso2[code] = float(row["Confidence"]) | |
| # print("predicted_iso2:", predicted_iso2) | |
| # map_html = render_map_html(predicted_iso2, conf_by_iso2) | |
| # return df, summary, map_html | |
| # def predict_egyptian(text, threshold=0.5): | |
| # """ | |
| # Predict whether the input is Egyptian dialect using the dedicated model. | |
| # Returns a small dataframe and a markdown summary. | |
| # """ | |
| # if not text or not text.strip(): | |
| # return pd.DataFrame({"Label": [], "Confidence": []}), "**No input provided.**" | |
| # enc = egyptian_tok([text], truncation=True, padding=True, max_length=128, return_tensors="pt") | |
| # input_ids = enc["input_ids"].to(DEVICE) | |
| # attention_mask = enc["attention_mask"].to(DEVICE) | |
| # with torch.no_grad(): | |
| # outputs = egyptian_model(input_ids=input_ids, attention_mask=attention_mask) | |
| # logits = outputs.logits # (1, num_labels) | |
| # if _EGY_SIGMOID: | |
| # p = torch.sigmoid(logits).item() | |
| # else: | |
| # probs = torch.softmax(logits, dim=-1).squeeze(0) | |
| # p = probs[_EGY_POS_INDEX].item() | |
| # label = "โ Egyptian" if p >= threshold else "โ Not Egyptian" | |
| # df = pd.DataFrame([{"Label": label, "Confidence": f"{p:.4f}"}]) | |
| # md = f"**Prediction:** {label} \n**Confidence:** {p:.4f} \n**Threshold:** {threshold:.2f}" | |
| # return df, md | |
| # # ====================== | |
| # # Gradio UI | |
| # # ====================== | |
| # with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| # gr.Markdown( | |
| # """ | |
| # # ๐ LahjatBERT: Multi-Label Arabic Dialect Classifier | |
| # This demo predicts **which country-level Arabic dialects a sentence sounds natural in**. | |
| # Unlike classic โpick one dialectโ systems, **a single sentence can be acceptable in multiple dialects**. | |
| # **How to use** | |
| # 1) Paste an Arabic sentence | |
| # 2) Adjust the **Confidence Threshold** (higher = fewer highlights) | |
| # 3) Click **Predict Dialects** | |
| # **How to interpret the results** | |
| # - **Highlighted countries** = dialects predicted as *valid/acceptable* for the sentence | |
| # """ | |
| # ) | |
| # with gr.Row(): | |
| # with gr.Column(scale=1): | |
| # model_dropdown = gr.Dropdown( | |
| # choices=list(MODEL_CHOICES.keys()), | |
| # value="LahjatBERT", | |
| # label="Model", | |
| # info="Select which LahjatBERT variant to use for prediction." | |
| # ) | |
| # text_input = gr.Textbox( | |
| # label="Arabic Text Input", | |
| # placeholder="ุฃุฏุฎู ูุตูุง ุนุฑุจููุง ููุง... ู ุซุงู: ุดููููุ / ุฅุฒูู ูุง ุนู ุ / ุดู ุฃุฎุจุงุฑูุ", | |
| # lines=4, | |
| # rtl=True, | |
| # ) | |
| # threshold_slider = gr.Slider( | |
| # minimum=0.1, | |
| # maximum=0.9, | |
| # value=0.3, | |
| # step=0.05, | |
| # label="Confidence Threshold", | |
| # info=( | |
| # "Dialects with confidence โฅ threshold are marked as valid. " | |
| # "Try 0.30 for broader overlap, or 0.50 for stricter predictions." | |
| # ), | |
| # ) | |
| # predict_button = gr.Button("๐ Predict Dialects", variant="primary") | |
| # gr.Markdown( | |
| # """ | |
| # **Tip:** If youโre testing a sentence thatโs close to Modern Standard Arabic (MSA), | |
| # you may see **many countries highlighted**โthatโs expected, because MSA-like text | |
| # can be acceptable across dialects. | |
| # """ | |
| # ) | |
| # with gr.Column(scale=1): | |
| # summary_output = gr.Markdown(label="Summary") | |
| # results_output = gr.Dataframe( | |
| # label="Detailed Results", | |
| # headers=["Dialect", "Confidence", "Prediction"], | |
| # datatype=["str", "str", "str"], | |
| # ) | |
| # gr.Markdown("---") | |
| # gr.Markdown( | |
| # """ | |
| # ## ๐บ๏ธ Dialect Map (Zoomed to the Arab World) | |
| # The map updates after each prediction. | |
| # Green countries indicate dialects predicted as valid at your selected threshold. | |
| # """ | |
| # ) | |
| # map_output = gr.HTML(label="Arab World Map", value=render_map_html([], {})) | |
| # gr.Markdown("---") | |
| # gr.Markdown( | |
| # """ | |
| # ## โจ Try these examples | |
| # These examples are meant to show **dialect overlap**: | |
| # - Some expressions are widely shared and may light up multiple regions | |
| # - Others contain strong local signals (e.g., Egyptian, Gulf/Khaleeji, Levantine, Maghrebi) | |
| # """ | |
| # ) | |
| # gr.Examples( | |
| # examples=[ | |
| # # Broad / MSA-like (often acceptable widely) | |
| # ["ููู ุญุงููุ", 0.30], | |
| # ["ุงูุณูุงู ุนูููู ูุฑุญู ุฉ ุงููู ูุจุฑูุงุชู", 0.30], | |
| # # Egyptian-leaning | |
| # ["ุฅุฒูู ูุง ุนู ุ ุนุงู ู ุฅููุ", 0.30], | |
| # ["ู ุด ูุงูู ููู ูุฏู ุจุตุฑุงุญุฉ", 0.30], | |
| # # Gulf / Iraqi-leaning | |
| # ["ุดููููุ ุดุฎุจุงุฑูุ", 0.30], | |
| # ["ูููู ู ู ุฒู ุงูุ", 0.30], | |
| # # Levantine-leaning | |
| # ["ุดู ุฃุฎุจุงุฑูุ ููููุ", 0.30], | |
| # ["ุจุฏูู ุฃุฑูุญ ูููู", 0.30], | |
| # # Maghrebi-leaning (may vary depending on spelling) | |
| # ["ูุงุจุงุณ ุนูููุ ูุงุด ุฑุงูุ", 0.30], | |
| # ["ุจุฒุงู ุฏูุงู ุงููุงุณ ูููุถุฑู ููุง", 0.30], | |
| # # Stricter threshold examples (fewer highlights) | |
| # ["ุดููููุ", 0.30], | |
| # ["ุฅุฒูู ูุง ุนู ุ", 0.30], | |
| # ], | |
| # inputs=[text_input, threshold_slider], | |
| # label="Click an example to auto-fill the input", | |
| # ) | |
| # predict_button.click( | |
| # fn=predict_wrapper, | |
| # inputs=[model_dropdown, text_input, threshold_slider], | |
| # outputs=[results_output, summary_output, map_output], | |
| # ) | |
| # gr.Markdown( | |
| # """ | |
| # --- | |
| # ### Notes | |
| # - The model outputs **multi-label** predictions: more than one dialect can be valid at once. | |
| # If you use this demo in research, please cite the accompanying paper. | |
| # """ | |
| # ) | |
| # # Launch | |
| # if __name__ == "__main__": | |
| # demo.launch() |