# src/evaluators/normalization/evaluator.py import torch from datasets import load_dataset from sklearn.metrics import accuracy_score from typing import Dict, Any import warnings from ..base_evaluator import BaseEvaluator from .datasets import NORMALIZATION_DATASETS warnings.filterwarnings("ignore") class NormalizationEvaluator(BaseEvaluator): def __init__(self, dataset_key: str = "madar-tun", max_samples: int = None): if dataset_key not in NORMALIZATION_DATASETS: raise ValueError(f"Unknown dataset: {dataset_key}") self.config = NORMALIZATION_DATASETS[dataset_key] self.max_samples = max_samples @property def task_name(self) -> str: return "Normalization" def load_dataset(self): print(f"\nLoading normalization data from {self.config['path']}...") ds = load_dataset( self.config["path"], split=self.config["split"] ) valid = [] for ex in ds: a = ex[self.config["arabish_col"]] c = ex[self.config["canonical_col"]] if a and c and a != "" and c != "" and a is not None and a.strip() and c is not None and c.strip(): valid.append((a.strip(), c.strip())) if self.max_samples: valid = valid[:self.max_samples] print(f"Loaded {len(valid)} normalization pairs.") return valid # List[Tuple[noisy, canonical]] def evaluate(self, model, tokenizer, device: str = "cuda") -> Dict[str, Any]: pairs = self.load_dataset() if not pairs: raise ValueError("No valid normalization pairs found!") words, targets = zip(*pairs) words, targets = list(words), list(targets) # Build vocab unique_targets = sorted(set(targets)) target_to_id = {t: i for i, t in enumerate(unique_targets)} # Encode targets target_enc = tokenizer( unique_targets, padding=True, truncation=True, max_length=32, return_tensors="pt" ).to(device) with torch.no_grad(): target_embeds = model(**target_enc).last_hidden_state[:, 0] # Predict predictions = [] batch_size = 32 for i in range(0, len(words), batch_size): batch = words[i:i+batch_size] inputs = tokenizer( batch, padding=True, truncation=True, max_length=32, return_tensors="pt" ).to(device) with torch.no_grad(): word_embeds = model(**inputs).last_hidden_state[:, 0] logits = torch.matmul(word_embeds, target_embeds.T) preds = logits.argmax(dim=1).cpu().tolist() predictions.extend(preds) true_labels = [target_to_id[t] for t in targets] acc = accuracy_score(true_labels, predictions) print(f"✅ Normalization Accuracy: {acc:.4f}") return { "task": self.task_name, "main_metric": acc, "accuracy": acc, "total_samples": len(pairs) }