File size: 3,452 Bytes
bde1c71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from datasets import load_dataset
# from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import accuracy_score
# import argparse
import warnings
warnings.filterwarnings("ignore")

def load_and_prepare_data():
    """Load MADAR-TUN and prepare normalization & transliteration pairs."""
    print("Loading MADAR-TUN dataset...")
    ds = load_dataset("tunis-ai/MADAR-TUN", split="train")
    

    valid_examples = [
        ex for ex in ds 
        if ex["arabish"] != "<eos>" 
        and ex["words"] != "<eos>" 
        and ex["lem"] != "<eos>"
        and ex["arabish"] is not None
        and ex["arabish"].strip()
        and ex["words"] is not None
        and ex["words"].strip()
        and ex["lem"] is not None
        and ex["lem"].strip()
    ]
    
    print(f"Loaded {len(valid_examples)} valid token entries.")
    
    # Build unique pairs (deduplicate)
    norm_pairs = {}  # arabish -> canonical lemma
    trans_pairs = {}  # arabish <-> arabic
    
    for ex in valid_examples:
        arabizi = ex["arabish"]
        arabic = ex["words"]
        lemma = ex["lem"]
        
        # For normalization: use lemma as canonical form
        if arabizi not in norm_pairs:
            norm_pairs[arabizi] = lemma
        

        if arabizi not in trans_pairs:
            trans_pairs[arabizi] = arabic
    
    print(f"Normalization pairs: {len(norm_pairs)}")
    print(f"Transliteration pairs: {len(trans_pairs)}")
    
    return norm_pairs, trans_pairs

def evaluate_word_classification(model, tokenizer, word_pairs, device, task_name):
    """
    Evaluate word-level classification (normalization or transliteration).
    Treats it as closed-vocabulary classification via embedding similarity.
    """
    words = list(word_pairs.keys())
    targets = list(word_pairs.values())
    
    # Build target vocabulary
    unique_targets = sorted(set(targets))
    target_to_id = {t: i for i, t in enumerate(unique_targets)}
    _target_ids = [target_to_id[t] for t in targets]
    
    print(f"\n[{task_name}] Vocabulary size: {len(unique_targets)}")
    print(f"[{task_name}] Evaluation samples: {len(words)}")
    
    # Get embeddings for all target forms
    print(f"[{task_name}] Encoding target vocabulary...")
    target_encodings = tokenizer(
        unique_targets,
        padding=True,
        truncation=True,
        max_length=32,
        return_tensors="pt"
    ).to(device)
    
    with torch.no_grad():
        target_embeds = model(**target_encodings).last_hidden_state[:, 0]  # [V, H]
    
    # Predict for each input word
    predictions = []
    batch_size = 32
    
    print(f"[{task_name}] Predicting...")
    for i in range(0, len(words), batch_size):
        batch_words = words[i:i+batch_size]
        inputs = tokenizer(
            batch_words,
            padding=True,
            truncation=True,
            max_length=32,
            return_tensors="pt"
        ).to(device)
        
        with torch.no_grad():
            word_embeds = model(**inputs).last_hidden_state[:, 0]  # [B, H]
            logits = torch.matmul(word_embeds, target_embeds.T)  # [B, V]
            preds = logits.argmax(dim=1).cpu().tolist()
            predictions.extend(preds)
    
    # Map back to target IDs
    true_labels = [target_to_id[t] for t in targets]
    
    acc = accuracy_score(true_labels, predictions)
    print(f"[{task_name}] Accuracy: {acc:.4f}")
    return acc