Ali Mekky commited on
Commit
4e4c436
·
verified ·
1 Parent(s): fdf0710

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +62 -1
README.md CHANGED
@@ -46,7 +46,7 @@ Users should be aware of biases in dataset annotation and carefully validate out
46
 
47
  - **Testing Data:** NADI 2024 Test set
48
  - **Metrics:** Macro F1-score, precision, recall
49
- - **Link to NADI2024 Leaderboard** https://huggingface.co/spaces/AMR-KELEG/NADI2024-leaderboard
50
 
51
 
52
 
@@ -61,4 +61,65 @@ Users should be aware of biases in dataset annotation and carefully validate out
61
  - **Hardware:** NVIDIA RTX 6000 (24GB VRAM)
62
  - **Software:** Python, PyTorch, Hugging Face Transformers
63
 
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  - **Testing Data:** NADI 2024 Test set
48
  - **Metrics:** Macro F1-score, precision, recall
49
+ - **Link to NADI2024 Leaderboard** https://huggingface.co/spaces/AMR-KELEG/MLADI
50
 
51
 
52
 
 
61
  - **Hardware:** NVIDIA RTX 6000 (24GB VRAM)
62
  - **Software:** Python, PyTorch, Hugging Face Transformers
63
 
64
+ ## Using the Model
65
 
66
+ ```
67
+ import torch
68
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
69
+
70
+ # Load the model and tokenizer
71
+ model_name = "AliMekky/MDABERT"
72
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
73
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
74
+
75
+ # Define dialects
76
+ DIALECTS = [
77
+ "Algeria", "Bahrain", "Egypt", "Iraq", "Jordan", "Kuwait", "Lebanon", "Libya",
78
+ "Morocco", "Oman", "Palestine", "Qatar", "Saudi_Arabia", "Sudan", "Syria",
79
+ "Tunisia", "UAE", "Yemen"
80
+ ]
81
+
82
+ def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3):
83
+ """Predict the validity in each dialect by applying a sigmoid activation to each dialect's logit.
84
+ Dialects with probabilities (sigmoid activations) above the threshold (default 0.3) are predicted as valid.
85
+
86
+ The model generates logits for each dialect in the following order:
87
+ Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar,
88
+ Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
89
+
90
+ """
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+ model.to(device)
93
+
94
+ encodings = tokenizer(
95
+ texts, truncation=True, padding=True, max_length=128, return_tensors="pt"
96
+ )
97
+
98
+ input_ids = encodings["input_ids"].to(device)
99
+ attention_mask = encodings["attention_mask"].to(device)
100
+
101
+ with torch.no_grad():
102
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
103
+ logits = outputs.logits
104
+
105
+ probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1)
106
+ binary_predictions = (probabilities >= threshold).astype(int)
107
+
108
+ # Map indices to actual labels
109
+ predicted_dialects = [
110
+ dialect
111
+ for dialect, dialect_prediction in zip(DIALECTS, binary_predictions)
112
+ if dialect_prediction == 1
113
+ ]
114
+
115
+ return predicted_dialects
116
+
117
+ text = "كيف حالك؟"
118
+
119
+ ## Use threshold 0.3 for better results.
120
+ predicted_dialects = predict_binary_outcomes(model, tokenizer, [text])
121
+ print(f"Predicted Dialects: {predicted_dialects}")
122
+
123
+
124
+
125
+ ```