File size: 3,185 Bytes
bde1c71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# src/evaluators/transliteration/evaluator.py
import torch
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from typing import Dict, Any
import warnings

from ..base_evaluator import BaseEvaluator
from .datasets import TRANSLITERATION_DATASETS

warnings.filterwarnings("ignore")

class TransliterationEvaluator(BaseEvaluator):
    def __init__(self, dataset_key: str = "madar-tun", max_samples: int = None):
        if dataset_key not in TRANSLITERATION_DATASETS:
            raise ValueError(f"Unknown dataset: {dataset_key}")
        self.config = TRANSLITERATION_DATASETS[dataset_key]
        self.max_samples = max_samples

    @property
    def task_name(self) -> str:
        return "Transliteration"

    def load_dataset(self):
        print(f"\nLoading transliteration data from {self.config['path']}...")
        ds = load_dataset(
            self.config["path"],
            split=self.config["split"]
        )
        
        valid = []
        for ex in ds:
            src = ex[self.config["source_col"]]
            tgt = ex[self.config["target_col"]]
            if src and tgt and src != "<eos>" and tgt != "<eos>" and src.strip() and tgt.strip():
                valid.append((src.strip(), tgt.strip()))
        
        if self.max_samples:
            valid = valid[:self.max_samples]
        
        print(f"Loaded {len(valid)} transliteration pairs.")
        return valid

    def evaluate(self, model, tokenizer, device: str = "cuda") -> Dict[str, Any]:
        pairs = self.load_dataset()
        if not pairs:
            raise ValueError("No valid transliteration pairs found!")

        sources, targets = zip(*pairs)
        sources, targets = list(sources), list(targets)

        # Build target vocab
        unique_targets = sorted(set(targets))
        target_to_id = {t: i for i, t in enumerate(unique_targets)}

        # Encode targets
        target_enc = tokenizer(
            unique_targets,
            padding=True,
            truncation=True,
            max_length=32,
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            target_embeds = model(**target_enc).last_hidden_state[:, 0]

        # Predict
        predictions = []
        batch_size = 32
        for i in range(0, len(sources), batch_size):
            batch = sources[i:i+batch_size]
            inputs = tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=32,
                return_tensors="pt"
            ).to(device)

            with torch.no_grad():
                src_embeds = model(**inputs).last_hidden_state[:, 0]
                logits = torch.matmul(src_embeds, target_embeds.T)
                preds = logits.argmax(dim=1).cpu().tolist()
                predictions.extend(preds)

        true_labels = [target_to_id[t] for t in targets]
        acc = accuracy_score(true_labels, predictions)

        print(f"✅ Transliteration Accuracy: {acc:.4f}")
        return {
            "task": self.task_name,
            "main_metric": acc,
            "accuracy": acc,
            "total_samples": len(pairs)
        }