Spaces:
Runtime error
Runtime error
refactor the code for better scalability and update tsac naming to sentiment analysis, adding madar dataset for transliteration and normalization eval
bde1c71
| # src/evaluators/transliteration/evaluator.py | |
| import torch | |
| from datasets import load_dataset | |
| from sklearn.metrics import accuracy_score | |
| from typing import Dict, Any | |
| import warnings | |
| from ..base_evaluator import BaseEvaluator | |
| from .datasets import TRANSLITERATION_DATASETS | |
| warnings.filterwarnings("ignore") | |
| class TransliterationEvaluator(BaseEvaluator): | |
| def __init__(self, dataset_key: str = "madar-tun", max_samples: int = None): | |
| if dataset_key not in TRANSLITERATION_DATASETS: | |
| raise ValueError(f"Unknown dataset: {dataset_key}") | |
| self.config = TRANSLITERATION_DATASETS[dataset_key] | |
| self.max_samples = max_samples | |
| def task_name(self) -> str: | |
| return "Transliteration" | |
| def load_dataset(self): | |
| print(f"\nLoading transliteration data from {self.config['path']}...") | |
| ds = load_dataset( | |
| self.config["path"], | |
| split=self.config["split"] | |
| ) | |
| valid = [] | |
| for ex in ds: | |
| src = ex[self.config["source_col"]] | |
| tgt = ex[self.config["target_col"]] | |
| if src and tgt and src != "<eos>" and tgt != "<eos>" and src.strip() and tgt.strip(): | |
| valid.append((src.strip(), tgt.strip())) | |
| if self.max_samples: | |
| valid = valid[:self.max_samples] | |
| print(f"Loaded {len(valid)} transliteration pairs.") | |
| return valid | |
| def evaluate(self, model, tokenizer, device: str = "cuda") -> Dict[str, Any]: | |
| pairs = self.load_dataset() | |
| if not pairs: | |
| raise ValueError("No valid transliteration pairs found!") | |
| sources, targets = zip(*pairs) | |
| sources, targets = list(sources), list(targets) | |
| # Build target vocab | |
| unique_targets = sorted(set(targets)) | |
| target_to_id = {t: i for i, t in enumerate(unique_targets)} | |
| # Encode targets | |
| target_enc = tokenizer( | |
| unique_targets, | |
| padding=True, | |
| truncation=True, | |
| max_length=32, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| target_embeds = model(**target_enc).last_hidden_state[:, 0] | |
| # Predict | |
| predictions = [] | |
| batch_size = 32 | |
| for i in range(0, len(sources), batch_size): | |
| batch = sources[i:i+batch_size] | |
| inputs = tokenizer( | |
| batch, | |
| padding=True, | |
| truncation=True, | |
| max_length=32, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| src_embeds = model(**inputs).last_hidden_state[:, 0] | |
| logits = torch.matmul(src_embeds, target_embeds.T) | |
| preds = logits.argmax(dim=1).cpu().tolist() | |
| predictions.extend(preds) | |
| true_labels = [target_to_id[t] for t in targets] | |
| acc = accuracy_score(true_labels, predictions) | |
| print(f"✅ Transliteration Accuracy: {acc:.4f}") | |
| return { | |
| "task": self.task_name, | |
| "main_metric": acc, | |
| "accuracy": acc, | |
| "total_samples": len(pairs) | |
| } |