| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional |
| import pandas as pd |
| from datasets import load_from_disk |
|
|
|
|
| @dataclass(slots=True) |
| class Config(): |
| verbose: bool = True |
| max_words: Optional[int] = None |
| min_words: Optional[int] = None |
|
|
| config = Config() |
|
|
|
|
| |
|
|
|
|
|
|
| def mask_invalid_rows(df: pd.DataFrame) -> pd.DataFrame: |
| """ |
| Remove rows with invalid text fields (non-string, empty) or invalid labels (not 0 or 1). |
| No need to check NaNs because the dataset does not contain NaNs. |
| """ |
| text1_ok = df["text1"].map(lambda x: isinstance(x, str) and x.strip() != "") |
| text2_ok = df["text2"].map(lambda x: isinstance(x, str) and x.strip() != "") |
| label_ok = df["same"].isin([0, 1]) |
| valid_mask = text1_ok & text2_ok & label_ok |
| return df[valid_mask] |
|
|
|
|
|
|
| def _within_word_count_range(value, min_words: Optional[int] = None, max_words: Optional[int] = None) -> bool: |
| count = len(value.split()) |
| if min_words is not None and count < min_words: return False |
| if max_words is not None and count > max_words: return False |
| return True |
|
|
| def mask_by_word_count(df: pd.DataFrame, |
| min_words: Optional[int] = None, |
| max_words: Optional[int] = None |
| ) -> pd.DataFrame: |
| """ |
| """ |
| if min_words is None and max_words is None: return df |
|
|
| text1_ok = df["text1"].map(lambda x: _within_word_count_range(x, min_words, max_words)) |
| text2_ok = df["text2"].map(lambda x: _within_word_count_range(x, min_words, max_words)) |
| valid_mask = text1_ok & text2_ok |
| return df[valid_mask] |
|
|
|
|
|
|
| def _symmetric_pair_id(text1: str, text2: str, same: int) -> list[str | int]: |
| """ |
| Generate a unique identifier for each pair of texts, such that the order of the texts does not matter. |
| For example, if text1="A", text2="B", same=1, then the symmetric pair ID should be the same as text1="B", text2="A", same=1. |
| """ |
| text1, text2 = sorted([text1, text2]) |
| return text1, text2, same |
|
|
| |
| |
| |
|
|
|
|
| def _length_stats(series: pd.Series) -> dict[str, float]: |
|
|
| word_length = series.str.split().str.len() |
| char_length = series.str.len() |
| return { |
| "median_word_length": round(word_length.median(), 2), |
| "mean_word_length": round(word_length.mean(), 2), |
| "std_word_length": round(word_length.std(), 2), |
| "median_char_length": round(char_length.median(), 2), |
| "mean_char_length": round(char_length.mean(), 2), |
| "std_char_length": round(char_length.std(), 2), |
| } |
|
|
|
|
| |
|
|
|
|
| def load_all_splits(path: Path, |
| config: Config = config |
| ) -> pd.DataFrame: |
| """ |
| Loads all three splits (train, validation, test) from the given path and concatenates them into a single DataFrame. |
| Assumes that the splits are stored in .arrow files under the corresponding split folder. |
| """ |
| dict_df: dict[str, pd.DataFrame] = {} |
| for split in ["train", "validation", "test"]: |
| dataset_dir = path / f"authorship_verification_{split}" |
| ds = load_from_disk(dataset_dir) |
| df = ds.to_pandas() |
| df["__split__"] = split |
| |
| dict_df[split] = df |
|
|
| if config.verbose: |
| print(f"Loaded {split} split: {len(df)} rows") |
|
|
| return dict_df |
|
|
| def mask_rows(dict_df: dict[str, pd.DataFrame], |
| min_words: Optional[int] = None, |
| max_words: Optional[int] = None, |
| config: Config = config |
| ) -> pd.DataFrame: |
| |
| """ |
| Mask rows with invalid text fields, invalid labels, and optionally filter by word count. |
| Then identify and drop duplicate pairs across all splits, keeping only the first occurrence based on split priority (train > validation > test) and row_id. |
| """ |
| if config.verbose: |
| if min_words is not None or max_words is not None: |
| print(f"\nWord-count filter used: min_words={min_words}, max_words={max_words}") |
| else: print("\nNo word-count filter used\n") |
|
|
| rows_before_masking: dict[str, int] = {} |
| for split, df in dict_df.items(): |
| rows_before_masking[split] = len(dict_df[split]) |
| print("\nStarting masking rows based on invalid text fields and labels, and word count if specified...\n") |
|
|
| df_valid: list[pd.DataFrame] = [] |
| for split, df in dict_df.items(): |
|
|
| if config.verbose: |
| print(f" Processing split='{split}' ({len(df):,} rows)") |
|
|
| df = mask_invalid_rows(df) |
| df = mask_by_word_count(df, min_words, max_words) |
| df_valid.append(df) |
|
|
| if config.verbose: |
| print(f" Split='{split}': remove {rows_before_masking[split] - len(df):,} rows, {len(df):,} rows remain\n") |
|
|
| df_valid = pd.concat(df_valid, ignore_index=True) |
| df_valid["__symmetric_pair_id__"] = df_valid.apply(lambda row: _symmetric_pair_id(row["text1"], row["text2"], row["same"]), axis=1) |
| if config.verbose: |
| print("Checking for duplicate pairs (same texts and label, regardless of order) before deduplication...") |
| num_duplicates = df_valid.duplicated(subset="__symmetric_pair_id__").sum() |
| print(f"There are {num_duplicates:,} duplicate pairs across all splits before deduplication\n") |
|
|
| |
| |
| |
| df_valid = df_valid.drop_duplicates(subset="__symmetric_pair_id__", keep="first", ignore_index=True) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| for split, df in dict_df.items(): |
| if config.verbose: |
| print(f" Split='{split}': {len(df_valid[df_valid['__split__'] == split]):,} valid rows after deduplication") |
|
|
| dict_df[split] = df_valid[df_valid["__split__"] == split].drop(columns=["__split__", "__symmetric_pair_id__"]) |
|
|
| return dict_df |
|
|
|
|
| def summary_stats(dict_df: dict[str, pd.DataFrame]) -> pd.DataFrame: |
| """Generate summary statistics for each split. |
| """ |
| summary = [] |
| for split, df in dict_df.items(): |
| stats_text1 = _length_stats(df["text1"]) |
| stats_text2 = _length_stats(df["text2"]) |
| same_distribution = df["same"].value_counts().to_dict() |
| summary.append({ |
| "split": split, |
| "num_rows": len(df), |
| "same_0_count": int(same_distribution.get(0, 0)), |
| "same_1_count": int(same_distribution.get(1, 0)), |
| "same_0_ratio": round(df["same"].eq(0).mean(), 4), |
| "same_1_ratio": round(df["same"].eq(1).mean(), 4), |
| **{f"text1_{k}": v for k, v in stats_text1.items()}, |
| **{f"text2_{k}": v for k, v in stats_text2.items()}, |
| }) |
|
|
| summary_df = pd.DataFrame(summary) |
| print("\nSummary statistics for each split:") |
| print(summary_df) |
|
|
| return summary_df |
|
|
| def audit_wrapper(path: Path, |
| config: Config = config |
| ) -> tuple[dict[str, pd.DataFrame], pd.DataFrame]: |
| """ |
| Wrapper around the audit structure. |
| """ |
|
|
| if config.verbose: |
| print("======= AUDIT START =======") |
| print("") |
| |
|
|
| dict_df = load_all_splits(path) |
|
|
| dict_df = mask_rows(dict_df, config.min_words, config.max_words, config) |
| summary_df = summary_stats(dict_df) |
|
|
| if config.verbose: |
| print("") |
| print("======= AUDIT END =======") |
| print("") |
|
|
| return dict_df, summary_df |
|
|