hamzabouajila's picture
refactor the code for better scalability and update tsac naming to sentiment analysis, adding madar dataset for transliteration and normalization eval
bde1c71
# src/evaluators/normalization/evaluator.py
import torch
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from typing import Dict, Any
import warnings
from ..base_evaluator import BaseEvaluator
from .datasets import NORMALIZATION_DATASETS
warnings.filterwarnings("ignore")
class NormalizationEvaluator(BaseEvaluator):
def __init__(self, dataset_key: str = "madar-tun", max_samples: int = None):
if dataset_key not in NORMALIZATION_DATASETS:
raise ValueError(f"Unknown dataset: {dataset_key}")
self.config = NORMALIZATION_DATASETS[dataset_key]
self.max_samples = max_samples
@property
def task_name(self) -> str:
return "Normalization"
def load_dataset(self):
print(f"\nLoading normalization data from {self.config['path']}...")
ds = load_dataset(
self.config["path"],
split=self.config["split"]
)
valid = []
for ex in ds:
a = ex[self.config["arabish_col"]]
c = ex[self.config["canonical_col"]]
if a and c and a != "<eos>" and c != "<eos>" and a is not None and a.strip() and c is not None and c.strip():
valid.append((a.strip(), c.strip()))
if self.max_samples:
valid = valid[:self.max_samples]
print(f"Loaded {len(valid)} normalization pairs.")
return valid # List[Tuple[noisy, canonical]]
def evaluate(self, model, tokenizer, device: str = "cuda") -> Dict[str, Any]:
pairs = self.load_dataset()
if not pairs:
raise ValueError("No valid normalization pairs found!")
words, targets = zip(*pairs)
words, targets = list(words), list(targets)
# Build vocab
unique_targets = sorted(set(targets))
target_to_id = {t: i for i, t in enumerate(unique_targets)}
# Encode targets
target_enc = tokenizer(
unique_targets,
padding=True,
truncation=True,
max_length=32,
return_tensors="pt"
).to(device)
with torch.no_grad():
target_embeds = model(**target_enc).last_hidden_state[:, 0]
# Predict
predictions = []
batch_size = 32
for i in range(0, len(words), batch_size):
batch = words[i:i+batch_size]
inputs = tokenizer(
batch,
padding=True,
truncation=True,
max_length=32,
return_tensors="pt"
).to(device)
with torch.no_grad():
word_embeds = model(**inputs).last_hidden_state[:, 0]
logits = torch.matmul(word_embeds, target_embeds.T)
preds = logits.argmax(dim=1).cpu().tolist()
predictions.extend(preds)
true_labels = [target_to_id[t] for t in targets]
acc = accuracy_score(true_labels, predictions)
print(f"✅ Normalization Accuracy: {acc:.4f}")
return {
"task": self.task_name,
"main_metric": acc,
"accuracy": acc,
"total_samples": len(pairs)
}