File size: 3,203 Bytes
bde1c71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# src/evaluators/normalization/evaluator.py
import torch
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from typing import Dict, Any
import warnings

from ..base_evaluator import BaseEvaluator
from .datasets import NORMALIZATION_DATASETS

warnings.filterwarnings("ignore")

class NormalizationEvaluator(BaseEvaluator):
    def __init__(self, dataset_key: str = "madar-tun", max_samples: int = None):
        if dataset_key not in NORMALIZATION_DATASETS:
            raise ValueError(f"Unknown dataset: {dataset_key}")
        self.config = NORMALIZATION_DATASETS[dataset_key]
        self.max_samples = max_samples

    @property
    def task_name(self) -> str:
        return "Normalization"

    def load_dataset(self):
        print(f"\nLoading normalization data from {self.config['path']}...")
        ds = load_dataset(
            self.config["path"],
            split=self.config["split"]
        )
        
        valid = []
        for ex in ds:
            a = ex[self.config["arabish_col"]]
            c = ex[self.config["canonical_col"]]
            if a and c and a != "<eos>" and c != "<eos>" and a is not None and a.strip() and c is not None and c.strip():
                valid.append((a.strip(), c.strip()))
        
        if self.max_samples:
            valid = valid[:self.max_samples]
        
        print(f"Loaded {len(valid)} normalization pairs.")
        return valid  # List[Tuple[noisy, canonical]]

    def evaluate(self, model, tokenizer, device: str = "cuda") -> Dict[str, Any]:
        pairs = self.load_dataset()
        if not pairs:
            raise ValueError("No valid normalization pairs found!")

        words, targets = zip(*pairs)
        words, targets = list(words), list(targets)

        # Build vocab
        unique_targets = sorted(set(targets))
        target_to_id = {t: i for i, t in enumerate(unique_targets)}

        # Encode targets
        target_enc = tokenizer(
            unique_targets,
            padding=True,
            truncation=True,
            max_length=32,
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            target_embeds = model(**target_enc).last_hidden_state[:, 0]

        # Predict
        predictions = []
        batch_size = 32
        for i in range(0, len(words), batch_size):
            batch = words[i:i+batch_size]
            inputs = tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=32,
                return_tensors="pt"
            ).to(device)

            with torch.no_grad():
                word_embeds = model(**inputs).last_hidden_state[:, 0]
                logits = torch.matmul(word_embeds, target_embeds.T)
                preds = logits.argmax(dim=1).cpu().tolist()
                predictions.extend(preds)

        true_labels = [target_to_id[t] for t in targets]
        acc = accuracy_score(true_labels, predictions)

        print(f"✅ Normalization Accuracy: {acc:.4f}")
        return {
            "task": self.task_name,
            "main_metric": acc,
            "accuracy": acc,
            "total_samples": len(pairs)
        }