File size: 9,335 Bytes
66242b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import pandas as pd
from datasets import load_from_disk


@dataclass(slots=True)
class Config():
    verbose: bool = True
    max_words: Optional[int] = None
    min_words: Optional[int] = None

config = Config()


# ================ HELPERS ================



def mask_invalid_rows(df: pd.DataFrame) -> pd.DataFrame:
    """
    Remove rows with invalid text fields (non-string, empty) or invalid labels (not 0 or 1).
    No need to check NaNs because the dataset does not contain NaNs.
    """
    text1_ok = df["text1"].map(lambda x: isinstance(x, str) and x.strip() != "")
    text2_ok = df["text2"].map(lambda x: isinstance(x, str) and x.strip() != "")
    label_ok = df["same"].isin([0, 1])
    valid_mask = text1_ok & text2_ok & label_ok
    return df[valid_mask] # return the valid rows



def _within_word_count_range(value, min_words: Optional[int] = None, max_words: Optional[int] = None) -> bool:
    count = len(value.split()) # count the number of words by splitting on whitespace (rough estimate)
    if min_words is not None and count < min_words: return False
    if max_words is not None and count > max_words: return False
    return True # assume no invalid rows from previous check (mask_invalid_rows)

def mask_by_word_count(df: pd.DataFrame,
                        min_words: Optional[int] = None,
                        max_words: Optional[int] = None
                        ) -> pd.DataFrame:
    """
    """
    if min_words is None and max_words is None: return df # no filtering needed

    text1_ok = df["text1"].map(lambda x: _within_word_count_range(x, min_words, max_words))
    text2_ok = df["text2"].map(lambda x: _within_word_count_range(x, min_words, max_words))
    valid_mask = text1_ok & text2_ok
    return df[valid_mask] # return the valid rows



def _symmetric_pair_id(text1: str, text2: str, same: int) -> list[str | int]:
    """
    Generate a unique identifier for each pair of texts, such that the order of the texts does not matter.
    For example, if text1="A", text2="B", same=1, then the symmetric pair ID should be the same as text1="B", text2="A", same=1.
    """
    text1, text2 = sorted([text1, text2]) # ensure the order of texts does not matter by sorting them alphabetically
    return text1, text2, same # combine the sorted texts and label into a tuple to create a unique identifier for the pair

# def _duplicate_group_keep_row(df: pd.DataFrame, split_priority: dict[str, int]) -> pd.Series:
#     ranked = df.assign(__split_priority__=df["__split__"].map(split_priority)).sort_values(by=["__split_priority__", "__row_id__"], kind="stable") # stable sort to maintain original order within each split
#     return ranked.iloc[0] # keep the first row (highest priority split, then lowest row_id) and drop the rest as duplicates


def _length_stats(series: pd.Series) -> dict[str, float]:

    word_length = series.str.split().str.len()
    char_length = series.str.len()
    return {
        "median_word_length": round(word_length.median(), 2),
        "mean_word_length": round(word_length.mean(), 2),
        "std_word_length": round(word_length.std(), 2),
        "median_char_length": round(char_length.median(), 2),
        "mean_char_length": round(char_length.mean(), 2),
        "std_char_length": round(char_length.std(), 2),
    }


# ================ MAIN FUNCTIONS ================


def load_all_splits(path: Path,
                    config: Config = config
                    ) -> pd.DataFrame:
    """
    Loads all three splits (train, validation, test) from the given path and concatenates them into a single DataFrame.
    Assumes that the splits are stored in .arrow files under the corresponding split folder.
    """
    dict_df: dict[str, pd.DataFrame] = {}
    for split in ["train", "validation", "test"]:
        dataset_dir = path / f"authorship_verification_{split}"
        ds = load_from_disk(dataset_dir)
        df = ds.to_pandas()
        df["__split__"] = split # identify which split
        # df["__row_id__"] = df.index
        dict_df[split] = df

        if config.verbose:
            print(f"Loaded {split} split: {len(df)} rows")

    return dict_df

