""" PyTorch Dataset Classes for Sentiment Analysis Type-safe implementation with explicit per-key typing """ import torch from torch.utils.data import Dataset, DataLoader import numpy as np from typing import Dict, List, Optional, Union, Protocol, runtime_checkable, cast, Any import warnings import sys @runtime_checkable class AugmenterProtocol(Protocol): """Protocol defining the required interface for text augmenters""" def augment_text(self, text: str) -> str: ... @property def aug_methods(self) -> List[str]: ... class SentimentDataset(Dataset): """PyTorch Dataset for sentiment analysis""" def __init__( self, texts: Union[List[str], np.ndarray], labels: Union[List[int], np.ndarray, torch.Tensor], preprocessor, add_special_tokens: bool = False, augmenter: Optional[AugmenterProtocol] = None, augment_prob: float = 0.0, num_classes: int = 3 ): if len(texts) == 0: raise ValueError("texts cannot be empty") if len(labels) == 0: raise ValueError("labels cannot be empty") if len(texts) != len(labels): raise ValueError(f"Length mismatch: texts ({len(texts)}) != labels ({len(labels)})") self.texts = np.asarray(texts) self.labels = np.asarray(labels).astype(np.int64) if np.min(self.labels) < 0 or np.max(self.labels) >= num_classes: invalid_labels = np.unique(self.labels[(self.labels < 0) | (self.labels >= num_classes)]) raise ValueError( f"Labels must be in range [0, {num_classes-1}]. " f"Found invalid labels: {invalid_labels.tolist()}" ) if not hasattr(preprocessor, 'word2idx') or len(preprocessor.word2idx) == 0: raise RuntimeError( "Preprocessor vocabulary not built. Call preprocessor.build_vocabulary() first." ) self.preprocessor = preprocessor self.add_special_tokens = add_special_tokens self.augmenter = augmenter self.augment_prob = max(0.0, min(1.0, augment_prob)) self.num_classes = num_classes self.precomputed = (augmenter is None or self.augment_prob == 0.0) if self.precomputed: self._precompute_features() else: self.raw_texts = self.texts.copy() self.raw_labels = self.labels.copy() def _precompute_features(self): print("Precomputing features for faster training...") self.sequences = [] self.padded_sequences = [] self.lengths = [] self.vader_features = [] invalid_count = 0 for i, text in enumerate(self.texts): if not isinstance(text, str) or not text.strip(): text = "unknown" invalid_count += 1 try: sequence = self.preprocessor.text_to_sequence(text) padded = self.preprocessor.pad_sequence(sequence, self.add_special_tokens) length = min(len(sequence), self.preprocessor.max_length) try: vader = self.preprocessor.compute_vader_features(text) except Exception as e: warnings.warn(f"VADER failed for text '{text[:30]}...': {e}. Using zeros.") vader = np.zeros(4, dtype=np.float32) self.sequences.append(sequence) self.padded_sequences.append(padded) self.lengths.append(length) self.vader_features.append(vader) except Exception as e: raise RuntimeError( f"Failed to process text at index {i} ('{text[:50]}...'): {e}" ) if invalid_count > 0: warnings.warn(f"Found {invalid_count} invalid/empty texts. Replaced with 'unknown'.") print(f"Precomputed {len(self.sequences)} samples successfully") def __len__(self) -> int: return len(self.texts) def __getitem__(self, idx: int) -> Dict[str, Any]: if self.precomputed: text = self.texts[idx] label = self.labels[idx] padded = self.padded_sequences[idx] length = self.lengths[idx] vader = self.vader_features[idx] if self.augmenter is not None and np.random.random() < self.augment_prob: try: if isinstance(self.augmenter, AugmenterProtocol): augmented = self.augmenter.augment_text(text) if augmented and isinstance(augmented, str) and augmented.strip(): text = augmented sequence = self.preprocessor.text_to_sequence(text) padded = self.preprocessor.pad_sequence(sequence, self.add_special_tokens) length = min(len(sequence), self.preprocessor.max_length) vader = self.preprocessor.compute_vader_features(text) except Exception as e: warnings.warn(f"Augmentation failed for text '{text[:30]}...': {e}. Using original.") else: text = self.raw_texts[idx] label = self.raw_labels[idx] if not isinstance(text, str) or not text.strip(): text = "unknown" if self.augmenter is not None and np.random.random() < self.augment_prob: try: if isinstance(self.augmenter, AugmenterProtocol): augmented = self.augmenter.augment_text(text) if augmented and isinstance(augmented, str) and augmented.strip(): text = augmented except Exception as e: warnings.warn(f"Augmentation failed: {e}. Using original text.") try: sequence = self.preprocessor.text_to_sequence(text) padded = self.preprocessor.pad_sequence(sequence, self.add_special_tokens) length = min(len(sequence), self.preprocessor.max_length) vader = self.preprocessor.compute_vader_features(text) except Exception as e: raise RuntimeError( f"Failed to process text at index {idx} ('{text[:50]}...'): {e}" ) if not isinstance(padded, (list, np.ndarray)): raise RuntimeError(f"Invalid padded sequence type: {type(padded)}") if len(padded) != self.preprocessor.max_length: raise RuntimeError( f"Sequence length mismatch: expected {self.preprocessor.max_length}, got {len(padded)}" ) return { 'text': torch.LongTensor(padded), 'label': torch.LongTensor([int(label)]), 'length': torch.LongTensor([int(length)]), 'sentiment_score': torch.FloatTensor(vader), 'raw_text': str(text) } class TransformerDataset(Dataset): """Dataset for transformer models""" def __init__( self, texts: Union[List[str], np.ndarray], labels: Union[List[int], np.ndarray, torch.Tensor], tokenizer, max_length: int = 128, augmenter: Optional[AugmenterProtocol] = None, augment_prob: float = 0.0, num_classes: int = 3 ): if len(texts) == 0: raise ValueError("texts cannot be empty") if len(labels) == 0: raise ValueError("labels cannot be empty") if len(texts) != len(labels): raise ValueError(f"Length mismatch: texts ({len(texts)}) != labels ({len(labels)})") self.texts = np.asarray(texts) self.labels = np.asarray(labels).astype(np.int64) if np.min(self.labels) < 0 or np.max(self.labels) >= num_classes: invalid_labels = np.unique(self.labels[(self.labels < 0) | (self.labels >= num_classes)]) raise ValueError( f"Labels must be in range [0, {num_classes-1}]. " f"Found invalid labels: {invalid_labels.tolist()}" ) self.tokenizer = tokenizer self.max_length = max_length self.augmenter = augmenter self.augment_prob = max(0.0, min(1.0, augment_prob)) self.num_classes = num_classes def __len__(self) -> int: return len(self.texts) def __getitem__(self, idx: int) -> Dict[str, Any]: text = self.texts[idx] label = self.labels[idx] if not isinstance(text, str) or not text.strip(): text = "unknown" if self.augmenter is not None and np.random.random() < self.augment_prob: try: if isinstance(self.augmenter, AugmenterProtocol): augmented = self.augmenter.augment_text(text) if augmented and isinstance(augmented, str) and augmented.strip(): text = augmented except Exception as e: warnings.warn(f"Augmentation failed: {e}. Using original text.") try: encoded = self.tokenizer( text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt' ) except Exception as e: warnings.warn(f"Tokenization failed for text '{text[:30]}...': {e}. Using fallback.") encoded = self.tokenizer( "unknown", padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt' ) return { 'input_ids': encoded['input_ids'].squeeze(0), 'attention_mask': encoded['attention_mask'].squeeze(0), 'label': torch.LongTensor([int(label)]), 'raw_text': str(text) } # ✅ CRITICAL FIX: Return precise per-key types WITHOUT union def custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: """ Returns a dictionary with precise per-key types: - 'text', 'label', 'length', 'sentiment_score': torch.Tensor - 'raw_text': List[str] """ text_list = [item['text'] for item in batch] # List[Tensor] label_list = [item['label'] for item in batch] # List[Tensor] length_list = [item['length'] for item in batch] # List[Tensor] sentiment_list = [item['sentiment_score'] for item in batch] # List[Tensor] raw_texts = [item['raw_text'] for item in batch] # List[str] return { 'text': torch.stack(text_list), 'label': torch.cat(label_list), 'length': torch.cat(length_list), 'sentiment_score': torch.stack(sentiment_list), 'raw_text': raw_texts # Pure List[str] } def transformer_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: """ Returns a dictionary with precise per-key types for transformers """ input_ids_list = [item['input_ids'] for item in batch] attention_list = [item['attention_mask'] for item in batch] label_list = [item['label'] for item in batch] raw_texts = [item['raw_text'] for item in batch] return { 'input_ids': torch.stack(input_ids_list), 'attention_mask': torch.stack(attention_list), 'label': torch.cat(label_list), 'raw_text': raw_texts } def create_data_loaders( data_dict: Dict[str, Dict[str, np.ndarray]], preprocessor, batch_size: int = 64, add_special_tokens: bool = False, augmenter: Optional[AugmenterProtocol] = None, augment_train_only: bool = True, num_workers: int = 0, pin_memory: Optional[bool] = None ) -> Dict[str, DataLoader]: required_splits = ['train', 'val', 'test'] required_keys = ['texts', 'labels'] for split in required_splits: if split not in data_dict: raise ValueError(f"data_dict missing required split: '{split}'") for key in required_keys: if key not in data_dict[split]: raise ValueError(f"data_dict['{split}'] missing required key: '{key}'") if pin_memory is None: pin_memory = torch.cuda.is_available() print("\n" + "="*80) print("CREATING DATA LOADERS") print("="*80) train_dataset = SentimentDataset( texts=data_dict['train']['texts'], labels=data_dict['train']['labels'], preprocessor=preprocessor, add_special_tokens=add_special_tokens, augmenter=augmenter if augment_train_only else None, augment_prob=0.5 if (augmenter and augment_train_only) else 0.0 ) val_dataset = SentimentDataset( texts=data_dict['val']['texts'], labels=data_dict['val']['labels'], preprocessor=preprocessor, add_special_tokens=add_special_tokens, augmenter=None, augment_prob=0.0 ) test_dataset = SentimentDataset( texts=data_dict['test']['texts'], labels=data_dict['test']['labels'], preprocessor=preprocessor, add_special_tokens=add_special_tokens, augmenter=None, augment_prob=0.0 ) if num_workers > 0 and sys.platform == 'win32': warnings.warn( "Windows multiprocessing with DataLoader may cause issues. " "Consider setting num_workers=0 if you encounter errors." ) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn, # Returns Dict[str, Any] with precise runtime types num_workers=num_workers, pin_memory=pin_memory, drop_last=False ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn, num_workers=num_workers, pin_memory=pin_memory, drop_last=False ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn, num_workers=num_workers, pin_memory=pin_memory, drop_last=False ) print(f"\nDataset sizes:") print(f" Train: {len(train_dataset):,} samples") print(f" Val: {len(val_dataset):,} samples") print(f" Test: {len(test_dataset):,} samples") print(f"\nBatch configuration:") print(f" Batch size: {batch_size}") print(f" Train batches: {len(train_loader)}") print(f" Val batches: {len(val_loader)}") print(f" Test batches: {len(test_loader)}") print(f" Num workers: {num_workers}") print(f" Pin memory: {pin_memory}") if augmenter and augment_train_only: print(f"\nData augmentation: ENABLED (train only)") print(f" Augmentation probability: 50%") print(f" Methods: {getattr(augmenter, 'aug_methods', 'N/A')}") else: print(f"\nData augmentation: DISABLED") return { 'train': train_loader, 'val': val_loader, 'test': test_loader } def create_transformer_data_loaders( data_dict: Dict[str, Dict[str, np.ndarray]], tokenizer, batch_size: int = 32, max_length: int = 128, augmenter: Optional[AugmenterProtocol] = None, augment_train_only: bool = True, num_workers: int = 0, pin_memory: Optional[bool] = None ) -> Dict[str, DataLoader]: required_splits = ['train', 'val', 'test'] required_keys = ['texts', 'labels'] for split in required_splits: if split not in data_dict: raise ValueError(f"data_dict missing required split: '{split}'") for key in required_keys: if key not in data_dict[split]: raise ValueError(f"data_dict['{split}'] missing required key: '{key}'") if pin_memory is None: pin_memory = torch.cuda.is_available() print("\n" + "="*80) print("CREATING TRANSFORMER DATA LOADERS") print("="*80) train_dataset = TransformerDataset( texts=data_dict['train']['texts'], labels=data_dict['train']['labels'], tokenizer=tokenizer, max_length=max_length, augmenter=augmenter if augment_train_only else None, augment_prob=0.5 if (augmenter and augment_train_only) else 0.0 ) val_dataset = TransformerDataset( texts=data_dict['val']['texts'], labels=data_dict['val']['labels'], tokenizer=tokenizer, max_length=max_length, augmenter=None, augment_prob=0.0 ) test_dataset = TransformerDataset( texts=data_dict['test']['texts'], labels=data_dict['test']['labels'], tokenizer=tokenizer, max_length=max_length, augmenter=None, augment_prob=0.0 ) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, collate_fn=transformer_collate_fn, num_workers=num_workers, pin_memory=pin_memory, drop_last=False ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, collate_fn=transformer_collate_fn, num_workers=num_workers, pin_memory=pin_memory, drop_last=False ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, collate_fn=transformer_collate_fn, num_workers=num_workers, pin_memory=pin_memory, drop_last=False ) print(f"\nDataset sizes:") print(f" Train: {len(train_dataset):,} samples") print(f" Val: {len(val_dataset):,} samples") print(f" Test: {len(test_dataset):,} samples") print(f"\nBatch configuration:") print(f" Batch size: {batch_size}") print(f" Max length: {max_length}") print(f" Train batches: {len(train_loader)}") print(f" Val batches: {len(val_loader)}") print(f" Test batches: {len(test_loader)}") print(f" Num workers: {num_workers}") print(f" Pin memory: {pin_memory}") if augmenter and augment_train_only: print(f"\nData augmentation: ENABLED (train only)") print(f" Augmentation probability: 50%") else: print(f"\nData augmentation: DISABLED") return { 'train': train_loader, 'val': val_loader, 'test': test_loader } def get_class_weights( labels: Union[np.ndarray, torch.Tensor, List[int]], method: str = 'inverse_freq', num_classes: int = 3, epsilon: float = 1e-6 ) -> torch.Tensor: if isinstance(labels, torch.Tensor): labels = labels.cpu().numpy() elif isinstance(labels, list): labels = np.array(labels) if labels.size == 0: raise ValueError("labels cannot be empty") class_counts = np.bincount(labels, minlength=num_classes) if np.any(class_counts == 0): warnings.warn( f"Class(es) with zero samples detected: {np.where(class_counts == 0)[0].tolist()}. " "Using epsilon to avoid division by zero." ) class_counts = np.maximum(class_counts, epsilon) if method == 'inverse_freq': total_samples = len(labels) weights = total_samples / (num_classes * (class_counts + epsilon)) elif method == 'effective_num': beta = 0.9999 effective_num = 1.0 - np.power(beta, class_counts + epsilon) weights = (1.0 - beta) / (effective_num + epsilon) else: raise ValueError(f"Unknown method: {method}. Choose 'inverse_freq' or 'effective_num'") weights = weights / (np.sum(weights) + epsilon) * num_classes return torch.FloatTensor(weights) def print_batch_info(batch: Dict[str, Any], dataset_type: str = 'standard'): """ ✅ FIXED: Explicit variable assignment with runtime type safety """ print("\n" + "="*80) print("BATCH INFORMATION") print("="*80) if dataset_type == 'standard': # Runtime type safety: we KNOW these keys have these types based on collate_fn text_tensor = batch['text'] # torch.Tensor label_tensor = batch['label'] # torch.Tensor length_tensor = batch['length'] # torch.Tensor sentiment_tensor = batch['sentiment_score'] # torch.Tensor raw_texts = batch['raw_text'] # List[str] print(f"Text shape: {text_tensor.shape} (dtype: {text_tensor.dtype})") print(f"Label shape: {label_tensor.shape} (dtype: {label_tensor.dtype})") print(f"Length shape: {length_tensor.shape} (dtype: {length_tensor.dtype})") print(f"Sentiment score shape: {sentiment_tensor.shape} (dtype: {sentiment_tensor.dtype})") print(f"\nSample text (first in batch):") print(f" Tokens (first 20): {text_tensor[0][:20].tolist()}...") print(f" Label: {label_tensor[0].item()}") print(f" Length: {length_tensor[0].item()}") print(f" VADER scores: {sentiment_tensor[0].tolist()}") print(f" Raw text: {raw_texts[0][:100]}...") elif dataset_type == 'transformer': input_ids = batch['input_ids'] # torch.Tensor attention_mask = batch['attention_mask'] # torch.Tensor label_tensor = batch['label'] # torch.Tensor raw_texts = batch['raw_text'] # List[str] print(f"Input IDs shape: {input_ids.shape} (dtype: {input_ids.dtype})") print(f"Attention mask shape: {attention_mask.shape} (dtype: {attention_mask.dtype})") print(f"Label shape: {label_tensor.shape} (dtype: {label_tensor.dtype})") print(f"\nSample (first in batch):") print(f" Input IDs (first 20): {input_ids[0][:20].tolist()}...") print(f" Attention mask sum: {attention_mask[0].sum().item()}") print(f" Label: {label_tensor[0].item()}") print(f" Raw text: {raw_texts[0][:100]}...") else: raise ValueError(f"Unknown dataset_type: {dataset_type}. Choose 'standard' or 'transformer'") if sys.platform == 'win32': warnings.warn( "Windows detected: Setting default num_workers=0 for DataLoader to avoid multiprocessing issues. " "You can override this in create_data_loaders() if needed." ) if __name__ == "__main__": print("="*80) print("TESTING DATASET MODULE") print("="*80) labels = np.array([0]*100 + [1]*500 + [2]*200) weights_inv = get_class_weights(labels, method='inverse_freq') print(f"\nInverse frequency weights: {weights_inv.numpy()}") weights_eff = get_class_weights(labels, method='effective_num') print(f"Effective number weights: {weights_eff.numpy()}") try: get_class_weights(np.array([])) except ValueError as e: print(f"✓ Correctly rejected empty labels: {e}") try: get_class_weights(labels, method='unknown') except ValueError as e: print(f"✓ Correctly rejected invalid method: {e}") print("\n✅ Dataset module tested successfully!")