"""Preprocessing helpers for transformer training. This module provides utilities to parse multi-label strings, ensure the `combo` column exists, perform label-aware supersampling of a training DataFrame, and a light-weight `load_or_prepare_data` entrypoint that loads raw CSVs, optionally applies preprocessing, and writes processed CSVs. """ import logging import os from typing import Tuple import numpy as np import pandas as pd logger = logging.getLogger(__name__) def parse_label_str(s: str) -> np.ndarray: """Convert a string like '[0 0 1 0 0 0 0]' into a float32 numpy array.""" return np.fromstring(str(s).strip("[]"), sep=" ", dtype=np.float32) def ensure_combo_column(df: pd.DataFrame) -> pd.DataFrame: """Ensure that the 'combo' column exists. If missing, create it from 'comment_sentence' and 'class'. """ if "combo" not in df.columns: logger.info("Column 'combo' not found, creating it from 'comment_sentence' and 'class'.") df = df.copy() df["combo"] = df["comment_sentence"].astype(str) + " | " + df["class"].astype(str) else: logger.info("Column 'combo' already present, reusing it.") return df def supersample_dataframe( df: pd.DataFrame, factor: float, random_state: int = 42, ) -> pd.DataFrame: """Offline label-aware supersampling of the training DataFrame. - Keeps all original rows. - For each label j, duplicates rows that contain that label until: target_j = min(max_freq, freq_j * factor) where freq_j is the original count for label j and max_freq is the maximum frequency across labels. - Shuffles the resulting indices. Assumes: - df['labels'] is a string representation of a multi-hot vector. """ if factor <= 1.0: logger.info( "Supersampling factor <= 1.0 (%.2f), returning original DataFrame.", factor, ) return df.copy() rng = np.random.default_rng(random_state) labels_array = np.stack(df["labels"].map(parse_label_str).values) if labels_array.ndim == 1: labels_array = labels_array[:, None] num_samples, num_labels = labels_array.shape freq = labels_array.sum(axis=0).astype(int) max_freq = int(freq.max()) logger.info("Original label frequencies: %s", freq.tolist()) logger.info("Max label frequency: %d", max_freq) if max_freq == 0: logger.warning("All label frequencies are zero, skipping supersampling.") return df.copy() target = np.minimum(max_freq, (freq * factor).astype(int)) logger.info( "Target label frequencies after supersampling (capped by max_freq): %s", target.tolist(), ) indices_by_label = {j: np.where(labels_array[:, j] == 1)[0] for j in range(num_labels)} new_indices = list(range(num_samples)) for j in range(num_labels): current = int(freq[j]) desired = int(target[j]) if desired <= current: continue candidate_indices = indices_by_label[j] if candidate_indices.size == 0: continue needed = desired - current extra = rng.choice(candidate_indices, size=needed, replace=True) new_indices.extend(extra.tolist()) logger.info( "Label %d: current=%d, target=%d, added=%d samples.", j, current, desired, needed, ) rng.shuffle(new_indices) df_sup = df.iloc[new_indices].reset_index(drop=True) labels_array_after = np.stack(df_sup["labels"].map(parse_label_str).values) freq_after = labels_array_after.sum(axis=0).astype(int) logger.info("Final label frequencies after supersampling: %s", freq_after.tolist()) logger.info("Training rows before: %d, after: %d", num_samples, len(df_sup)) return df_sup def load_or_prepare_data( lang: str, raw_data_dir: str, processed_data_dir: str, preprocessing_enabled: bool, preprocessing_factor: float, random_state: int = 42, ) -> Tuple[pd.DataFrame, pd.DataFrame, str]: """Load raw CSVs for the given language, optionally apply preprocessing. (supersampling) on the train split, and save processed CSVs. - Test split is NEVER supersampled or augmented. - Train split: - always gets 'combo' and 'labels_array' - supersampled only if preprocessing_enabled=True and preprocessing_factor>1.0 Parameters ---------- lang : str Language key (e.g., 'java', 'python', 'pharo'). raw_data_dir : str Directory containing {lang}_train.csv and {lang}_test.csv. processed_data_dir : str Directory where processed CSVs will be saved. preprocessing_enabled : bool Whether to apply supersampling on the training split. preprocessing_factor : float Supersampling factor (ignored if preprocessing_enabled=False). random_state : int RNG seed. Returns ------- train_df : pd.DataFrame eval_df : pd.DataFrame preprocessing_used : str One of: 'none', 'supersampling'. """ logger.info("Loading raw CSVs for language '%s' from '%s'.", lang, raw_data_dir) raw_train_path = os.path.join(raw_data_dir, f"{lang}_train.csv") raw_eval_path = os.path.join(raw_data_dir, f"{lang}_test.csv") if not os.path.exists(raw_train_path): raise FileNotFoundError(f"Raw train CSV not found: {raw_train_path}") if not os.path.exists(raw_eval_path): raise FileNotFoundError(f"Raw test CSV not found: {raw_eval_path}") train_df = pd.read_csv(raw_train_path) eval_df = pd.read_csv(raw_eval_path) train_df = ensure_combo_column(train_df) eval_df = ensure_combo_column(eval_df) if preprocessing_enabled and preprocessing_factor > 1.0: logger.info( "Preprocessing enabled: applying supersampling with factor=%.2f.", preprocessing_factor, ) train_df = supersample_dataframe( train_df, factor=preprocessing_factor, random_state=random_state, ) preprocessing_used = "supersampling" else: logger.info( "Preprocessing disabled or factor <= 1.0 (%.2f). Using original training data.", preprocessing_factor, ) preprocessing_used = "none" # Save processed CSVs (for inspection / reproducibility) os.makedirs(processed_data_dir, exist_ok=True) processed_train_path = os.path.join(processed_data_dir, f"{lang}_train.csv") processed_eval_path = os.path.join(processed_data_dir, f"{lang}_test.csv") train_df.to_csv(processed_train_path, index=False) eval_df.to_csv(processed_eval_path, index=False) logger.info("Saved processed train/test CSVs to '%s'.", processed_data_dir) # Ensure 'labels_array' exists for both splits for df, split_name in ((train_df, "train"), (eval_df, "test")): if "labels_array" not in df.columns: logger.info("Parsing label strings into arrays for split '%s'.", split_name) df["labels_array"] = df["labels"].apply(parse_label_str) return train_df, eval_df, preprocessing_used