""" Synthetic Data Generator for Drift Testing Generates synthetic drifted datasets to test drift detection. """ import random import string from typing import List, Tuple from loguru import logger import numpy as np class SyntheticDataGenerator: """ Generates synthetic code comment data with controlled drift characteristics. """ def __init__(self, seed: int = 42): """ Initialize synthetic data generator. """ self.seed = seed np.random.seed(seed) random.seed(seed) def generate_short_comments( self, reference_texts: List[str], ratio: float = 0.5, n_samples: int = 100, ) -> List[str]: """ Generate shorter comments (text length drift). """ short_comments = [] for _ in range(n_samples): ref_text = np.random.choice(reference_texts) words = ref_text.split() truncated_len = max(1, int(len(words) * ratio)) short_text = " ".join(words[:truncated_len]) short_comments.append(short_text) logger.debug(f"Generated {len(short_comments)} short comments") return short_comments def generate_long_comments( self, reference_texts: List[str], ratio: float = 1.5, n_samples: int = 100, ) -> List[str]: """ Generate longer comments (text length drift upward). """ long_comments = [] for _ in range(n_samples): ref_text = np.random.choice(reference_texts) words = ref_text.split() target_len = max(1, int(len(words) * ratio)) extended_words = words.copy() while len(extended_words) < target_len: extended_words.append(np.random.choice(words)) long_text = " ".join(extended_words[:target_len]) long_comments.append(long_text) logger.debug(f"Generated {len(long_comments)} long comments") return long_comments def generate_corrupted_vocabulary( self, reference_texts: List[str], corruption_rate: float = 0.5, n_samples: int = 100, ) -> List[str]: """ Generate texts with corrupted vocabulary (typos, character swaps). Args: reference_texts: Reference training texts corruption_rate: Fraction of words to corrupt (0.0-1.0) n_samples: Number of samples to generate Returns: List of corrupted texts """ corrupted_texts = [] for _ in range(n_samples): ref_text = np.random.choice(reference_texts) words = ref_text.split() # Corrupt some words for i in range(len(words)): if random.random() < corruption_rate: word = words[i] if len(word) > 2: # Random character swap or substitution if random.random() < 0.5: # Character swap idx = random.randint(0, len(word) - 2) word = word[:idx] + word[idx + 1] + word[idx] + word[idx + 2 :] else: # Character substitution idx = random.randint(0, len(word) - 1) word = ( word[:idx] + random.choice(string.ascii_lowercase) + word[idx + 1 :] ) words[i] = word corrupted_text = " ".join(words) corrupted_texts.append(corrupted_text) logger.debug(f"Generated {len(corrupted_texts)} corrupted texts (rate={corruption_rate})") return corrupted_texts def generate_label_shift( self, reference_texts: List[str], reference_labels: np.ndarray, shift_type: str = "class_imbalance", n_samples: int = 100, ) -> Tuple[List[str], np.ndarray]: """ Generate batch with label distribution shift (class imbalance). Args: reference_texts: Reference training texts reference_labels: Reference training labels (binary matrix) shift_type: 'class_imbalance' - favor majority class n_samples: Number of samples to generate Returns: Tuple of (texts, shifted_labels) """ texts = [] shifted_labels = [] if reference_labels.ndim == 2: # Multi-label: get the first label per sample label_indices = np.argmax(reference_labels, axis=1) else: label_indices = reference_labels # Get class distribution unique_labels, counts = np.unique(label_indices, return_counts=True) majority_class = unique_labels[np.argmax(counts)] minority_classes = unique_labels[unique_labels != majority_class] # Create imbalanced distribution: 80% majority, 20% minority n_majority = int(n_samples * 0.8) n_minority = n_samples - n_majority # Sample indices with bias toward majority class majority_indices = np.where(label_indices == majority_class)[0] minority_indices = np.where(np.isin(label_indices, minority_classes))[0] selected_indices = [] selected_indices.extend(np.random.choice(majority_indices, size=n_majority, replace=True)) if len(minority_indices) > 0: selected_indices.extend( np.random.choice(minority_indices, size=n_minority, replace=True) ) np.random.shuffle(selected_indices) selected_indices = selected_indices[:n_samples] # Get texts and labels texts = [reference_texts[i] for i in selected_indices] shifted_labels = reference_labels[selected_indices] logger.debug(f"Generated {len(texts)} samples with class imbalance") return texts, shifted_labels def generate_synthetic_batch( self, reference_texts: List[str], reference_labels: np.ndarray, drift_type: str = "none", batch_size: int = 50, ) -> Tuple[List[str], np.ndarray]: """ Generate a synthetic batch with specified drift. Args: reference_texts: Reference training texts reference_labels: Reference training labels drift_type: Type of drift to introduce: - 'none': No drift (baseline) - 'text_length_short': Shortened texts - 'text_length_long': Elongated texts - 'corrupted_vocab': Typos and character swaps - 'class_imbalance': Biased label distribution batch_size: Number of samples to generate Returns: Tuple of (texts, labels) """ if drift_type == "none": indices = np.random.choice(len(reference_texts), size=batch_size, replace=True) texts = [reference_texts[i] for i in indices] labels = reference_labels[indices] elif drift_type == "text_length_short": texts = self.generate_short_comments(reference_texts, ratio=0.5, n_samples=batch_size) indices = np.random.choice(len(reference_labels), size=batch_size) labels = reference_labels[indices] elif drift_type == "text_length_long": texts = self.generate_long_comments(reference_texts, ratio=1.5, n_samples=batch_size) indices = np.random.choice(len(reference_labels), size=batch_size) labels = reference_labels[indices] elif drift_type == "corrupted_vocab": texts = self.generate_corrupted_vocabulary( reference_texts, corruption_rate=0.2, n_samples=batch_size ) indices = np.random.choice(len(reference_labels), size=batch_size) labels = reference_labels[indices] elif drift_type == "class_imbalance": texts, labels = self.generate_label_shift( reference_texts, reference_labels, shift_type="class_imbalance", n_samples=batch_size, ) else: raise ValueError(f"Unknown drift type: {drift_type}") logger.info(f"Generated synthetic batch: {drift_type}, size={batch_size}") return texts, labels