Spaces:
Configuration error
Configuration error
| """ | |
| Dataset Balancing Module | |
| ========================= | |
| Class balancing for classification datasets via | |
| oversampling / undersampling strategies. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Dict, Optional | |
| import pandas as pd | |
| class BalancingConfig: | |
| """Configuration for dataset balancing.""" | |
| enabled: bool = False | |
| label_column: str = "" | |
| strategy: str = "none" # "none", "oversample", "undersample" | |
| def compute_label_distribution( | |
| df: pd.DataFrame, | |
| label_col: str, | |
| ) -> Dict[str, int]: | |
| """ | |
| Compute label distribution for a given column. | |
| Returns dict of label_value -> count. | |
| """ | |
| if label_col not in df.columns: | |
| return {} | |
| return df[label_col].value_counts().to_dict() | |
| def oversample_minority( | |
| df: pd.DataFrame, | |
| label_col: str, | |
| ) -> pd.DataFrame: | |
| """ | |
| Oversample minority classes to match the majority class count. | |
| """ | |
| if label_col not in df.columns: | |
| return df | |
| counts = df[label_col].value_counts() | |
| max_count = counts.max() | |
| frames = [] | |
| for label, count in counts.items(): | |
| label_df = df[df[label_col] == label] | |
| if count < max_count: | |
| # Resample with replacement to reach max_count | |
| extra = label_df.sample(n=max_count - count, replace=True, random_state=42) | |
| frames.append(pd.concat([label_df, extra], ignore_index=True)) | |
| else: | |
| frames.append(label_df) | |
| return pd.concat(frames, ignore_index=True) | |
| def undersample_majority( | |
| df: pd.DataFrame, | |
| label_col: str, | |
| ) -> pd.DataFrame: | |
| """ | |
| Undersample majority classes to match the minority class count. | |
| """ | |
| if label_col not in df.columns: | |
| return df | |
| counts = df[label_col].value_counts() | |
| min_count = counts.min() | |
| frames = [] | |
| for label in counts.index: | |
| label_df = df[df[label_col] == label] | |
| if len(label_df) > min_count: | |
| frames.append(label_df.sample(n=min_count, random_state=42)) | |
| else: | |
| frames.append(label_df) | |
| return pd.concat(frames, ignore_index=True) | |
| def balance_dataset( | |
| df: pd.DataFrame, | |
| label_col: str, | |
| strategy: str = "none", | |
| ) -> pd.DataFrame: | |
| """ | |
| Balance dataset using the specified strategy. | |
| strategy: "none", "oversample", or "undersample" | |
| """ | |
| if strategy == "oversample": | |
| return oversample_minority(df, label_col) | |
| elif strategy == "undersample": | |
| return undersample_majority(df, label_col) | |
| return df | |