File size: 9,335 Bytes
66242b8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import pandas as pd
from datasets import load_from_disk
@dataclass(slots=True)
class Config():
verbose: bool = True
max_words: Optional[int] = None
min_words: Optional[int] = None
config = Config()
# ================ HELPERS ================
def mask_invalid_rows(df: pd.DataFrame) -> pd.DataFrame:
"""
Remove rows with invalid text fields (non-string, empty) or invalid labels (not 0 or 1).
No need to check NaNs because the dataset does not contain NaNs.
"""
text1_ok = df["text1"].map(lambda x: isinstance(x, str) and x.strip() != "")
text2_ok = df["text2"].map(lambda x: isinstance(x, str) and x.strip() != "")
label_ok = df["same"].isin([0, 1])
valid_mask = text1_ok & text2_ok & label_ok
return df[valid_mask] # return the valid rows
def _within_word_count_range(value, min_words: Optional[int] = None, max_words: Optional[int] = None) -> bool:
count = len(value.split()) # count the number of words by splitting on whitespace (rough estimate)
if min_words is not None and count < min_words: return False
if max_words is not None and count > max_words: return False
return True # assume no invalid rows from previous check (mask_invalid_rows)
def mask_by_word_count(df: pd.DataFrame,
min_words: Optional[int] = None,
max_words: Optional[int] = None
) -> pd.DataFrame:
"""
"""
if min_words is None and max_words is None: return df # no filtering needed
text1_ok = df["text1"].map(lambda x: _within_word_count_range(x, min_words, max_words))
text2_ok = df["text2"].map(lambda x: _within_word_count_range(x, min_words, max_words))
valid_mask = text1_ok & text2_ok
return df[valid_mask] # return the valid rows
def _symmetric_pair_id(text1: str, text2: str, same: int) -> list[str | int]:
"""
Generate a unique identifier for each pair of texts, such that the order of the texts does not matter.
For example, if text1="A", text2="B", same=1, then the symmetric pair ID should be the same as text1="B", text2="A", same=1.
"""
text1, text2 = sorted([text1, text2]) # ensure the order of texts does not matter by sorting them alphabetically
return text1, text2, same # combine the sorted texts and label into a tuple to create a unique identifier for the pair
# def _duplicate_group_keep_row(df: pd.DataFrame, split_priority: dict[str, int]) -> pd.Series:
# ranked = df.assign(__split_priority__=df["__split__"].map(split_priority)).sort_values(by=["__split_priority__", "__row_id__"], kind="stable") # stable sort to maintain original order within each split
# return ranked.iloc[0] # keep the first row (highest priority split, then lowest row_id) and drop the rest as duplicates
def _length_stats(series: pd.Series) -> dict[str, float]:
word_length = series.str.split().str.len()
char_length = series.str.len()
return {
"median_word_length": round(word_length.median(), 2),
"mean_word_length": round(word_length.mean(), 2),
"std_word_length": round(word_length.std(), 2),
"median_char_length": round(char_length.median(), 2),
"mean_char_length": round(char_length.mean(), 2),
"std_char_length": round(char_length.std(), 2),
}
# ================ MAIN FUNCTIONS ================
def load_all_splits(path: Path,
config: Config = config
) -> pd.DataFrame:
"""
Loads all three splits (train, validation, test) from the given path and concatenates them into a single DataFrame.
Assumes that the splits are stored in .arrow files under the corresponding split folder.
"""
dict_df: dict[str, pd.DataFrame] = {}
for split in ["train", "validation", "test"]:
dataset_dir = path / f"authorship_verification_{split}"
ds = load_from_disk(dataset_dir)
df = ds.to_pandas()
df["__split__"] = split # identify which split
# df["__row_id__"] = df.index
dict_df[split] = df
if config.verbose:
print(f"Loaded {split} split: {len(df)} rows")
return dict_df
def mask_rows(dict_df: dict[str, pd.DataFrame],
min_words: Optional[int] = None,
max_words: Optional[int] = None,
config: Config = config
) -> pd.DataFrame:
"""
Mask rows with invalid text fields, invalid labels, and optionally filter by word count.
Then identify and drop duplicate pairs across all splits, keeping only the first occurrence based on split priority (train > validation > test) and row_id.
"""
if config.verbose:
if min_words is not None or max_words is not None:
print(f"\nWord-count filter used: min_words={min_words}, max_words={max_words}")
else: print("\nNo word-count filter used\n")
rows_before_masking: dict[str, int] = {}
for split, df in dict_df.items():
rows_before_masking[split] = len(dict_df[split])
print("\nStarting masking rows based on invalid text fields and labels, and word count if specified...\n")
df_valid: list[pd.DataFrame] = [] # placeholder for valid rows from all splits
for split, df in dict_df.items():
if config.verbose:
print(f" Processing split='{split}' ({len(df):,} rows)")
df = mask_invalid_rows(df)
df = mask_by_word_count(df, min_words, max_words)
df_valid.append(df) # update the filtered df back into the dict
if config.verbose:
print(f" Split='{split}': remove {rows_before_masking[split] - len(df):,} rows, {len(df):,} rows remain\n")
df_valid = pd.concat(df_valid, ignore_index=True) # combine valid rows from all splits into one DataFrame
df_valid["__symmetric_pair_id__"] = df_valid.apply(lambda row: _symmetric_pair_id(row["text1"], row["text2"], row["same"]), axis=1) # identify symmetric pairs
if config.verbose:
print("Checking for duplicate pairs (same texts and label, regardless of order) before deduplication...")
num_duplicates = df_valid.duplicated(subset="__symmetric_pair_id__").sum()
print(f"There are {num_duplicates:,} duplicate pairs across all splits before deduplication\n")
# drop duplicate pairs, keeping only the first occurrence
# since df_valid is sorted by ["train", "validation", "test"] in load_all_splits, keep="first" will prioritize keeping the row from the train split, then validation, then test
df_valid = df_valid.drop_duplicates(subset="__symmetric_pair_id__", keep="first", ignore_index=True)
# for _, group in df_valid.groupby("__symmetric_pair_id__", sort=False):
# if len(group) < 2: continue # only interested in duplicate pairs
# split_priority = {"train": 3, "validation": 2, "test": 1} # define split priority for keeping rows
# keep_row = _duplicate_group_keep_row(group, split_priority) # get the row to keep based on split priority and row_id
# drop_rows = group.index.difference([keep_row.name]) # identify rows to drop (all except the keep_row)
# df_valid = df_valid.drop(index=drop_rows) # drop the duplicate rows, keeping only the one with the highest priority
for split, df in dict_df.items():
if config.verbose:
print(f" Split='{split}': {len(df_valid[df_valid['__split__'] == split]):,} valid rows after deduplication")
dict_df[split] = df_valid[df_valid["__split__"] == split].drop(columns=["__split__", "__symmetric_pair_id__"]) # update the dict with the valid rows for each split, dropping the helper column
return dict_df
def summary_stats(dict_df: dict[str, pd.DataFrame]) -> pd.DataFrame:
"""Generate summary statistics for each split.
"""
summary = []
for split, df in dict_df.items():
stats_text1 = _length_stats(df["text1"])
stats_text2 = _length_stats(df["text2"])
same_distribution = df["same"].value_counts().to_dict()
summary.append({
"split": split,
"num_rows": len(df),
"same_0_count": int(same_distribution.get(0, 0)),
"same_1_count": int(same_distribution.get(1, 0)),
"same_0_ratio": round(df["same"].eq(0).mean(), 4),
"same_1_ratio": round(df["same"].eq(1).mean(), 4),
**{f"text1_{k}": v for k, v in stats_text1.items()},
**{f"text2_{k}": v for k, v in stats_text2.items()},
})
summary_df = pd.DataFrame(summary)
print("\nSummary statistics for each split:")
print(summary_df)
return summary_df
def audit_wrapper(path: Path, # folder path to all three splits
config: Config = config
) -> tuple[dict[str, pd.DataFrame], pd.DataFrame]:
"""
Wrapper around the audit structure.
"""
if config.verbose:
print("======= AUDIT START =======")
print("")
dict_df = load_all_splits(path)
dict_df = mask_rows(dict_df, config.min_words, config.max_words, config)
summary_df = summary_stats(dict_df)
if config.verbose:
print("")
print("======= AUDIT END =======")
print("")
return dict_df, summary_df
|