""" Toxic-only back-translation augmentation with cosine deduplication. Augments only the positive (toxic) class in the training set via EN→ES→EN, then drops synthetic samples too similar to the original training corpus. """ from __future__ import annotations import time from typing import Iterable import numpy as np import pandas as pd from sklearn.metrics.pairwise import cosine_similarity from src.utils.logger import get_logger logger = get_logger(__name__) def toxic_back_translation( texts: Iterable[str], labels: Iterable[int | bool], *, source_lang: str = "en", pivot_lang: str = "es", min_words: int = 3, max_words: int = 60, rate_limit_every: int = 50, rate_limit_sleep_sec: float = 1.0, seed: int = 42, ) -> tuple[list[str], list[int]]: """ Back-translate toxic samples only (label == 1). Returns parallel lists of augmented texts and labels (all toxic). """ try: from deep_translator import GoogleTranslator except ImportError as e: raise ImportError( "Install augmentation deps: uv sync --extra train" ) from e import random random.seed(seed) to_pivot = GoogleTranslator(source=source_lang, target=pivot_lang) to_source = GoogleTranslator(source=pivot_lang, target=source_lang) aug_texts: list[str] = [] aug_labels: list[int] = [] pairs = [ (str(t).strip(), int(bool(l))) for t, l in zip(texts, labels, strict=False) if int(bool(l)) == 1 ] logger.info(f"Back-translation: {len(pairs)} toxic samples") for i, (text, label) in enumerate(pairs): words = text.split() if len(words) < min_words: continue try: short = " ".join(words[:max_words]) pivot = to_pivot.translate(short) back = to_source.translate(pivot) if back and back.strip() and back.strip() != short.strip(): aug_texts.append(back.strip()) aug_labels.append(label) except Exception as exc: logger.warning(f"Back-translation failed at index {i}: {exc}") continue if rate_limit_every > 0 and i > 0 and i % rate_limit_every == 0: time.sleep(rate_limit_sleep_sec) logger.info(f"Back-translation produced {len(aug_texts)} samples") return aug_texts, aug_labels def back_translate_texts( texts: Iterable[str], *, source_lang: str = "en", pivot_lang: str = "de", max_words: int = 60, rate_limit_every: int = 50, rate_limit_sleep_sec: float = 1.0, fallback_to_original: bool = True, ) -> list[str]: """ Back-translate every text (EN→pivot→EN) for test-time augmentation. On failure, returns the original string when ``fallback_to_original`` is True. """ try: from deep_translator import GoogleTranslator except ImportError as e: raise ImportError( "Install augmentation deps: uv sync --extra train" ) from e to_pivot = GoogleTranslator(source=source_lang, target=pivot_lang) to_source = GoogleTranslator(source=pivot_lang, target=source_lang) out: list[str] = [] for i, raw in enumerate(texts): text = str(raw).strip() if not text: out.append(text) continue words = text.split() short = " ".join(words[:max_words]) try: pivot = to_pivot.translate(short) back = to_source.translate(pivot) out.append(back.strip() if back and back.strip() else text) except Exception as exc: logger.warning(f"TTA back-translation failed at index {i}: {exc}") out.append(text if fallback_to_original else short) if rate_limit_every > 0 and i > 0 and i % rate_limit_every == 0: time.sleep(rate_limit_sleep_sec) return out def deduplicate_by_cosine( synthetic_texts: list[str], synthetic_labels: list[int], reference_texts: list[str], *, threshold: float = 0.95, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", ) -> tuple[list[str], list[int]]: """ Remove synthetic samples with max cosine similarity > threshold vs reference. """ if not synthetic_texts: return [], [] try: from sentence_transformers import SentenceTransformer except ImportError as e: raise ImportError( "Install augmentation deps: uv sync --extra train" ) from e model = SentenceTransformer(embedding_model) ref_emb = model.encode(reference_texts, show_progress_bar=False, convert_to_numpy=True) syn_emb = model.encode(synthetic_texts, show_progress_bar=False, convert_to_numpy=True) sims = cosine_similarity(syn_emb, ref_emb) max_sim = sims.max(axis=1) kept_texts: list[str] = [] kept_labels: list[int] = [] dropped = 0 for text, label, sim in zip(synthetic_texts, synthetic_labels, max_sim, strict=False): if sim <= threshold: kept_texts.append(text) kept_labels.append(label) else: dropped += 1 logger.info( f"Dedup: kept {len(kept_texts)}/{len(synthetic_texts)} " f"(dropped {dropped} with cosine > {threshold})" ) return kept_texts, kept_labels def augment_toxic_train( X_train: pd.Series, y_train: pd.Series, cfg: dict, *, seed: int = 42, ) -> tuple[pd.Series, pd.Series]: """ Append toxic-only back-translated samples to training data (with dedup). """ aug_cfg = cfg.get("augmentation", {}) if not aug_cfg.get("enabled", True): return X_train, y_train syn_texts, syn_labels = toxic_back_translation( X_train.tolist(), y_train.tolist(), source_lang=aug_cfg.get("source_lang", "en"), pivot_lang=aug_cfg.get("pivot_lang", "es"), min_words=aug_cfg.get("min_words", 3), max_words=aug_cfg.get("max_words", 60), rate_limit_every=aug_cfg.get("rate_limit_every", 50), rate_limit_sleep_sec=aug_cfg.get("rate_limit_sleep_sec", 1.0), seed=seed, ) dedup_cfg = aug_cfg.get("dedup", {}) if dedup_cfg.get("enabled", True) and syn_texts: syn_texts, syn_labels = deduplicate_by_cosine( syn_texts, syn_labels, X_train.tolist(), threshold=float(dedup_cfg.get("cosine_threshold", 0.95)), embedding_model=dedup_cfg.get( "embedding_model", "sentence-transformers/all-MiniLM-L6-v2" ), ) if not syn_texts: return X_train, y_train X_aug = pd.concat( [X_train, pd.Series(syn_texts, name=X_train.name)], ignore_index=True, ) y_aug = pd.concat( [y_train, pd.Series(syn_labels, name=y_train.name)], ignore_index=True, ) logger.info(f"Train size after augmentation: {len(X_aug)} (+{len(syn_texts)})") return X_aug, y_aug