from __future__ import annotations import argparse import json from pathlib import Path import pandas as pd from sklearn.model_selection import train_test_split from .dataset import DEFAULT_ADOLFO_ROOT, DEFAULT_METADATA_CSV, DEFAULT_OTHER_ROOT, build_sample_manifest def _stratify_or_none(frame: pd.DataFrame, columns: list[str]) -> pd.Series | None: labels = frame[columns].fillna("").astype(str).agg("__".join, axis=1) if labels.value_counts().min() < 2: return None return labels def _write_split( out_dir: Path, name: str, parts: dict[str, pd.DataFrame], stratification_note: str, ) -> dict[str, int]: setting_dir = out_dir / name setting_dir.mkdir(parents=True, exist_ok=True) stats: dict[str, int] = {} seen: set[str] = set() for split_name, frame in parts.items(): overlap = seen.intersection(set(frame["sample_id"])) if overlap: raise ValueError(f"{name} has leakage into {split_name}: {sorted(overlap)[:10]}") seen.update(frame["sample_id"]) frame.to_csv(setting_dir / f"{split_name}.csv", index=False) stats[split_name] = len(frame) all_splits = pd.concat([frame.assign(split=split_name) for split_name, frame in parts.items()], ignore_index=True) all_splits.to_csv(setting_dir / "all_splits.csv", index=False) payload = _stats_payload(name, parts, stratification_note) (setting_dir / "stats.json").write_text(json.dumps(payload, indent=2)) (setting_dir / "summary.txt").write_text(_summary_text(payload, parts)) return stats def _stats_payload(name: str, parts: dict[str, pd.DataFrame], stratification_note: str) -> dict: payload = {"setting": name, "stratification": stratification_note, "splits": {}} for split_name, frame in parts.items(): payload["splits"][split_name] = { "n": len(frame), "by_collection": frame["collection"].value_counts(dropna=False).sort_index().to_dict(), "by_quality": frame["quality_label"].fillna("").replace("", "unlabeled").value_counts(dropna=False).sort_index().to_dict(), } return payload def _format_counts(counts: dict) -> str: if not counts: return "none" return ", ".join(f"{key}={value}" for key, value in counts.items()) def _quality_collection_table(frame: pd.DataFrame) -> str: table = pd.crosstab( frame["quality_label"].fillna("").replace("", "unlabeled"), frame["collection"].fillna("").replace("", "unknown"), ) if table.empty: return "none" return table.to_string() def _summary_text(payload: dict, parts: dict[str, pd.DataFrame]) -> str: lines = [ f"Setting: {payload['setting']}", f"Stratification: {payload['stratification']}", "", ] total = sum(split["n"] for split in payload["splits"].values()) lines.append(f"Total rows across splits: {total}") lines.append("") for split_name, split_stats in payload["splits"].items(): frame = parts[split_name] lines.extend( [ f"[{split_name}]", f"n={split_stats['n']}", f"quality: {_format_counts(split_stats['by_quality'])}", f"collection: {_format_counts(split_stats['by_collection'])}", "quality_by_collection:", _quality_collection_table(frame), "", ] ) return "\n".join(lines).rstrip() + "\n" def create_splits( out_dir: str | Path = "splits", adolfo_root: str | Path = DEFAULT_ADOLFO_ROOT, other_root: str | Path = DEFAULT_OTHER_ROOT, metadata_csv: str | Path | None = DEFAULT_METADATA_CSV, seed: int = 7, ) -> dict[str, dict[str, int]]: out_dir = Path(out_dir) manifest = build_sample_manifest(adolfo_root, other_root, metadata_csv) out_dir.mkdir(parents=True, exist_ok=True) manifest.to_csv(out_dir / "manifest.csv", index=False) (out_dir / "manifest_stats.json").write_text( json.dumps( { "total": len(manifest), "by_collection": manifest["collection"].value_counts().sort_index().to_dict(), "by_quality": manifest["quality_label"].fillna("").replace("", "unlabeled").value_counts().sort_index().to_dict(), }, indent=2, ) ) adolfo = manifest[manifest["collection"] == "adolfo"].reset_index(drop=True) other = manifest[manifest["collection"] != "adolfo"].reset_index(drop=True) e1_train, e1_val = train_test_split( adolfo, test_size=0.2, random_state=seed, stratify=_stratify_or_none(adolfo, ["quality_label"]), ) labeled = manifest[manifest["quality_label"].isin(["q0", "q1", "q2"])].reset_index(drop=True) e2_train_all = labeled[labeled["quality_label"].isin(["q0", "q1"])].reset_index(drop=True) e2_test = labeled[labeled["quality_label"] == "q2"].reset_index(drop=True) e2_train, e2_val = train_test_split( e2_train_all, test_size=0.15, random_state=seed, stratify=_stratify_or_none(e2_train_all, ["collection", "quality_label"]), ) e3_train, e3_test = train_test_split( manifest, test_size=0.2, random_state=seed, stratify=_stratify_or_none(manifest, ["collection", "quality_label"]), ) e3_train, e3_val = train_test_split( e3_train, test_size=0.125, random_state=seed, stratify=_stratify_or_none(e3_train, ["collection", "quality_label"]), ) return { "e1_cross_collection": _write_split( out_dir, "e1_cross_collection", {"train": e1_train, "val": e1_val, "test": other}, "Adolfo train/val is stratified by quality_label. Test is the full held-out non-Adolfo collection, not sampled.", ), "e2_quality_stratified": _write_split( out_dir, "e2_quality_stratified", {"train": e2_train, "val": e2_val, "test": e2_test}, "q0+q1 train/val is stratified by collection and quality_label. Test is all q2 samples by protocol.", ), "e3_full_data": _write_split( out_dir, "e3_full_data", {"train": e3_train, "val": e3_val, "test": e3_test}, "Train/val/test is stratified by collection and quality_label.", ), } def main() -> None: parser = argparse.ArgumentParser(description="Create E1/E2/E3 monogram segmentation splits.") parser.add_argument("--out-dir", default="splits") parser.add_argument("--metadata-csv", default=str(DEFAULT_METADATA_CSV)) parser.add_argument("--seed", type=int, default=7) args = parser.parse_args() stats = create_splits(args.out_dir, metadata_csv=args.metadata_csv, seed=args.seed) print(json.dumps(stats, indent=2)) if __name__ == "__main__": main()