|
|
--- |
|
|
library_name: transformers |
|
|
tags: [Arabic, Dialect Identification, Multi-Label, BERT, MLADI, NLP] |
|
|
--- |
|
|
|
|
|
# Model Card for B2BERT |
|
|
|
|
|
## Model Details |
|
|
|
|
|
### Model Description |
|
|
|
|
|
B2BERT is a lightweight transformer-based model for **Multi-Label Arabic Dialect Identification (MLADI)**. Unlike traditional single-label approaches, MLADI captures the natural overlap between dialects, allowing a sentence to be associated with multiple dialects at once. |
|
|
|
|
|
- **Model type:** Transformer-based multi-label classifier |
|
|
- **Finetuned from model:** CAMeLBERT (~110M parameters) |
|
|
- **Language(s) (NLP):** Arabic (Dialectal Variants, 18 dialects) |
|
|
- **License:** TBD |
|
|
|
|
|
**Key Innovations:** |
|
|
- **Knowledge Distillation:** Multi-label annotations generated by GPT-4o, capturing real-world ambiguity. |
|
|
- **Curriculum Learning:** Samples introduced progressively by label cardinality, mitigating imbalance and improving generalization. |
|
|
- **Preprocessing:** Normalization (Alef variants), diacritics/emoji/punctuation removal, anonymization of mentions and URLs, mixed-language handling, and stopword removal. |
|
|
|
|
|
--- |
|
|
|
|
|
## Bias, Risks, and Limitations |
|
|
|
|
|
### Biases |
|
|
- Geographic bias in dataset annotation (labels based on user location). |
|
|
- Overlapping dialects may result in misclassification, especially in dense regions like Maghreb and Levant. |
|
|
- Errors may arise from synthetic labels (pseudo-labeling from GPT-4o). |
|
|
|
|
|
### Recommendations |
|
|
Users should validate outputs before deploying in high-stakes or production settings, especially where dialect precision is critical. |
|
|
|
|
|
--- |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Training Data |
|
|
- **Datasets:** NADI 2020, 2021, 2023, and NADI 2024 development set. |
|
|
- **Synthetic dataset:** Converted to multi-label using GPT-4o pseudo-labeling. |
|
|
- **Dialects Covered (18):** Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi Arabia, Sudan, Syria, Tunisia, UAE, Yemen. |
|
|
|
|
|
### Training Procedure |
|
|
|
|
|
Training Procedure |
|
|
- **Knowledge Distillation:** Pseudo-labels generated by GPT-4o. |
|
|
|
|
|
- **Curriculum Learning:** Training samples organized by label cardinality (from single to multi-label) to improve robustness without undersampling. Progressively exposed to examples with increasing label combinations (to balance between different cardinalities in each epoch), while reinforcing simpler cases. |
|
|
|
|
|
|
|
|
|
|
|
**Hyperparameters:** |
|
|
- Optimizer: AdamW |
|
|
- Learning Rate: 1e-5 |
|
|
- Dropout: 0.3 |
|
|
- Batch Size: 11 |
|
|
- Epochs: 10 |
|
|
- Hardware: NVIDIA RTX 6000 (24GB VRAM) |
|
|
- Training Time: ~27 minutes per run |
|
|
|
|
|
--- |
|
|
|
|
|
## Evaluation |
|
|
|
|
|
### Testing Data & Metrics |
|
|
- **Test Data:** Official MLADI test set (not public, NADI 2024 shared task). |
|
|
- **Metrics:** Macro F1-score, precision, recall. |
|
|
|
|
|
**Leaderboard Results (MLADI, NADI 2024):** |
|
|
|
|
|
| Model | Macro F1 | Precision | Recall | |
|
|
|------------------------|----------|-----------|---------| |
|
|
| NADI 2024 Baseline | 0.4698 | 0.6480 | 0.3986 | |
|
|
| ELYADATA (best team) | 0.5240 | 0.5015 | 0.5687 | |
|
|
| Aya Expanse 32B | 0.5447 | 0.4945 | 0.6451 | |
|
|
| ALLaM 7B Instruct | 0.2506 | 0.5791 | 0.1639 | |
|
|
| **B2BERT (ours)** | **0.5963** | 0.5818 | 0.6976 | |
|
|
|
|
|
- **Strengths:** Strong performance on Gulf, MSA, and Egyptian dialects (F1 > 0.70). |
|
|
- **Weaknesses:** Lower performance on Maghrebi, Levantine, and Nile Valley dialects due to overlap. |
|
|
|
|
|
**Link to MLADI Leaderboard:** [Hugging Face Space](https://huggingface.co/spaces/AMR-KELEG/MLADI) |
|
|
|
|
|
--- |
|
|
|
|
|
## Technical Specifications |
|
|
|
|
|
### Model Architecture and Objective |
|
|
- Transformer-based multi-label classifier (BERT backbone). |
|
|
- Outputs sigmoid activations per dialect, allowing multi-label predictions. |
|
|
|
|
|
### Compute Infrastructure |
|
|
- **Hardware:** NVIDIA RTX 6000 (24GB VRAM) |
|
|
- **Software:** Python, PyTorch, Hugging Face Transformers |
|
|
|
|
|
## Using the Model |
|
|
|
|
|
``` |
|
|
import torch |
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
|
|
# Load the model and tokenizer |
|
|
model_name = "AHAAM/B2BERT" |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
# Define dialects |
|
|
DIALECTS = [ |
|
|
"Algeria", "Bahrain", "Egypt", "Iraq", "Jordan", "Kuwait", "Lebanon", "Libya", |
|
|
"Morocco", "Oman", "Palestine", "Qatar", "Saudi_Arabia", "Sudan", "Syria", |
|
|
"Tunisia", "UAE", "Yemen" |
|
|
] |
|
|
|
|
|
def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3): |
|
|
"""Predict the validity in each dialect by applying a sigmoid activation to each dialect's logit. |
|
|
Dialects with probabilities (sigmoid activations) above the threshold (default 0.3) are predicted as valid. |
|
|
|
|
|
The model generates logits for each dialect in the following order: |
|
|
Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, |
|
|
Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen. |
|
|
|
|
|
""" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
encodings = tokenizer( |
|
|
texts, truncation=True, padding=True, max_length=128, return_tensors="pt" |
|
|
) |
|
|
|
|
|
input_ids = encodings["input_ids"].to(device) |
|
|
attention_mask = encodings["attention_mask"].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
logits = outputs.logits |
|
|
|
|
|
probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1) |
|
|
binary_predictions = (probabilities >= threshold).astype(int) |
|
|
|
|
|
# Map indices to actual labels |
|
|
predicted_dialects = [ |
|
|
dialect |
|
|
for dialect, dialect_prediction in zip(DIALECTS, binary_predictions) |
|
|
if dialect_prediction == 1 |
|
|
] |
|
|
|
|
|
return predicted_dialects |
|
|
|
|
|
text = "كيف حالك؟" |
|
|
|
|
|
## Use threshold 0.3 for better results. |
|
|
predicted_dialects = predict_binary_outcomes(model, tokenizer, [text]) |
|
|
print(f"Predicted Dialects: {predicted_dialects}") |
|
|
|
|
|
|
|
|
|
|
|
``` |
|
|
|
|
|
## Credits to |
|
|
- Ali Mekky: ali.mekky@mbzuai.ac.ae |
|
|
- Mohamed ElZeftawy: mohamed.elzeftawy@mbzuai.ac.ae |
|
|
- Lara Hassan: lara.hassan@mbzuai.ac.ae |