""" Dataset Balancing Module ========================= Class balancing for classification datasets via oversampling / undersampling strategies. """ from dataclasses import dataclass from typing import Dict, Optional import pandas as pd @dataclass class BalancingConfig: """Configuration for dataset balancing.""" enabled: bool = False label_column: str = "" strategy: str = "none" # "none", "oversample", "undersample" def compute_label_distribution( df: pd.DataFrame, label_col: str, ) -> Dict[str, int]: """ Compute label distribution for a given column. Returns dict of label_value -> count. """ if label_col not in df.columns: return {} return df[label_col].value_counts().to_dict() def oversample_minority( df: pd.DataFrame, label_col: str, ) -> pd.DataFrame: """ Oversample minority classes to match the majority class count. """ if label_col not in df.columns: return df counts = df[label_col].value_counts() max_count = counts.max() frames = [] for label, count in counts.items(): label_df = df[df[label_col] == label] if count < max_count: # Resample with replacement to reach max_count extra = label_df.sample(n=max_count - count, replace=True, random_state=42) frames.append(pd.concat([label_df, extra], ignore_index=True)) else: frames.append(label_df) return pd.concat(frames, ignore_index=True) def undersample_majority( df: pd.DataFrame, label_col: str, ) -> pd.DataFrame: """ Undersample majority classes to match the minority class count. """ if label_col not in df.columns: return df counts = df[label_col].value_counts() min_count = counts.min() frames = [] for label in counts.index: label_df = df[df[label_col] == label] if len(label_df) > min_count: frames.append(label_df.sample(n=min_count, random_state=42)) else: frames.append(label_df) return pd.concat(frames, ignore_index=True) def balance_dataset( df: pd.DataFrame, label_col: str, strategy: str = "none", ) -> pd.DataFrame: """ Balance dataset using the specified strategy. strategy: "none", "oversample", or "undersample" """ if strategy == "oversample": return oversample_minority(df, label_col) elif strategy == "undersample": return undersample_majority(df, label_col) return df