import pandas as pd import nlpaug.augmenter.word as naw import nlpaug.augmenter.char as nac # Primary: synonym replacement (good when WordNet has coverage) aug_synonym = naw.SynonymAug(aug_src='wordnet', aug_p=0.15) # Fallback: random swap of word positions — no dictionary needed, always works aug_swap = naw.RandomWordAug(action='swap', aug_p=0.2) # Second fallback: random char-level keyboard typo simulation — always produces change aug_char = nac.KeyboardAug(aug_char_p=0.1, aug_word_p=0.2) def _try_augment(text: str) -> str | None: """Try augmenters in order, return first result that differs from original.""" for augmenter in [aug_synonym, aug_swap, aug_char]: try: result = augmenter.augment(text) new_text = result[0] if isinstance(result, list) else result if new_text.strip() != text.strip(): return new_text except Exception: continue # All augmenters failed or returned same text — return a simple word-swap manually words = text.split() if len(words) >= 2: words[0], words[-1] = words[-1], words[0] return " ".join(words) return None # single-word text, truly can't augment def augment_class(df: pd.DataFrame, label_id: int, target_count: int) -> pd.DataFrame: df = df.copy() df["label"] = df["label"].astype(int) class_df = df[df["label"] == label_id] if len(class_df) == 0: raise ValueError( f"No rows found for label_id={label_id}. " f"Available labels: {df['label'].unique().tolist()} " f"(dtype: {df['label'].dtype})" ) already_have = len(class_df) if already_have >= target_count: print(f" Label {label_id} already has {already_have} samples — skipping") return df needed = target_count - already_have print(f" Label {label_id}: {already_have} → {target_count} (generating {needed} samples)") augmented_rows = [] attempts = 0 max_attempts = needed * 10 # increased — some headlines will be hard to augment while len(augmented_rows) < needed and attempts < max_attempts: attempts += 1 original = class_df["text"].sample(1).values[0] new_text = _try_augment(original) if new_text: augmented_rows.append({"text": new_text, "label": label_id}) if len(augmented_rows) < needed: print(f" Warning: only generated {len(augmented_rows)}/{needed} after {max_attempts} attempts") result_df = pd.concat([df, pd.DataFrame(augmented_rows)], ignore_index=True) result_df["label"] = result_df["label"].astype(int) result_df["text"] = result_df["text"].astype(str) print(f" Done. Distribution: {result_df['label'].value_counts().sort_index().to_dict()}") return result_df