AMR-KELEG commited on
Commit
9f4840a
·
1 Parent(s): 37b6ae3

Tweak the newly added evaluation method

Browse files
Files changed (1) hide show
  1. eval_utils.py +9 -6
eval_utils.py CHANGED
@@ -81,7 +81,10 @@ def prompt_chat_LLM(model, tokenizer, text):
81
 
82
  def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3):
83
  """Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit.
84
- Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are considered predicted.
 
 
 
85
  """
86
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87
 
@@ -98,13 +101,13 @@ def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3):
98
  logits = outputs.logits
99
 
100
  probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1)
101
- predictions = (probabilities >= threshold).astype(int)
102
 
103
  # Map indices to actual labels
104
- predicted_labels = [
105
  dialect
106
- for dialect, dialect_probability in zip(DIALECTS, predictions)
107
- if dialect_probability == 1
108
  ]
109
 
110
- return predicted_labels
 
81
 
82
  def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3):
83
  """Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit.
84
+ Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid.
85
+ The model is expected to generate logits for each dialect of the following dialects in the same order:
86
+ Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
87
+ Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI.
88
  """
89
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
90
 
 
101
  logits = outputs.logits
102
 
103
  probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1)
104
+ binary_predictions = (probabilities >= threshold).astype(int)
105
 
106
  # Map indices to actual labels
107
+ predicted_dialects = [
108
  dialect
109
+ for dialect, dialect_prediction in zip(DIALECTS, binary_predictions)
110
+ if dialect_prediction == 1
111
  ]
112
 
113
+ return predicted_dialects