Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| from .dataset import DEFAULT_ADOLFO_ROOT, DEFAULT_METADATA_CSV, DEFAULT_OTHER_ROOT, build_sample_manifest | |
| def _stratify_or_none(frame: pd.DataFrame, columns: list[str]) -> pd.Series | None: | |
| labels = frame[columns].fillna("").astype(str).agg("__".join, axis=1) | |
| if labels.value_counts().min() < 2: | |
| return None | |
| return labels | |
| def _write_split( | |
| out_dir: Path, | |
| name: str, | |
| parts: dict[str, pd.DataFrame], | |
| stratification_note: str, | |
| ) -> dict[str, int]: | |
| setting_dir = out_dir / name | |
| setting_dir.mkdir(parents=True, exist_ok=True) | |
| stats: dict[str, int] = {} | |
| seen: set[str] = set() | |
| for split_name, frame in parts.items(): | |
| overlap = seen.intersection(set(frame["sample_id"])) | |
| if overlap: | |
| raise ValueError(f"{name} has leakage into {split_name}: {sorted(overlap)[:10]}") | |
| seen.update(frame["sample_id"]) | |
| frame.to_csv(setting_dir / f"{split_name}.csv", index=False) | |
| stats[split_name] = len(frame) | |
| all_splits = pd.concat([frame.assign(split=split_name) for split_name, frame in parts.items()], ignore_index=True) | |
| all_splits.to_csv(setting_dir / "all_splits.csv", index=False) | |
| payload = _stats_payload(name, parts, stratification_note) | |
| (setting_dir / "stats.json").write_text(json.dumps(payload, indent=2)) | |
| (setting_dir / "summary.txt").write_text(_summary_text(payload, parts)) | |
| return stats | |
| def _stats_payload(name: str, parts: dict[str, pd.DataFrame], stratification_note: str) -> dict: | |
| payload = {"setting": name, "stratification": stratification_note, "splits": {}} | |
| for split_name, frame in parts.items(): | |
| payload["splits"][split_name] = { | |
| "n": len(frame), | |
| "by_collection": frame["collection"].value_counts(dropna=False).sort_index().to_dict(), | |
| "by_quality": frame["quality_label"].fillna("").replace("", "unlabeled").value_counts(dropna=False).sort_index().to_dict(), | |
| } | |
| return payload | |
| def _format_counts(counts: dict) -> str: | |
| if not counts: | |
| return "none" | |
| return ", ".join(f"{key}={value}" for key, value in counts.items()) | |
| def _quality_collection_table(frame: pd.DataFrame) -> str: | |
| table = pd.crosstab( | |
| frame["quality_label"].fillna("").replace("", "unlabeled"), | |
| frame["collection"].fillna("").replace("", "unknown"), | |
| ) | |
| if table.empty: | |
| return "none" | |
| return table.to_string() | |
| def _summary_text(payload: dict, parts: dict[str, pd.DataFrame]) -> str: | |
| lines = [ | |
| f"Setting: {payload['setting']}", | |
| f"Stratification: {payload['stratification']}", | |
| "", | |
| ] | |
| total = sum(split["n"] for split in payload["splits"].values()) | |
| lines.append(f"Total rows across splits: {total}") | |
| lines.append("") | |
| for split_name, split_stats in payload["splits"].items(): | |
| frame = parts[split_name] | |
| lines.extend( | |
| [ | |
| f"[{split_name}]", | |
| f"n={split_stats['n']}", | |
| f"quality: {_format_counts(split_stats['by_quality'])}", | |
| f"collection: {_format_counts(split_stats['by_collection'])}", | |
| "quality_by_collection:", | |
| _quality_collection_table(frame), | |
| "", | |
| ] | |
| ) | |
| return "\n".join(lines).rstrip() + "\n" | |
| def create_splits( | |
| out_dir: str | Path = "splits", | |
| adolfo_root: str | Path = DEFAULT_ADOLFO_ROOT, | |
| other_root: str | Path = DEFAULT_OTHER_ROOT, | |
| metadata_csv: str | Path | None = DEFAULT_METADATA_CSV, | |
| seed: int = 7, | |
| ) -> dict[str, dict[str, int]]: | |
| out_dir = Path(out_dir) | |
| manifest = build_sample_manifest(adolfo_root, other_root, metadata_csv) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| manifest.to_csv(out_dir / "manifest.csv", index=False) | |
| (out_dir / "manifest_stats.json").write_text( | |
| json.dumps( | |
| { | |
| "total": len(manifest), | |
| "by_collection": manifest["collection"].value_counts().sort_index().to_dict(), | |
| "by_quality": manifest["quality_label"].fillna("").replace("", "unlabeled").value_counts().sort_index().to_dict(), | |
| }, | |
| indent=2, | |
| ) | |
| ) | |
| adolfo = manifest[manifest["collection"] == "adolfo"].reset_index(drop=True) | |
| other = manifest[manifest["collection"] != "adolfo"].reset_index(drop=True) | |
| e1_train, e1_val = train_test_split( | |
| adolfo, | |
| test_size=0.2, | |
| random_state=seed, | |
| stratify=_stratify_or_none(adolfo, ["quality_label"]), | |
| ) | |
| labeled = manifest[manifest["quality_label"].isin(["q0", "q1", "q2"])].reset_index(drop=True) | |
| e2_train_all = labeled[labeled["quality_label"].isin(["q0", "q1"])].reset_index(drop=True) | |
| e2_test = labeled[labeled["quality_label"] == "q2"].reset_index(drop=True) | |
| e2_train, e2_val = train_test_split( | |
| e2_train_all, | |
| test_size=0.15, | |
| random_state=seed, | |
| stratify=_stratify_or_none(e2_train_all, ["collection", "quality_label"]), | |
| ) | |
| e3_train, e3_test = train_test_split( | |
| manifest, | |
| test_size=0.2, | |
| random_state=seed, | |
| stratify=_stratify_or_none(manifest, ["collection", "quality_label"]), | |
| ) | |
| e3_train, e3_val = train_test_split( | |
| e3_train, | |
| test_size=0.125, | |
| random_state=seed, | |
| stratify=_stratify_or_none(e3_train, ["collection", "quality_label"]), | |
| ) | |
| return { | |
| "e1_cross_collection": _write_split( | |
| out_dir, | |
| "e1_cross_collection", | |
| {"train": e1_train, "val": e1_val, "test": other}, | |
| "Adolfo train/val is stratified by quality_label. Test is the full held-out non-Adolfo collection, not sampled.", | |
| ), | |
| "e2_quality_stratified": _write_split( | |
| out_dir, | |
| "e2_quality_stratified", | |
| {"train": e2_train, "val": e2_val, "test": e2_test}, | |
| "q0+q1 train/val is stratified by collection and quality_label. Test is all q2 samples by protocol.", | |
| ), | |
| "e3_full_data": _write_split( | |
| out_dir, | |
| "e3_full_data", | |
| {"train": e3_train, "val": e3_val, "test": e3_test}, | |
| "Train/val/test is stratified by collection and quality_label.", | |
| ), | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Create E1/E2/E3 monogram segmentation splits.") | |
| parser.add_argument("--out-dir", default="splits") | |
| parser.add_argument("--metadata-csv", default=str(DEFAULT_METADATA_CSV)) | |
| parser.add_argument("--seed", type=int, default=7) | |
| args = parser.parse_args() | |
| stats = create_splits(args.out_dir, metadata_csv=args.metadata_csv, seed=args.seed) | |
| print(json.dumps(stats, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |