File size: 2,615 Bytes
d4398e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""

Dataset Balancing Module

=========================

Class balancing for classification datasets via

oversampling / undersampling strategies.

"""

from dataclasses import dataclass
from typing import Dict, Optional
import pandas as pd


@dataclass
class BalancingConfig:
    """Configuration for dataset balancing."""
    enabled: bool = False
    label_column: str = ""
    strategy: str = "none"  # "none", "oversample", "undersample"


def compute_label_distribution(

    df: pd.DataFrame,

    label_col: str,

) -> Dict[str, int]:
    """

    Compute label distribution for a given column.

    Returns dict of label_value -> count.

    """
    if label_col not in df.columns:
        return {}
    return df[label_col].value_counts().to_dict()


def oversample_minority(

    df: pd.DataFrame,

    label_col: str,

) -> pd.DataFrame:
    """

    Oversample minority classes to match the majority class count.

    """
    if label_col not in df.columns:
        return df

    counts = df[label_col].value_counts()
    max_count = counts.max()

    frames = []
    for label, count in counts.items():
        label_df = df[df[label_col] == label]
        if count < max_count:
            # Resample with replacement to reach max_count
            extra = label_df.sample(n=max_count - count, replace=True, random_state=42)
            frames.append(pd.concat([label_df, extra], ignore_index=True))
        else:
            frames.append(label_df)

    return pd.concat(frames, ignore_index=True)


def undersample_majority(

    df: pd.DataFrame,

    label_col: str,

) -> pd.DataFrame:
    """

    Undersample majority classes to match the minority class count.

    """
    if label_col not in df.columns:
        return df

    counts = df[label_col].value_counts()
    min_count = counts.min()

    frames = []
    for label in counts.index:
        label_df = df[df[label_col] == label]
        if len(label_df) > min_count:
            frames.append(label_df.sample(n=min_count, random_state=42))
        else:
            frames.append(label_df)

    return pd.concat(frames, ignore_index=True)


def balance_dataset(

    df: pd.DataFrame,

    label_col: str,

    strategy: str = "none",

) -> pd.DataFrame:
    """

    Balance dataset using the specified strategy.

    strategy: "none", "oversample", or "undersample"

    """
    if strategy == "oversample":
        return oversample_minority(df, label_col)
    elif strategy == "undersample":
        return undersample_majority(df, label_col)
    return df