segment-monograms / src /data /splits.py
Saranga7's picture
Deploy monogram segmentation demo
bcc432f verified
Raw
History Blame Contribute Delete
6.93 kB
from __future__ import annotations
import argparse
import json
from pathlib import Path
import pandas as pd
from sklearn.model_selection import train_test_split
from .dataset import DEFAULT_ADOLFO_ROOT, DEFAULT_METADATA_CSV, DEFAULT_OTHER_ROOT, build_sample_manifest
def _stratify_or_none(frame: pd.DataFrame, columns: list[str]) -> pd.Series | None:
labels = frame[columns].fillna("").astype(str).agg("__".join, axis=1)
if labels.value_counts().min() < 2:
return None
return labels
def _write_split(
out_dir: Path,
name: str,
parts: dict[str, pd.DataFrame],
stratification_note: str,
) -> dict[str, int]:
setting_dir = out_dir / name
setting_dir.mkdir(parents=True, exist_ok=True)
stats: dict[str, int] = {}
seen: set[str] = set()
for split_name, frame in parts.items():
overlap = seen.intersection(set(frame["sample_id"]))
if overlap:
raise ValueError(f"{name} has leakage into {split_name}: {sorted(overlap)[:10]}")
seen.update(frame["sample_id"])
frame.to_csv(setting_dir / f"{split_name}.csv", index=False)
stats[split_name] = len(frame)
all_splits = pd.concat([frame.assign(split=split_name) for split_name, frame in parts.items()], ignore_index=True)
all_splits.to_csv(setting_dir / "all_splits.csv", index=False)
payload = _stats_payload(name, parts, stratification_note)
(setting_dir / "stats.json").write_text(json.dumps(payload, indent=2))
(setting_dir / "summary.txt").write_text(_summary_text(payload, parts))
return stats
def _stats_payload(name: str, parts: dict[str, pd.DataFrame], stratification_note: str) -> dict:
payload = {"setting": name, "stratification": stratification_note, "splits": {}}
for split_name, frame in parts.items():
payload["splits"][split_name] = {
"n": len(frame),
"by_collection": frame["collection"].value_counts(dropna=False).sort_index().to_dict(),
"by_quality": frame["quality_label"].fillna("").replace("", "unlabeled").value_counts(dropna=False).sort_index().to_dict(),
}
return payload
def _format_counts(counts: dict) -> str:
if not counts:
return "none"
return ", ".join(f"{key}={value}" for key, value in counts.items())
def _quality_collection_table(frame: pd.DataFrame) -> str:
table = pd.crosstab(
frame["quality_label"].fillna("").replace("", "unlabeled"),
frame["collection"].fillna("").replace("", "unknown"),
)
if table.empty:
return "none"
return table.to_string()
def _summary_text(payload: dict, parts: dict[str, pd.DataFrame]) -> str:
lines = [
f"Setting: {payload['setting']}",
f"Stratification: {payload['stratification']}",
"",
]
total = sum(split["n"] for split in payload["splits"].values())
lines.append(f"Total rows across splits: {total}")
lines.append("")
for split_name, split_stats in payload["splits"].items():
frame = parts[split_name]
lines.extend(
[
f"[{split_name}]",
f"n={split_stats['n']}",
f"quality: {_format_counts(split_stats['by_quality'])}",
f"collection: {_format_counts(split_stats['by_collection'])}",
"quality_by_collection:",
_quality_collection_table(frame),
"",
]
)
return "\n".join(lines).rstrip() + "\n"
def create_splits(
out_dir: str | Path = "splits",
adolfo_root: str | Path = DEFAULT_ADOLFO_ROOT,
other_root: str | Path = DEFAULT_OTHER_ROOT,
metadata_csv: str | Path | None = DEFAULT_METADATA_CSV,
seed: int = 7,
) -> dict[str, dict[str, int]]:
out_dir = Path(out_dir)
manifest = build_sample_manifest(adolfo_root, other_root, metadata_csv)
out_dir.mkdir(parents=True, exist_ok=True)
manifest.to_csv(out_dir / "manifest.csv", index=False)
(out_dir / "manifest_stats.json").write_text(
json.dumps(
{
"total": len(manifest),
"by_collection": manifest["collection"].value_counts().sort_index().to_dict(),
"by_quality": manifest["quality_label"].fillna("").replace("", "unlabeled").value_counts().sort_index().to_dict(),
},
indent=2,
)
)
adolfo = manifest[manifest["collection"] == "adolfo"].reset_index(drop=True)
other = manifest[manifest["collection"] != "adolfo"].reset_index(drop=True)
e1_train, e1_val = train_test_split(
adolfo,
test_size=0.2,
random_state=seed,
stratify=_stratify_or_none(adolfo, ["quality_label"]),
)
labeled = manifest[manifest["quality_label"].isin(["q0", "q1", "q2"])].reset_index(drop=True)
e2_train_all = labeled[labeled["quality_label"].isin(["q0", "q1"])].reset_index(drop=True)
e2_test = labeled[labeled["quality_label"] == "q2"].reset_index(drop=True)
e2_train, e2_val = train_test_split(
e2_train_all,
test_size=0.15,
random_state=seed,
stratify=_stratify_or_none(e2_train_all, ["collection", "quality_label"]),
)
e3_train, e3_test = train_test_split(
manifest,
test_size=0.2,
random_state=seed,
stratify=_stratify_or_none(manifest, ["collection", "quality_label"]),
)
e3_train, e3_val = train_test_split(
e3_train,
test_size=0.125,
random_state=seed,
stratify=_stratify_or_none(e3_train, ["collection", "quality_label"]),
)
return {
"e1_cross_collection": _write_split(
out_dir,
"e1_cross_collection",
{"train": e1_train, "val": e1_val, "test": other},
"Adolfo train/val is stratified by quality_label. Test is the full held-out non-Adolfo collection, not sampled.",
),
"e2_quality_stratified": _write_split(
out_dir,
"e2_quality_stratified",
{"train": e2_train, "val": e2_val, "test": e2_test},
"q0+q1 train/val is stratified by collection and quality_label. Test is all q2 samples by protocol.",
),
"e3_full_data": _write_split(
out_dir,
"e3_full_data",
{"train": e3_train, "val": e3_val, "test": e3_test},
"Train/val/test is stratified by collection and quality_label.",
),
}
def main() -> None:
parser = argparse.ArgumentParser(description="Create E1/E2/E3 monogram segmentation splits.")
parser.add_argument("--out-dir", default="splits")
parser.add_argument("--metadata-csv", default=str(DEFAULT_METADATA_CSV))
parser.add_argument("--seed", type=int, default=7)
args = parser.parse_args()
stats = create_splits(args.out_dir, metadata_csv=args.metadata_csv, seed=args.seed)
print(json.dumps(stats, indent=2))
if __name__ == "__main__":
main()