Spaces:
Configuration error
Configuration error
File size: 2,615 Bytes
d4398e6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | """
Dataset Balancing Module
=========================
Class balancing for classification datasets via
oversampling / undersampling strategies.
"""
from dataclasses import dataclass
from typing import Dict, Optional
import pandas as pd
@dataclass
class BalancingConfig:
"""Configuration for dataset balancing."""
enabled: bool = False
label_column: str = ""
strategy: str = "none" # "none", "oversample", "undersample"
def compute_label_distribution(
df: pd.DataFrame,
label_col: str,
) -> Dict[str, int]:
"""
Compute label distribution for a given column.
Returns dict of label_value -> count.
"""
if label_col not in df.columns:
return {}
return df[label_col].value_counts().to_dict()
def oversample_minority(
df: pd.DataFrame,
label_col: str,
) -> pd.DataFrame:
"""
Oversample minority classes to match the majority class count.
"""
if label_col not in df.columns:
return df
counts = df[label_col].value_counts()
max_count = counts.max()
frames = []
for label, count in counts.items():
label_df = df[df[label_col] == label]
if count < max_count:
# Resample with replacement to reach max_count
extra = label_df.sample(n=max_count - count, replace=True, random_state=42)
frames.append(pd.concat([label_df, extra], ignore_index=True))
else:
frames.append(label_df)
return pd.concat(frames, ignore_index=True)
def undersample_majority(
df: pd.DataFrame,
label_col: str,
) -> pd.DataFrame:
"""
Undersample majority classes to match the minority class count.
"""
if label_col not in df.columns:
return df
counts = df[label_col].value_counts()
min_count = counts.min()
frames = []
for label in counts.index:
label_df = df[df[label_col] == label]
if len(label_df) > min_count:
frames.append(label_df.sample(n=min_count, random_state=42))
else:
frames.append(label_df)
return pd.concat(frames, ignore_index=True)
def balance_dataset(
df: pd.DataFrame,
label_col: str,
strategy: str = "none",
) -> pd.DataFrame:
"""
Balance dataset using the specified strategy.
strategy: "none", "oversample", or "undersample"
"""
if strategy == "oversample":
return oversample_minority(df, label_col)
elif strategy == "undersample":
return undersample_majority(df, label_col)
return df
|