File size: 3,819 Bytes
86b932c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import logging
import pandas as pd
import numpy as np

logger = logging.getLogger(__name__)

def compute_domain_weights(
    df: pd.DataFrame, 
    min_domain_samples: int = 20, 
    max_multiplier: float = 10.0
) -> pd.DataFrame:
    """
    Compute domain-aware sample weights to handle heavily biased domains.
    
    Rules:
    - Domains with < min_domain_samples -> merge into "other"
    - Calculate global class distributions.
    - Calculate per-domain class distributions.
    - Compute weight = global_class_ratio / domain_class_ratio
    - Clip weights at max_multiplier * median_weight
    
    Args:
        df: Input DataFrame containing 'source_domain' and 'binary_label'
        min_domain_samples: Threshold below which domains are grouped to 'other'
        max_multiplier: Max multiplier over the median weight to clip extreme weights
        
    Returns:
        DataFrame with an additional 'sample_weight' column.
    """
    df = df.copy()
    
    # Ensure source_domain exists
    if "source_domain" not in df.columns:
        logger.warning("'source_domain' not found in DataFrame. Returning weights=1.0")
        df["sample_weight"] = 1.0
        return df
        
    # 1. Merge small domains into "other"
    domain_counts = df["source_domain"].value_counts()
    small_domains = set(domain_counts[domain_counts < min_domain_samples].index)
    
    df["_effective_domain"] = df["source_domain"].apply(
        lambda x: "other" if x in small_domains or not isinstance(x, str) else x
    )
    
    # 2. Compute global class ratios
    global_counts = df["binary_label"].value_counts()
    global_total = len(df)
    global_ratio = {
        label: count / global_total 
        for label, count in global_counts.items()
    }
    
    # 3. Compute domain class ratios and assign weights
    # We group by domain and label to get counts per domain
    domain_label_counts = df.groupby(["_effective_domain", "binary_label"]).size().unstack(fill_value=0)
    domain_totals = domain_label_counts.sum(axis=1)
    
    weights_map = {}
    for domain in domain_label_counts.index:
        weights_map[domain] = {}
        d_total = domain_totals[domain]
        for label in global_ratio.keys():
            if label in domain_label_counts.columns:
                d_count = domain_label_counts.loc[domain, label]
                if d_count == 0:
                    # If domain has 0 instances of this class, we won't observe it here anyway,
                    # but set some fallback value.
                    weights_map[domain][label] = 1.0
                else:
                    d_ratio = d_count / d_total
                    weights_map[domain][label] = global_ratio[label] / d_ratio
            else:
                weights_map[domain][label] = 1.0
                
    # 4. Map weights back to dataframe
    df["sample_weight"] = df.apply(
        lambda r: weights_map[r["_effective_domain"]].get(r["binary_label"], 1.0),
        axis=1
    )
    
    # 5. Clip weights at max_multiplier * median_weight
    median_w = df["sample_weight"].median()
    max_w = max_multiplier * median_w
    df["sample_weight"] = df["sample_weight"].clip(upper=max_w)
    
    # Clean up temp col
    df.drop(columns=["_effective_domain"], inplace=True)
    
    logger.info("Computed domain weights (median: %.3f, max applied: %.3f)", median_w, df["sample_weight"].max())
    
    return df

if __name__ == "__main__":
    # Test script
    data = pd.DataFrame({
        "source_domain": ["nytimes.com"] * 100 + ["fakenews.biz"] * 100 + ["tinyblog.com"] * 5,
        "binary_label": [1] * 90 + [0] * 10 + [0] * 95 + [1] * 5 + [0] * 5
    })
    
    out = compute_domain_weights(data, min_domain_samples=20, max_multiplier=10.0)
    print(out.groupby("source_domain")["sample_weight"].mean())