# src/evaluators/transliteration/evaluator.py import torch from datasets import load_dataset from sklearn.metrics import accuracy_score from typing import Dict, Any import warnings from ..base_evaluator import BaseEvaluator from .datasets import TRANSLITERATION_DATASETS warnings.filterwarnings("ignore") class TransliterationEvaluator(BaseEvaluator): def __init__(self, dataset_key: str = "madar-tun", max_samples: int = None): if dataset_key not in TRANSLITERATION_DATASETS: raise ValueError(f"Unknown dataset: {dataset_key}") self.config = TRANSLITERATION_DATASETS[dataset_key] self.max_samples = max_samples @property def task_name(self) -> str: return "Transliteration" def load_dataset(self): print(f"\nLoading transliteration data from {self.config['path']}...") ds = load_dataset( self.config["path"], split=self.config["split"] ) valid = [] for ex in ds: src = ex[self.config["source_col"]] tgt = ex[self.config["target_col"]] if src and tgt and src != "" and tgt != "" and src.strip() and tgt.strip(): valid.append((src.strip(), tgt.strip())) if self.max_samples: valid = valid[:self.max_samples] print(f"Loaded {len(valid)} transliteration pairs.") return valid def evaluate(self, model, tokenizer, device: str = "cuda") -> Dict[str, Any]: pairs = self.load_dataset() if not pairs: raise ValueError("No valid transliteration pairs found!") sources, targets = zip(*pairs) sources, targets = list(sources), list(targets) # Build target vocab unique_targets = sorted(set(targets)) target_to_id = {t: i for i, t in enumerate(unique_targets)} # Encode targets target_enc = tokenizer( unique_targets, padding=True, truncation=True, max_length=32, return_tensors="pt" ).to(device) with torch.no_grad(): target_embeds = model(**target_enc).last_hidden_state[:, 0] # Predict predictions = [] batch_size = 32 for i in range(0, len(sources), batch_size): batch = sources[i:i+batch_size] inputs = tokenizer( batch, padding=True, truncation=True, max_length=32, return_tensors="pt" ).to(device) with torch.no_grad(): src_embeds = model(**inputs).last_hidden_state[:, 0] logits = torch.matmul(src_embeds, target_embeds.T) preds = logits.argmax(dim=1).cpu().tolist() predictions.extend(preds) true_labels = [target_to_id[t] for t in targets] acc = accuracy_score(true_labels, predictions) print(f"✅ Transliteration Accuracy: {acc:.4f}") return { "task": self.task_name, "main_metric": acc, "accuracy": acc, "total_samples": len(pairs) }