"""CUAD dataset loading and contract-level (leakage-safe) splitting.""" from __future__ import annotations import pandas as pd from sklearn.model_selection import GroupShuffleSplit from src.config import CUAD_HF_ID, CUAD_PARQUET, SEED, TEST_SIZE def load_cuad(force_download: bool = False) -> pd.DataFrame: """Load CUAD as a DataFrame. Cached locally as parquet after first run.""" if CUAD_PARQUET.exists() and not force_download: return pd.read_parquet(CUAD_PARQUET) try: from datasets import load_dataset except ImportError as e: raise RuntimeError( "datasets library required for initial download; install requirements.txt" ) from e ds = load_dataset(CUAD_HF_ID) df = ds["train"].to_pandas() CUAD_PARQUET.parent.mkdir(parents=True, exist_ok=True) df.to_parquet(CUAD_PARQUET, index=False) return df def contract_group_split( df: pd.DataFrame, test_size: float = TEST_SIZE, seed: int = SEED ) -> tuple[pd.DataFrame, pd.DataFrame]: """Split by file_name so no contract appears in both train and test. Returns (train_df, test_df). Index reset. """ splitter = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=seed) train_idx, test_idx = next(splitter.split(df, groups=df["file_name"])) train_df = df.iloc[train_idx].reset_index(drop=True) test_df = df.iloc[test_idx].reset_index(drop=True) return train_df, test_df def naive_random_split( df: pd.DataFrame, test_size: float = TEST_SIZE, seed: int = SEED ) -> tuple[pd.DataFrame, pd.DataFrame]: """Naive row-level shuffle split. Provided for the leakage-comparison demo only; do not use for real eval. """ shuffled = df.sample(frac=1.0, random_state=seed).reset_index(drop=True) cut = int(len(shuffled) * (1 - test_size)) return shuffled.iloc[:cut].reset_index(drop=True), shuffled.iloc[cut:].reset_index(drop=True) def summarize(df: pd.DataFrame) -> dict: """Quick descriptive stats used in EDA and tests.""" return { "n_rows": int(len(df)), "n_contracts": int(df["file_name"].nunique()), "n_labels": int(df["label"].nunique()), "mean_clause_chars": float(df["clause"].str.len().mean()), "label_top5": df["label"].value_counts().head(5).to_dict(), }