File size: 4,647 Bytes
24cf6c5 e103605 24cf6c5 7935012 24cf6c5 1a452c0 24cf6c5 8b3349d 24cf6c5 d956a72 24cf6c5 3b9a198 491aef1 3b9a198 e103605 571f9ec e103605 8b3349d e103605 8b3349d e103605 6df4a25 9f4840a 6df4a25 9f4840a 6df4a25 9f4840a 6df4a25 9f4840a 6df4a25 9f4840a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import torch
from constants import DIALECTS, DIALECTS_WITH_LABELS, DIALECT_IN_ARABIC
device = "auto" if torch.cuda.is_available() else "cpu"
def predict_top_p(model, tokenizer, text, P=0.9):
"""Predict the top dialects with an accumulative confidence of at least P (set by default to 0.9).
The model is expected to generate logits for each dialect of the following dialects in the same order:
Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
"""
assert P <= 1 and P >= 0
logits = model(**tokenizer(text, return_tensors="pt").to(device)).logits
probabilities = torch.softmax(logits, dim=1).flatten().tolist()
topk_predictions = torch.topk(logits, 18).indices.flatten().tolist()
predictions = [0 for _ in range(18)]
total_prob = 0
# TODO: Assert that the list length is just 18
for i in range(18):
total_prob += probabilities[topk_predictions[i]]
predictions[topk_predictions[i]] = 1
if total_prob >= P:
break
if (
str(model.config.to_dict()["id2label"][0]) == "LABEL_0"
or str(model.config.to_dict()["id2label"][0]) == "Algeria"
):
return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]
else:
# Use the custom list of
# https://huggingface.co/Abdelrahman-Rezk/bert-base-arabic-camelbert-msa-finetuned-Arabic_Dialect_Identification_model_1
DIALECTS_LIST = [
"Oman",
"Sudan",
"Saudi_Arabia",
"Kuwait",
"Qatar",
"Lebanon",
"Jordan",
"Syria",
"Iraq",
"Morocco",
"Egypt",
"Palestine",
"Yemen",
"Bahrain",
"Algeria",
"UAE",
"Tunisia",
"Libya",
]
return [DIALECTS_LIST[i] for i, p in enumerate(predictions) if p == 1]
def prompt_chat_LLM(model, tokenizer, text):
"""Prompt the model to determine whether the input text is acceptable in each of the 11 dialects."""
predicted_dialects = []
for dialect in DIALECTS_WITH_LABELS:
messages = [
{
"role": "user",
"content": f"حدد إذا كانت الجملة الأتية مقبولة في أحد اللهجات المستخدمة في {DIALECT_IN_ARABIC[dialect]}. أجب ب 'نعم' أو 'لا' فقط."
+ "\n"
+ f'الجملة: "{text}"',
},
]
input_ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt").to("cuda")
gen_tokens = model.generate(**input_ids, max_new_tokens=20, pad_token_id=tokenizer.eos_token_id)
gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
# TODO: Add a condition for the case of "لا" and other responses (e.g., refuse to answer)
if gen_text.strip(". ").endswith("نعم"):
predicted_dialects.append(dialect)
return predicted_dialects
def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3):
"""Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit.
Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid.
The model is expected to generate logits for each dialect of the following dialects in the same order:
Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encodings = tokenizer(
texts, truncation=True, padding=True, max_length=128, return_tensors="pt"
)
## inputs
input_ids = encodings["input_ids"].to(device)
attention_mask = encodings["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1)
binary_predictions = (probabilities >= threshold).astype(int)
# Map indices to actual labels
predicted_dialects = [
dialect
for dialect, dialect_prediction in zip(DIALECTS, binary_predictions)
if dialect_prediction == 1
]
return predicted_dialects
|