File size: 4,647 Bytes
24cf6c5
e103605
24cf6c5
7935012
 
24cf6c5
 
1a452c0
 
 
 
24cf6c5
 
8b3349d
24cf6c5
 
 
 
 
 
d956a72
 
24cf6c5
 
 
 
 
 
3b9a198
491aef1
 
3b9a198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e103605
 
 
571f9ec
e103605
 
 
 
 
 
 
 
 
 
 
8b3349d
 
e103605
 
8b3349d
e103605
 
6df4a25
 
 
 
9f4840a
 
 
 
6df4a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f4840a
6df4a25
 
9f4840a
6df4a25
9f4840a
 
6df4a25
 
9f4840a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from constants import DIALECTS, DIALECTS_WITH_LABELS, DIALECT_IN_ARABIC

device = "auto" if torch.cuda.is_available() else "cpu"


def predict_top_p(model, tokenizer, text, P=0.9):
    """Predict the top dialects with an accumulative confidence of at least P (set by default to 0.9).
    The model is expected to generate logits for each dialect of the following dialects in the same order:
    Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
    """
    assert P <= 1 and P >= 0

    logits = model(**tokenizer(text, return_tensors="pt").to(device)).logits
    probabilities = torch.softmax(logits, dim=1).flatten().tolist()
    topk_predictions = torch.topk(logits, 18).indices.flatten().tolist()

    predictions = [0 for _ in range(18)]
    total_prob = 0

    # TODO: Assert that the list length is just 18

    for i in range(18):
        total_prob += probabilities[topk_predictions[i]]
        predictions[topk_predictions[i]] = 1
        if total_prob >= P:
            break

    if (
        str(model.config.to_dict()["id2label"][0]) == "LABEL_0"
        or str(model.config.to_dict()["id2label"][0]) == "Algeria"
    ):
        return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]
    else:
        # Use the custom list of
        # https://huggingface.co/Abdelrahman-Rezk/bert-base-arabic-camelbert-msa-finetuned-Arabic_Dialect_Identification_model_1
        DIALECTS_LIST = [
            "Oman",
            "Sudan",
            "Saudi_Arabia",
            "Kuwait",
            "Qatar",
            "Lebanon",
            "Jordan",
            "Syria",
            "Iraq",
            "Morocco",
            "Egypt",
            "Palestine",
            "Yemen",
            "Bahrain",
            "Algeria",
            "UAE",
            "Tunisia",
            "Libya",
        ]

        return [DIALECTS_LIST[i] for i, p in enumerate(predictions) if p == 1]


def prompt_chat_LLM(model, tokenizer, text):
    """Prompt the model to determine whether the input text is acceptable in each of the 11 dialects."""
    predicted_dialects = []
    for dialect in DIALECTS_WITH_LABELS:
        messages = [
            {
                "role": "user",
                "content": f"حدد إذا كانت الجملة الأتية مقبولة في أحد اللهجات المستخدمة في {DIALECT_IN_ARABIC[dialect]}. أجب ب 'نعم' أو 'لا' فقط."
                + "\n"
                + f'الجملة: "{text}"',
            },
        ]
        input_ids = tokenizer.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt").to("cuda")
        gen_tokens = model.generate(**input_ids, max_new_tokens=20, pad_token_id=tokenizer.eos_token_id)
        gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
        # TODO: Add a condition for the case of "لا" and other responses (e.g., refuse to answer)
        if gen_text.strip(". ").endswith("نعم"):
            predicted_dialects.append(dialect)
    return predicted_dialects


def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3):
    """Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit.
    Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid.
    The model is expected to generate logits for each dialect of the following dialects in the same order:
    Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
    Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    encodings = tokenizer(
        texts, truncation=True, padding=True, max_length=128, return_tensors="pt"
    )

    ## inputs
    input_ids = encodings["input_ids"].to(device)
    attention_mask = encodings["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1)
    binary_predictions = (probabilities >= threshold).astype(int)

    # Map indices to actual labels
    predicted_dialects = [
        dialect
        for dialect, dialect_prediction in zip(DIALECTS, binary_predictions)
        if dialect_prediction == 1
    ]

    return predicted_dialects