Tweak the newly added evaluation method
Browse files- eval_utils.py +9 -6
eval_utils.py
CHANGED
|
@@ -81,7 +81,10 @@ def prompt_chat_LLM(model, tokenizer, text):
|
|
| 81 |
|
| 82 |
def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3):
|
| 83 |
"""Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit.
|
| 84 |
-
Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are
|
|
|
|
|
|
|
|
|
|
| 85 |
"""
|
| 86 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 87 |
|
|
@@ -98,13 +101,13 @@ def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3):
|
|
| 98 |
logits = outputs.logits
|
| 99 |
|
| 100 |
probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1)
|
| 101 |
-
|
| 102 |
|
| 103 |
# Map indices to actual labels
|
| 104 |
-
|
| 105 |
dialect
|
| 106 |
-
for dialect,
|
| 107 |
-
if
|
| 108 |
]
|
| 109 |
|
| 110 |
-
return
|
|
|
|
| 81 |
|
| 82 |
def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3):
|
| 83 |
"""Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit.
|
| 84 |
+
Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid.
|
| 85 |
+
The model is expected to generate logits for each dialect of the following dialects in the same order:
|
| 86 |
+
Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
|
| 87 |
+
Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI.
|
| 88 |
"""
|
| 89 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 90 |
|
|
|
|
| 101 |
logits = outputs.logits
|
| 102 |
|
| 103 |
probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1)
|
| 104 |
+
binary_predictions = (probabilities >= threshold).astype(int)
|
| 105 |
|
| 106 |
# Map indices to actual labels
|
| 107 |
+
predicted_dialects = [
|
| 108 |
dialect
|
| 109 |
+
for dialect, dialect_prediction in zip(DIALECTS, binary_predictions)
|
| 110 |
+
if dialect_prediction == 1
|
| 111 |
]
|
| 112 |
|
| 113 |
+
return predicted_dialects
|