Auto-FineTune-Ops / preprocessing /dataset_balancing.py
aneeb15's picture
Initial release of Auto-FineTune-Ops
d4398e6
"""
Dataset Balancing Module
=========================
Class balancing for classification datasets via
oversampling / undersampling strategies.
"""
from dataclasses import dataclass
from typing import Dict, Optional
import pandas as pd
@dataclass
class BalancingConfig:
"""Configuration for dataset balancing."""
enabled: bool = False
label_column: str = ""
strategy: str = "none" # "none", "oversample", "undersample"
def compute_label_distribution(
df: pd.DataFrame,
label_col: str,
) -> Dict[str, int]:
"""
Compute label distribution for a given column.
Returns dict of label_value -> count.
"""
if label_col not in df.columns:
return {}
return df[label_col].value_counts().to_dict()
def oversample_minority(
df: pd.DataFrame,
label_col: str,
) -> pd.DataFrame:
"""
Oversample minority classes to match the majority class count.
"""
if label_col not in df.columns:
return df
counts = df[label_col].value_counts()
max_count = counts.max()
frames = []
for label, count in counts.items():
label_df = df[df[label_col] == label]
if count < max_count:
# Resample with replacement to reach max_count
extra = label_df.sample(n=max_count - count, replace=True, random_state=42)
frames.append(pd.concat([label_df, extra], ignore_index=True))
else:
frames.append(label_df)
return pd.concat(frames, ignore_index=True)
def undersample_majority(
df: pd.DataFrame,
label_col: str,
) -> pd.DataFrame:
"""
Undersample majority classes to match the minority class count.
"""
if label_col not in df.columns:
return df
counts = df[label_col].value_counts()
min_count = counts.min()
frames = []
for label in counts.index:
label_df = df[df[label_col] == label]
if len(label_df) > min_count:
frames.append(label_df.sample(n=min_count, random_state=42))
else:
frames.append(label_df)
return pd.concat(frames, ignore_index=True)
def balance_dataset(
df: pd.DataFrame,
label_col: str,
strategy: str = "none",
) -> pd.DataFrame:
"""
Balance dataset using the specified strategy.
strategy: "none", "oversample", or "undersample"
"""
if strategy == "oversample":
return oversample_minority(df, label_col)
elif strategy == "undersample":
return undersample_majority(df, label_col)
return df