| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import torch.nn.functional as F |
|
|
| |
| |
| |
| MODEL_ID = "oddadmix/dialect-router-v0.1" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| model.eval() |
|
|
| |
| |
| |
| LABEL_MAP = { |
| "eg": {"dialect": "Egyptian", "region": "Egypt", "flag": "๐ช๐ฌ"}, |
| "sa": {"dialect": "Saudi", "region": "Saudi Arabia", "flag": "๐ธ๐ฆ"}, |
| "mo": {"dialect": "Moroccan (Darija)", "region": "Morocco", "flag": "๐ฒ๐ฆ"}, |
| "iq": {"dialect": "Iraqi", "region": "Iraq", "flag": "๐ฎ๐ถ"}, |
| "sd": {"dialect": "Sudanese", "region": "Sudan", "flag": "๐ธ๐ฉ"}, |
| "tn": {"dialect": "Tunisian", "region": "Tunisia", "flag": "๐น๐ณ"}, |
| "lb": {"dialect": "Lebanese", "region": "Lebanon", "flag": "๐ฑ๐ง"}, |
| "sy": {"dialect": "Syrian", "region": "Syria", "flag": "๐ธ๐พ"}, |
| "ly": {"dialect": "Libyan", "region": "Libya", "flag": "๐ฑ๐พ"}, |
| "ps": {"dialect": "Palestinian", "region": "Palestine", "flag": "๐ต๐ธ"}, |
| "ar": {"dialect": "Modern Standard Arabic (MSA)", "region": "โ", "flag": "๐"}, |
| } |
|
|
| |
| id2label = model.config.id2label |
|
|
|
|
| def format_label(code): |
| meta = LABEL_MAP.get(code, {}) |
| return f"{meta.get('flag','')} {meta.get('dialect','Unknown')} ({meta.get('region','')})" |
|
|
|
|
| def predict_dialect(text): |
| if not text.strip(): |
| return "Please enter text", {} |
|
|
| inputs = tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| padding=True |
| ).to(device) |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
| probs = F.softmax(outputs.logits, dim=-1)[0] |
|
|
| |
| scores = {} |
| for i, prob in enumerate(probs): |
| code = id2label[i] |
| scores[code] = float(prob) |
|
|
| |
| sorted_scores = dict(sorted(scores.items(), key=lambda x: x[1], reverse=True)) |
|
|
| |
| top_code = list(sorted_scores.keys())[0] |
| top_label = format_label(top_code) |
|
|
| |
| formatted_scores = { |
| format_label(code): round(score, 4) |
| for code, score in sorted_scores.items() |
| } |
|
|
| return top_label, formatted_scores |
|
|
|
|
| |
| |
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# ๐ Arabic Dialect Router") |
| gr.Markdown("Detect the dialect of Arabic text with region-aware labels.") |
|
|
| text_input = gr.Textbox( |
| label="Enter Arabic text", |
| placeholder="ู
ุซุงู: ุงูุง ุฑุงูุญ ุงูุดุบู ุฏูููุชู" |
| ) |
|
|
| predict_btn = gr.Button("Predict") |
|
|
| label_output = gr.Textbox(label="Predicted Dialect") |
| scores_output = gr.JSON(label="All Dialect Scores") |
|
|
| gr.Examples( |
| examples=[ |
| "ุงูุง ุฑุงูุญ ุงูุดุบู ุฏูููุชู", |
| "ุดูู ุฏุฑุชู ุงูููู
", |
| "ูููู ุงูููู
", |
| "ูุด ุชุณูู ุงูุญูู", |
| "ู
ุงุฐุง ุชูุนู ุงูุขู", |
| ], |
| inputs=text_input |
| ) |
|
|
| predict_btn.click( |
| fn=predict_dialect, |
| inputs=text_input, |
| outputs=[label_output, scores_output] |
| ) |
|
|
| demo.launch() |