import torch from torch.utils.data import DataLoader from datasets import concatenate_datasets, load_dataset,Dataset from typing import Dict, Any, List, Optional import warnings from ..base_evaluator import BaseEvaluator SUPPORTED_DATASETS = { "tsac": { "path": "tunis-ai/tsac", "text_column": "sentence", "label_column": "target", "label_map": {0: 0, 1: 1}, # already binary "trust_remote_code": True, "split": "test" }, } class SentimentAnalysisEvaluator(BaseEvaluator): """ Unified evaluator for Tunisian sentiment analysis. Supports multiple datasets, harmonizes labels to binary (0=neg, 1=pos). Neutral/mapped-to-invalid labels are filtered out. """ def __init__( self, datasets: Optional[List[str]] = None, max_samples_per_dataset: int = 500, batch_size: int = 16 ): """ Args: datasets: List of dataset keys from SUPPORTED_DATASETS. If None, uses all available. max_samples_per_dataset: Limit samples per dataset for faster eval. batch_size: Inference batch size. """ if datasets is None: self.dataset_keys = list(SUPPORTED_DATASETS.keys()) else: for d in datasets: if d not in SUPPORTED_DATASETS: raise ValueError(f"Dataset '{d}' not in supported list: {list(SUPPORTED_DATASETS.keys())}") self.dataset_keys = datasets self.max_samples_per_dataset = max_samples_per_dataset self.batch_size = batch_size @property def task_name(self) -> str: return "Sentiment Analysis" def load_dataset(self) -> Dataset: """Load and harmonize all configured sentiment datasets.""" print("\n=== Loading Tunisian Sentiment Datasets ===") all_datasets = [] for key in self.dataset_keys: cfg = SUPPORTED_DATASETS[key] print(f"\nLoading '{key}': {cfg.get('description', 'No description available.')}") try: ds = load_dataset( cfg["path"], split=cfg["split"], trust_remote_code=cfg.get("trust_remote_code", False) ) print(f" Raw size: {len(ds)}") except Exception as e: warnings.warn(f"Failed to load {key}: {e}. Skipping.") continue # Harmonize to {"text": str, "label": int in {0,1}} def harmonize(example): # print(cfg) try: text = example[cfg["text_column"]] orig_label = example[cfg["label_column"]] if orig_label not in cfg["label_map"]: return None new_label = cfg["label_map"][orig_label] if new_label not in [0, 1]: return None # skip neutral/invalid return {"text": text, "label": new_label} except Exception: return None print(" Harmonizing and filtering...") ds = ds.map( harmonize, load_from_cache_file=False, desc=f"Harmonizing {key}" ) # print(ds) print(" Filtering invalid/neutral samples...") ds = ds.filter(lambda x: x is not None, load_from_cache_file=False) print(f" Valid binary samples: {len(ds)}") if self.max_samples_per_dataset and len(ds) > self.max_samples_per_dataset: ds = ds.select(range(self.max_samples_per_dataset)) print(f" Trimmed to {self.max_samples_per_dataset} samples") if len(ds) > 0: all_datasets.append(ds) if not all_datasets: raise ValueError("No valid sentiment data found!") # Combine all datasets combined = concatenate_datasets(all_datasets) print(f"\n✅ Total Tunisian sentiment samples: {len(combined)}") return combined def _tokenize_batch(self, examples, tokenizer): return tokenizer( examples["sentence"], padding=True, truncation=True, max_length=512, return_tensors=None ) def _collate_fn(self, batch): input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]) attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) labels = torch.tensor([b["labels"] for b in batch], dtype=torch.long) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels } def evaluate(self, model, tokenizer, device: str = "cuda") -> Dict[str, Any]: """Evaluate model on unified Tunisian sentiment task.""" print(f"\n=== Evaluating {self.task_name} ===") print(f"Model: {model.__class__.__name__} | Device: {device}") print(f"Datasets: {self.dataset_keys}") # Load and prepare data raw_dataset = self.load_dataset() tokenized = raw_dataset.map( lambda ex: self._tokenize_batch(ex, tokenizer), batched=True, remove_columns=raw_dataset.column_names ) tokenized.set_format(type="torch", columns=["input_ids", "attention_mask"]) tokenized = tokenized.add_column("labels", raw_dataset["label"]) print(tokenized.column_names) dataloader = DataLoader( tokenized, batch_size=self.batch_size, shuffle=False, collate_fn=self._collate_fn ) # Inference model.eval() all_preds, all_labels = [], [] with torch.no_grad(): for i, batch in enumerate(dataloader): inputs = { k: v.to(device) for k, v in batch.items() if k in ["input_ids", "attention_mask"] } labels = batch["labels"].to(device) outputs = model(**inputs) logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] if logits.dim() == 3: # [B, L, C] logits = logits[:, 0, :] preds = logits.argmax(dim=-1).cpu().tolist() trues = labels.cpu().tolist() all_preds.extend(preds) all_labels.extend(trues) # Metrics correct = sum(p == t for p, t in zip(all_preds, all_labels)) total = len(all_preds) accuracy = correct / total if total > 0 else 0.0 print(f"\n✅ {self.task_name} Results:") print(f" Accuracy: {accuracy:.4f} ({correct}/{total})") return { "task": self.task_name, "accuracy": accuracy, "main_metric": accuracy, "total_samples": total, "datasets_used": self.dataset_keys }