Spaces:
Runtime error
Runtime error
refactor the code for better scalability and update tsac naming to sentiment analysis, adding madar dataset for transliteration and normalization eval
bde1c71
| import torch | |
| from torch.utils.data import DataLoader | |
| from datasets import concatenate_datasets, load_dataset,Dataset | |
| from typing import Dict, Any, List, Optional | |
| import warnings | |
| from ..base_evaluator import BaseEvaluator | |
| SUPPORTED_DATASETS = { | |
| "tsac": { | |
| "path": "tunis-ai/tsac", | |
| "text_column": "sentence", | |
| "label_column": "target", | |
| "label_map": {0: 0, 1: 1}, # already binary | |
| "trust_remote_code": True, | |
| "split": "test" | |
| }, | |
| } | |
| class SentimentAnalysisEvaluator(BaseEvaluator): | |
| """ | |
| Unified evaluator for Tunisian sentiment analysis. | |
| Supports multiple datasets, harmonizes labels to binary (0=neg, 1=pos). | |
| Neutral/mapped-to-invalid labels are filtered out. | |
| """ | |
| def __init__( | |
| self, | |
| datasets: Optional[List[str]] = None, | |
| max_samples_per_dataset: int = 500, | |
| batch_size: int = 16 | |
| ): | |
| """ | |
| Args: | |
| datasets: List of dataset keys from SUPPORTED_DATASETS. | |
| If None, uses all available. | |
| max_samples_per_dataset: Limit samples per dataset for faster eval. | |
| batch_size: Inference batch size. | |
| """ | |
| if datasets is None: | |
| self.dataset_keys = list(SUPPORTED_DATASETS.keys()) | |
| else: | |
| for d in datasets: | |
| if d not in SUPPORTED_DATASETS: | |
| raise ValueError(f"Dataset '{d}' not in supported list: {list(SUPPORTED_DATASETS.keys())}") | |
| self.dataset_keys = datasets | |
| self.max_samples_per_dataset = max_samples_per_dataset | |
| self.batch_size = batch_size | |
| def task_name(self) -> str: | |
| return "Sentiment Analysis" | |
| def load_dataset(self) -> Dataset: | |
| """Load and harmonize all configured sentiment datasets.""" | |
| print("\n=== Loading Tunisian Sentiment Datasets ===") | |
| all_datasets = [] | |
| for key in self.dataset_keys: | |
| cfg = SUPPORTED_DATASETS[key] | |
| print(f"\nLoading '{key}': {cfg.get('description', "No description available.")}") | |
| try: | |
| ds = load_dataset( | |
| cfg["path"], | |
| split=cfg["split"], | |
| trust_remote_code=cfg.get("trust_remote_code", False) | |
| ) | |
| print(f" Raw size: {len(ds)}") | |
| except Exception as e: | |
| warnings.warn(f"Failed to load {key}: {e}. Skipping.") | |
| continue | |
| # Harmonize to {"text": str, "label": int in {0,1}} | |
| def harmonize(example): | |
| # print(cfg) | |
| try: | |
| text = example[cfg["text_column"]] | |
| orig_label = example[cfg["label_column"]] | |
| if orig_label not in cfg["label_map"]: | |
| return None | |
| new_label = cfg["label_map"][orig_label] | |
| if new_label not in [0, 1]: | |
| return None # skip neutral/invalid | |
| return {"text": text, "label": new_label} | |
| except Exception: | |
| return None | |
| print(" Harmonizing and filtering...") | |
| ds = ds.map( | |
| harmonize, | |
| load_from_cache_file=False, | |
| desc=f"Harmonizing {key}" | |
| ) | |
| # print(ds) | |
| print(" Filtering invalid/neutral samples...") | |
| ds = ds.filter(lambda x: x is not None, load_from_cache_file=False) | |
| print(f" Valid binary samples: {len(ds)}") | |
| if self.max_samples_per_dataset and len(ds) > self.max_samples_per_dataset: | |
| ds = ds.select(range(self.max_samples_per_dataset)) | |
| print(f" Trimmed to {self.max_samples_per_dataset} samples") | |
| if len(ds) > 0: | |
| all_datasets.append(ds) | |
| if not all_datasets: | |
| raise ValueError("No valid sentiment data found!") | |
| # Combine all datasets | |
| combined = concatenate_datasets(all_datasets) | |
| print(f"\n✅ Total Tunisian sentiment samples: {len(combined)}") | |
| return combined | |
| def _tokenize_batch(self, examples, tokenizer): | |
| return tokenizer( | |
| examples["sentence"], | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors=None | |
| ) | |
| def _collate_fn(self, batch): | |
| input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]) | |
| attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) | |
| labels = torch.tensor([b["labels"] for b in batch], dtype=torch.long) | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "labels": labels | |
| } | |
| def evaluate(self, model, tokenizer, device: str = "cuda") -> Dict[str, Any]: | |
| """Evaluate model on unified Tunisian sentiment task.""" | |
| print(f"\n=== Evaluating {self.task_name} ===") | |
| print(f"Model: {model.__class__.__name__} | Device: {device}") | |
| print(f"Datasets: {self.dataset_keys}") | |
| # Load and prepare data | |
| raw_dataset = self.load_dataset() | |
| tokenized = raw_dataset.map( | |
| lambda ex: self._tokenize_batch(ex, tokenizer), | |
| batched=True, | |
| remove_columns=raw_dataset.column_names | |
| ) | |
| tokenized.set_format(type="torch", columns=["input_ids", "attention_mask"]) | |
| tokenized = tokenized.add_column("labels", raw_dataset["label"]) | |
| print(tokenized.column_names) | |
| dataloader = DataLoader( | |
| tokenized, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| collate_fn=self._collate_fn | |
| ) | |
| # Inference | |
| model.eval() | |
| all_preds, all_labels = [], [] | |
| with torch.no_grad(): | |
| for i, batch in enumerate(dataloader): | |
| inputs = { | |
| k: v.to(device) for k, v in batch.items() | |
| if k in ["input_ids", "attention_mask"] | |
| } | |
| labels = batch["labels"].to(device) | |
| outputs = model(**inputs) | |
| logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] | |
| if logits.dim() == 3: # [B, L, C] | |
| logits = logits[:, 0, :] | |
| preds = logits.argmax(dim=-1).cpu().tolist() | |
| trues = labels.cpu().tolist() | |
| all_preds.extend(preds) | |
| all_labels.extend(trues) | |
| # Metrics | |
| correct = sum(p == t for p, t in zip(all_preds, all_labels)) | |
| total = len(all_preds) | |
| accuracy = correct / total if total > 0 else 0.0 | |
| print(f"\n✅ {self.task_name} Results:") | |
| print(f" Accuracy: {accuracy:.4f} ({correct}/{total})") | |
| return { | |
| "task": self.task_name, | |
| "accuracy": accuracy, | |
| "main_metric": accuracy, | |
| "total_samples": total, | |
| "datasets_used": self.dataset_keys | |
| } |