import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch.nn.functional as F # ========================= # Model # ========================= MODEL_ID = "oddadmix/dialect-router-v0.1" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # ========================= # Label Mapping # ========================= LABEL_MAP = { "eg": {"dialect": "Egyptian", "region": "Egypt", "flag": "🇪🇬"}, "sa": {"dialect": "Saudi", "region": "Saudi Arabia", "flag": "🇸🇦"}, "mo": {"dialect": "Moroccan (Darija)", "region": "Morocco", "flag": "🇲🇦"}, "iq": {"dialect": "Iraqi", "region": "Iraq", "flag": "🇮🇶"}, "sd": {"dialect": "Sudanese", "region": "Sudan", "flag": "🇸🇩"}, "tn": {"dialect": "Tunisian", "region": "Tunisia", "flag": "🇹🇳"}, "lb": {"dialect": "Lebanese", "region": "Lebanon", "flag": "🇱🇧"}, "sy": {"dialect": "Syrian", "region": "Syria", "flag": "🇸🇾"}, "ly": {"dialect": "Libyan", "region": "Libya", "flag": "🇱🇾"}, "ps": {"dialect": "Palestinian", "region": "Palestine", "flag": "🇵🇸"}, "ar": {"dialect": "Modern Standard Arabic (MSA)", "region": "—", "flag": "🌍"}, } # Reverse mapping if model uses numeric labels id2label = model.config.id2label def format_label(code): meta = LABEL_MAP.get(code, {}) return f"{meta.get('flag','')} {meta.get('dialect','Unknown')} ({meta.get('region','')})" def predict_dialect(text): if not text.strip(): return "Please enter text", {} inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True ).to(device) with torch.no_grad(): outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=-1)[0] # Build scores scores = {} for i, prob in enumerate(probs): code = id2label[i] # e.g., "eg", "sa" scores[code] = float(prob) # Sort sorted_scores = dict(sorted(scores.items(), key=lambda x: x[1], reverse=True)) # Top prediction top_code = list(sorted_scores.keys())[0] top_label = format_label(top_code) # Format scores nicely formatted_scores = { format_label(code): round(score, 4) for code, score in sorted_scores.items() } return top_label, formatted_scores # ========================= # UI # ========================= with gr.Blocks() as demo: gr.Markdown("# 🌍 Arabic Dialect Router") gr.Markdown("Detect the dialect of Arabic text with region-aware labels.") text_input = gr.Textbox( label="Enter Arabic text", placeholder="مثال: انا رايح الشغل دلوقتي" ) predict_btn = gr.Button("Predict") label_output = gr.Textbox(label="Predicted Dialect") scores_output = gr.JSON(label="All Dialect Scores") gr.Examples( examples=[ "انا رايح الشغل دلوقتي", # Egyptian "شنو درتي اليوم", # Moroccan "كيفك اليوم", # Levant "وش تسوي الحين", # Saudi "ماذا تفعل الآن", # MSA ], inputs=text_input ) predict_btn.click( fn=predict_dialect, inputs=text_input, outputs=[label_output, scores_output] ) demo.launch()