""" Text augmentation module for sentiment analysis Handles synonym replacement, word swapping, deletion, insertion, and back-translation """ import random import numpy as np import torch import nlpaug.augmenter.word as naw import nlpaug.augmenter.sentence as nas from tqdm import tqdm class TextAugmenter: def __init__(self, aug_methods=['synonym', 'swap', 'delete'], aug_p=0.1): self.aug_methods = aug_methods self.aug_p = aug_p if 'synonym' in aug_methods: self.synonym_aug = naw.SynonymAug(aug_src='wordnet', aug_p=aug_p) if 'backtranslation' in aug_methods: try: self.back_trans_aug = naw.BackTranslationAug( from_model_name='facebook/wmt19-en-de', to_model_name='facebook/wmt19-de-en', device='cuda' if torch.cuda.is_available() else 'cpu' ) except Exception as e: print(f"Warning: Back-translation not available: {e}") if 'backtranslation' in self.aug_methods: self.aug_methods.remove('backtranslation') def random_swap(self, text, n=2): words = text.split() if len(words) < 2: return text for _ in range(n): idx1, idx2 = random.sample(range(len(words)), 2) words[idx1], words[idx2] = words[idx2], words[idx1] return ' '.join(words) def random_deletion(self, text, p=0.1): words = text.split() if len(words) == 1: return text new_words = [word for word in words if random.random() > p] if len(new_words) == 0: return random.choice(words) return ' '.join(new_words) def random_insertion(self, text, n=1): words = text.split() if len(words) == 0: return text for _ in range(n): random_word = random.choice(words) random_idx = random.randint(0, len(words)) words.insert(random_idx, random_word) return ' '.join(words) def synonym_replacement(self, text): if 'synonym' in self.aug_methods and hasattr(self, 'synonym_aug'): try: augmented = self.synonym_aug.augment(text) # ✅ FIX: Handle nlpaug's variable return types safely if augmented is None: return text elif isinstance(augmented, list): # nlpaug sometimes returns list even with n=1 if len(augmented) > 0 and isinstance(augmented[0], str) and augmented[0].strip(): return augmented[0] else: return text elif isinstance(augmented, str) and augmented.strip(): return augmented else: return text except Exception as e: print(f"Synonym augmentation failed for text '{text[:50]}...': {e}") return text return text def back_translation(self, text): if 'backtranslation' in self.aug_methods and hasattr(self, 'back_trans_aug'): try: augmented = self.back_trans_aug.augment(text) # ✅ FIX: Same safe handling as above if augmented is None: return text elif isinstance(augmented, list): if len(augmented) > 0 and isinstance(augmented[0], str) and augmented[0].strip(): return augmented[0] else: return text elif isinstance(augmented, str) and augmented.strip(): return augmented else: return text except Exception as e: print(f"Back-translation failed for text '{text[:50]}...': {e}") return text return text def augment_text(self, text, method=None): if not text or not isinstance(text, str) or not text.strip(): return text if method is None: method = random.choice(self.aug_methods) if method == 'swap': return self.random_swap(text, n=max(1, len(text.split()) // 10)) elif method == 'delete': return self.random_deletion(text, p=self.aug_p) elif method == 'insert': return self.random_insertion(text, n=max(1, len(text.split()) // 10)) elif method == 'synonym': return self.synonym_replacement(text) elif method == 'backtranslation': return self.back_translation(text) else: return text def augment_dataset(self, texts, labels, n_aug=1, keep_original=True): print(f"Augmenting dataset (n_aug={n_aug}, methods={self.aug_methods})...") augmented_texts = [] augmented_labels = [] if keep_original: augmented_texts.extend(texts) augmented_labels.extend(labels) for text, label in tqdm(zip(texts, labels), total=len(texts), desc="Augmenting"): for _ in range(n_aug): aug_text = self.augment_text(text) augmented_texts.append(aug_text) augmented_labels.append(label) print(f"Original size: {len(texts)}") print(f"Augmented size: {len(augmented_texts)}") return np.array(augmented_texts), np.array(augmented_labels) def augment_minority_classes(self, texts, labels, target_ratio=1.0): from collections import Counter class_counts = Counter(labels) max_count = max(class_counts.values()) augmented_texts = list(texts) augmented_labels = list(labels) for class_label, count in class_counts.items(): target_count = int(max_count * target_ratio) n_to_generate = max(0, target_count - count) if n_to_generate > 0: class_texts = [text for text, label in zip(texts, labels) if label == class_label] for _ in tqdm(range(n_to_generate), desc=f"Augmenting class {class_label}"): text = random.choice(class_texts) aug_text = self.augment_text(text) augmented_texts.append(aug_text) augmented_labels.append(class_label) print(f"\nOriginal size: {len(texts)}") print(f"Augmented size: {len(augmented_texts)}") print(f"Original distribution: {Counter(labels)}") print(f"Augmented distribution: {Counter(augmented_labels)}") return np.array(augmented_texts), np.array(augmented_labels) def augment_training_data(data_dict, n_aug=1, balance_classes=True, aug_methods=['synonym', 'swap', 'delete']): """ Augment training data to handle class imbalance Args: data_dict: Dictionary with 'train' key containing 'texts' and 'labels' n_aug: Number of augmentations per sample (if not balancing) balance_classes: If True, augment minority classes to match majority aug_methods: List of augmentation methods to use Returns: Updated data_dict with augmented training data """ augmenter = TextAugmenter(aug_methods=aug_methods, aug_p=0.1) train_texts = data_dict['train']['texts'] train_labels = data_dict['train']['labels'] if balance_classes: aug_texts, aug_labels = augmenter.augment_minority_classes( train_texts, train_labels, target_ratio=1.0 ) else: aug_texts, aug_labels = augmenter.augment_dataset( train_texts, train_labels, n_aug=n_aug, keep_original=True ) data_dict['train']['texts'] = aug_texts data_dict['train']['labels'] = aug_labels return data_dict