oddadmix's picture
Update app.py
095ad79 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F
# =========================
# Model
# =========================
MODEL_ID = "oddadmix/dialect-router-v0.1"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# =========================
# Label Mapping
# =========================
LABEL_MAP = {
"eg": {"dialect": "Egyptian", "region": "Egypt", "flag": "๐Ÿ‡ช๐Ÿ‡ฌ"},
"sa": {"dialect": "Saudi", "region": "Saudi Arabia", "flag": "๐Ÿ‡ธ๐Ÿ‡ฆ"},
"mo": {"dialect": "Moroccan (Darija)", "region": "Morocco", "flag": "๐Ÿ‡ฒ๐Ÿ‡ฆ"},
"iq": {"dialect": "Iraqi", "region": "Iraq", "flag": "๐Ÿ‡ฎ๐Ÿ‡ถ"},
"sd": {"dialect": "Sudanese", "region": "Sudan", "flag": "๐Ÿ‡ธ๐Ÿ‡ฉ"},
"tn": {"dialect": "Tunisian", "region": "Tunisia", "flag": "๐Ÿ‡น๐Ÿ‡ณ"},
"lb": {"dialect": "Lebanese", "region": "Lebanon", "flag": "๐Ÿ‡ฑ๐Ÿ‡ง"},
"sy": {"dialect": "Syrian", "region": "Syria", "flag": "๐Ÿ‡ธ๐Ÿ‡พ"},
"ly": {"dialect": "Libyan", "region": "Libya", "flag": "๐Ÿ‡ฑ๐Ÿ‡พ"},
"ps": {"dialect": "Palestinian", "region": "Palestine", "flag": "๐Ÿ‡ต๐Ÿ‡ธ"},
"ar": {"dialect": "Modern Standard Arabic (MSA)", "region": "โ€”", "flag": "๐ŸŒ"},
}
# Reverse mapping if model uses numeric labels
id2label = model.config.id2label
def format_label(code):
meta = LABEL_MAP.get(code, {})
return f"{meta.get('flag','')} {meta.get('dialect','Unknown')} ({meta.get('region','')})"
def predict_dialect(text):
if not text.strip():
return "Please enter text", {}
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True
).to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)[0]
# Build scores
scores = {}
for i, prob in enumerate(probs):
code = id2label[i] # e.g., "eg", "sa"
scores[code] = float(prob)
# Sort
sorted_scores = dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))
# Top prediction
top_code = list(sorted_scores.keys())[0]
top_label = format_label(top_code)
# Format scores nicely
formatted_scores = {
format_label(code): round(score, 4)
for code, score in sorted_scores.items()
}
return top_label, formatted_scores
# =========================
# UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("# ๐ŸŒ Arabic Dialect Router")
gr.Markdown("Detect the dialect of Arabic text with region-aware labels.")
text_input = gr.Textbox(
label="Enter Arabic text",
placeholder="ู…ุซุงู„: ุงู†ุง ุฑุงูŠุญ ุงู„ุดุบู„ ุฏู„ูˆู‚ุชูŠ"
)
predict_btn = gr.Button("Predict")
label_output = gr.Textbox(label="Predicted Dialect")
scores_output = gr.JSON(label="All Dialect Scores")
gr.Examples(
examples=[
"ุงู†ุง ุฑุงูŠุญ ุงู„ุดุบู„ ุฏู„ูˆู‚ุชูŠ", # Egyptian
"ุดู†ูˆ ุฏุฑุชูŠ ุงู„ูŠูˆู…", # Moroccan
"ูƒูŠููƒ ุงู„ูŠูˆู…", # Levant
"ูˆุด ุชุณูˆูŠ ุงู„ุญูŠู†", # Saudi
"ู…ุงุฐุง ุชูุนู„ ุงู„ุขู†", # MSA
],
inputs=text_input
)
predict_btn.click(
fn=predict_dialect,
inputs=text_input,
outputs=[label_output, scores_output]
)
demo.launch()