Spaces:
Sleeping
Sleeping
File size: 2,323 Bytes
5212b8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | """CUAD dataset loading and contract-level (leakage-safe) splitting."""
from __future__ import annotations
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
from src.config import CUAD_HF_ID, CUAD_PARQUET, SEED, TEST_SIZE
def load_cuad(force_download: bool = False) -> pd.DataFrame:
"""Load CUAD as a DataFrame. Cached locally as parquet after first run."""
if CUAD_PARQUET.exists() and not force_download:
return pd.read_parquet(CUAD_PARQUET)
try:
from datasets import load_dataset
except ImportError as e:
raise RuntimeError(
"datasets library required for initial download; install requirements.txt"
) from e
ds = load_dataset(CUAD_HF_ID)
df = ds["train"].to_pandas()
CUAD_PARQUET.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(CUAD_PARQUET, index=False)
return df
def contract_group_split(
df: pd.DataFrame, test_size: float = TEST_SIZE, seed: int = SEED
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Split by file_name so no contract appears in both train and test.
Returns (train_df, test_df). Index reset.
"""
splitter = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=seed)
train_idx, test_idx = next(splitter.split(df, groups=df["file_name"]))
train_df = df.iloc[train_idx].reset_index(drop=True)
test_df = df.iloc[test_idx].reset_index(drop=True)
return train_df, test_df
def naive_random_split(
df: pd.DataFrame, test_size: float = TEST_SIZE, seed: int = SEED
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Naive row-level shuffle split.
Provided for the leakage-comparison demo only; do not use for real eval.
"""
shuffled = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
cut = int(len(shuffled) * (1 - test_size))
return shuffled.iloc[:cut].reset_index(drop=True), shuffled.iloc[cut:].reset_index(drop=True)
def summarize(df: pd.DataFrame) -> dict:
"""Quick descriptive stats used in EDA and tests."""
return {
"n_rows": int(len(df)),
"n_contracts": int(df["file_name"].nunique()),
"n_labels": int(df["label"].nunique()),
"mean_clause_chars": float(df["clause"].str.len().mean()),
"label_top5": df["label"].value_counts().head(5).to_dict(),
}
|