def mask_rows(dict_df: dict[str, pd.DataFrame],
            min_words: Optional[int] = None,
            max_words: Optional[int] = None,
            config: Config = config
            ) -> pd.DataFrame:
    
    """
    Mask rows with invalid text fields, invalid labels, and optionally filter by word count. 
    Then identify and drop duplicate pairs across all splits, keeping only the first occurrence based on split priority (train > validation > test) and row_id.
    """
    if config.verbose:
        if min_words is not None or max_words is not None: 
            print(f"\nWord-count filter used: min_words={min_words}, max_words={max_words}")
        else: print("\nNo word-count filter used\n")

        rows_before_masking: dict[str, int] = {}
        for split, df in dict_df.items():
            rows_before_masking[split] = len(dict_df[split])
        print("\nStarting masking rows based on invalid text fields and labels, and word count if specified...\n")

    df_valid: list[pd.DataFrame] = [] # placeholder for valid rows from all splits
    for split, df in dict_df.items():

        if config.verbose:
            print(f"    Processing split='{split}' ({len(df):,} rows)")

        df = mask_invalid_rows(df)
        df = mask_by_word_count(df, min_words, max_words)
        df_valid.append(df) # update the filtered df back into the dict

        if config.verbose:
            print(f"    Split='{split}': remove {rows_before_masking[split] - len(df):,} rows, {len(df):,} rows remain\n")

    df_valid = pd.concat(df_valid, ignore_index=True) # combine valid rows from all splits into one DataFrame
    df_valid["__symmetric_pair_id__"] = df_valid.apply(lambda row: _symmetric_pair_id(row["text1"], row["text2"], row["same"]), axis=1) # identify symmetric pairs
    if config.verbose:
        print("Checking for duplicate pairs (same texts and label, regardless of order) before deduplication...")
        num_duplicates = df_valid.duplicated(subset="__symmetric_pair_id__").sum()
        print(f"There are {num_duplicates:,} duplicate pairs across all splits before deduplication\n")

    
    # drop duplicate pairs, keeping only the first occurrence
    # since df_valid is sorted by ["train", "validation", "test"] in load_all_splits, keep="first" will prioritize keeping the row from the train split, then validation, then test
    df_valid = df_valid.drop_duplicates(subset="__symmetric_pair_id__", keep="first", ignore_index=True)

    # for _, group in df_valid.groupby("__symmetric_pair_id__", sort=False):
    #     if len(group) < 2: continue # only interested in duplicate pairs
        
    #     split_priority = {"train": 3, "validation": 2, "test": 1} # define split priority for keeping rows
    #     keep_row = _duplicate_group_keep_row(group, split_priority) # get the row to keep based on split priority and row_id
    #     drop_rows = group.index.difference([keep_row.name]) # identify rows to drop (all except the keep_row)
    #     df_valid = df_valid.drop(index=drop_rows) # drop the duplicate rows, keeping only the one with the highest priority

    for split, df in dict_df.items():
        if config.verbose:
            print(f"    Split='{split}': {len(df_valid[df_valid['__split__'] == split]):,} valid rows after deduplication")

        dict_df[split] = df_valid[df_valid["__split__"] == split].drop(columns=["__split__", "__symmetric_pair_id__"]) # update the dict with the valid rows for each split, dropping the helper column

    return dict_df


def summary_stats(dict_df: dict[str, pd.DataFrame]) -> pd.DataFrame:
    """Generate summary statistics for each split.
    """
    summary = []
    for split, df in dict_df.items():
        stats_text1 = _length_stats(df["text1"])
        stats_text2 = _length_stats(df["text2"])
        same_distribution = df["same"].value_counts().to_dict()
        summary.append({
            "split": split,
            "num_rows": len(df),
            "same_0_count": int(same_distribution.get(0, 0)),
            "same_1_count": int(same_distribution.get(1, 0)),
            "same_0_ratio": round(df["same"].eq(0).mean(), 4),
            "same_1_ratio": round(df["same"].eq(1).mean(), 4),
            **{f"text1_{k}": v for k, v in stats_text1.items()},
            **{f"text2_{k}": v for k, v in stats_text2.items()},
        })

    summary_df = pd.DataFrame(summary)
    print("\nSummary statistics for each split:")
    print(summary_df)

    return summary_df

def audit_wrapper(path: Path, # folder path to all three splits
                  config: Config = config
                  ) -> tuple[dict[str, pd.DataFrame], pd.DataFrame]:
    """
    Wrapper around the audit structure.
    """

    if config.verbose:
        print("======= AUDIT START =======")
        print("")
    

    dict_df = load_all_splits(path)

    dict_df = mask_rows(dict_df, config.min_words, config.max_words, config)
    summary_df = summary_stats(dict_df)

    if config.verbose:
        print("")
        print("======= AUDIT END =======")
        print("")

    return dict_df, summary_df