maralzar
Initial commit: EPCC clause classifier Streamlit demo for HF Spaces
5212b8e
Raw
History Blame Contribute Delete
2.32 kB
"""CUAD dataset loading and contract-level (leakage-safe) splitting."""
from __future__ import annotations
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
from src.config import CUAD_HF_ID, CUAD_PARQUET, SEED, TEST_SIZE
def load_cuad(force_download: bool = False) -> pd.DataFrame:
"""Load CUAD as a DataFrame. Cached locally as parquet after first run."""
if CUAD_PARQUET.exists() and not force_download:
return pd.read_parquet(CUAD_PARQUET)
try:
from datasets import load_dataset
except ImportError as e:
raise RuntimeError(
"datasets library required for initial download; install requirements.txt"
) from e
ds = load_dataset(CUAD_HF_ID)
df = ds["train"].to_pandas()
CUAD_PARQUET.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(CUAD_PARQUET, index=False)
return df
def contract_group_split(
df: pd.DataFrame, test_size: float = TEST_SIZE, seed: int = SEED
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Split by file_name so no contract appears in both train and test.
Returns (train_df, test_df). Index reset.
"""
splitter = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=seed)
train_idx, test_idx = next(splitter.split(df, groups=df["file_name"]))
train_df = df.iloc[train_idx].reset_index(drop=True)
test_df = df.iloc[test_idx].reset_index(drop=True)
return train_df, test_df
def naive_random_split(
df: pd.DataFrame, test_size: float = TEST_SIZE, seed: int = SEED
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Naive row-level shuffle split.
Provided for the leakage-comparison demo only; do not use for real eval.
"""
shuffled = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
cut = int(len(shuffled) * (1 - test_size))
return shuffled.iloc[:cut].reset_index(drop=True), shuffled.iloc[cut:].reset_index(drop=True)
def summarize(df: pd.DataFrame) -> dict:
"""Quick descriptive stats used in EDA and tests."""
return {
"n_rows": int(len(df)),
"n_contracts": int(df["file_name"].nunique()),
"n_labels": int(df["label"].nunique()),
"mean_clause_chars": float(df["clause"].str.len().mean()),
"label_top5": df["label"].value_counts().head(5).to_dict(),
}