#!/usr/bin/env python3 """ Resplit `data_v2/` into leakage-safe `data_v3_rebuild/` using MMseqs2 clustering. Default policy for the current rebuild: - Cluster `protein_seq` with MMseqs2 `linclust` - Define species by normalized binomial name (`genus species`) - Test species are exactly the normalized species present in `data_v2/test` - Validation is cluster-unseen but species-seen - Mixed seen/heldout clusters keep heldout rows in test and drop seen rows Typical usage (end-to-end): python resplit_data_v3.py all --threads 32 --split-memory-limit 120G --num-shards 256 """ from __future__ import annotations import argparse import json import os import shutil import stat import subprocess import sys import time from pathlib import Path from typing import Dict, Iterable, List, Optional, Sequence, Tuple def _default_mmseqs_path() -> str: cand = Path("MMseqs2/build/bin/mmseqs") if cand.exists(): return str(cand) return "mmseqs" def _run(cmd: List[str], *, cwd: Optional[str] = None, env: Optional[dict] = None) -> None: pretty = " ".join(cmd) print(f"+ {pretty}", flush=True) subprocess.run(cmd, cwd=cwd, env=env, check=True) def _sql_escape_path(path: str) -> str: return path.replace("'", "''") def _expand_parquet_inputs(inp: str) -> List[str]: import glob p = Path(inp) if p.exists() and p.is_dir(): files = sorted(str(x) for x in p.rglob("*.parquet")) else: files = sorted(glob.glob(inp)) seen = set() out: List[str] = [] for f in files: if f not in seen: out.append(f) seen.add(f) return out def _duckdb_parquet_source(inp: str, limit_files: int = 0) -> str: files = _expand_parquet_inputs(inp) if not files: raise SystemExit(f"No parquet files found for {inp!r}") if limit_files and int(limit_files) > 0: files = files[: int(limit_files)] quoted = ", ".join(f"'{_sql_escape_path(fp)}'" for fp in files) return f"read_parquet([{quoted}])" def _mem_total_bytes() -> Optional[int]: try: with open("/proc/meminfo", "r", encoding="utf-8") as f: for line in f: if line.startswith("MemTotal:"): parts = line.split() kb = int(parts[1]) return kb * 1024 except OSError: return None except (ValueError, IndexError): return None return None def _parse_mmseqs_bytes(s: str) -> Optional[int]: s = (s or "").strip() if not s: return None up = s.upper() suffix = up[-1] num_part = up[:-1] unit = suffix if suffix == "B" and len(up) >= 2 and up[-2] in "KMGT": unit = up[-2] num_part = up[:-2] if unit not in "BKMGT": return None try: val = float(num_part) except ValueError: return None mult = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4}[unit] return int(val * mult) def _format_bytes(n: int) -> str: for unit, div in [("TiB", 1024**4), ("GiB", 1024**3), ("MiB", 1024**2), ("KiB", 1024)]: if n >= div: return f"{n / div:.1f}{unit}" return f"{n}B" def _seq_id_sql() -> str: # Keep the stable row identifier aligned with the existing pipeline. return "coalesce(protein_refseq_id, '') || '|' || coalesce(RefseqID, '')" def _taxon_norm_sql(col: str = "taxon") -> str: return f"regexp_replace(lower(trim(coalesce({col}, ''))), '\\\\s+', ' ', 'g')" def _species_key_sql(mode: str, col: str = "taxon") -> str: norm = _taxon_norm_sql(col) if mode == "taxon": return norm if mode == "binomial": return ( f"CASE " f"WHEN strpos({norm}, ' ') > 0 " f"THEN split_part({norm}, ' ', 1) || ' ' || split_part({norm}, ' ', 2) " f"ELSE {norm} END" ) raise ValueError(f"Unsupported species key mode: {mode}") def _protein_norm_sql(col: str = "protein_seq") -> str: cleaned = f"regexp_replace(upper(coalesce({col}, '')), '\\\\s+', '', 'g')" no_stop = f"regexp_replace({cleaned}, '[_*]+$', '')" return f"regexp_replace({no_stop}, '[^A-Z]', 'X', 'g')" def _cds_norm_sql(col: str = "cds_DNA") -> str: cleaned = f"regexp_replace(upper(coalesce({col}, '')), '\\\\s+', '', 'g')" return f"regexp_replace({cleaned}, '[^ACGTN]', 'N', 'g')" def _seq_expr_sql(seq_space: str) -> str: if seq_space == "protein": return _protein_norm_sql("protein_seq") if seq_space == "cds": return _cds_norm_sql("cds_DNA") raise ValueError(f"Unsupported seq space: {seq_space}") def _seq_space_input_col(seq_space: str) -> str: if seq_space == "protein": return "protein_seq" if seq_space == "cds": return "cds_DNA" raise ValueError(f"Unsupported seq space: {seq_space}") def _mmseqs_dbtype(seq_space: str) -> str: if seq_space == "protein": return "1" if seq_space == "cds": return "2" raise ValueError(f"Unsupported seq space: {seq_space}") def _default_max_input_seq_len(seq_space: str) -> int: if seq_space == "protein": # MMseqs linclust hit an internal SW bug on a tiny tail of ultra-long proteins # (~39k aa+). Filtering this tail removes <0.01% of rows and keeps the run stable. return 20_000 return 0 def _ensure_mmseqs_ready(mmseqs: str) -> Tuple[str, Dict[str, str]]: path = Path(mmseqs) env = os.environ.copy() if path.exists(): mode = path.stat().st_mode if not (mode & stat.S_IXUSR): path.chmod(mode | stat.S_IXUSR) py = Path(sys.executable).resolve() env_root = py.parent.parent conda_root = env_root.parent.parent if env_root.parent.name == "envs" else env_root.parent lib_candidates = [env_root / "lib", conda_root / "lib"] libs = [str(p) for p in lib_candidates if p.exists()] if libs: current = env.get("LD_LIBRARY_PATH", "") env["LD_LIBRARY_PATH"] = ":".join(libs + ([current] if current else [])) return str(path if path.exists() else mmseqs), env def _ensure_output_parent(path: Path) -> None: path.parent.mkdir(parents=True, exist_ok=True) def cmd_make_fasta(args: argparse.Namespace) -> None: out_fasta = Path(args.output_fasta) _ensure_output_parent(out_fasta) import duckdb con = duckdb.connect() con.execute(f"PRAGMA threads={int(args.threads)};") con.execute("PRAGMA enable_progress_bar=true;") source_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files)) out_path = _sql_escape_path(str(out_fasta)) seq_id = _seq_id_sql() seq_expr = _seq_expr_sql(args.seq_space) raw_col = _seq_space_input_col(args.seq_space) max_input_seq_len = int(args.max_input_seq_len) if max_input_seq_len <= 0: max_input_seq_len = _default_max_input_seq_len(args.seq_space) len_filter = ( f"AND length({seq_expr}) <= {max_input_seq_len}" if max_input_seq_len > 0 else "" ) sql = f""" COPY ( SELECT '>' || ({seq_id}) AS header, {seq_expr} AS seq FROM {source_sql} WHERE {raw_col} IS NOT NULL AND length({seq_expr}) > 0 {len_filter} AND length(({seq_id})) > 1 {f"LIMIT {int(args.limit_rows)}" if args.limit_rows and int(args.limit_rows) > 0 else ""} ) TO '{out_path}' (FORMAT CSV, DELIMITER '\n', QUOTE '', ESCAPE '', HEADER FALSE); """ t0 = time.time() con.execute(sql) print( f"Wrote FASTA: {out_fasta} seq_space={args.seq_space} " f"max_input_seq_len={max_input_seq_len if max_input_seq_len > 0 else 'none'} " f"(elapsed_s={time.time() - t0:.1f})" ) def cmd_mmseqs_cluster(args: argparse.Namespace) -> None: mmseqs, env = _ensure_mmseqs_ready(args.mmseqs) workdir = Path(args.workdir) workdir.mkdir(parents=True, exist_ok=True) fasta = Path(args.fasta) if not fasta.exists(): raise SystemExit(f"FASTA not found: {fasta}") seqdb = workdir / "seqdb" clu = workdir / "clu" tmp = workdir / "tmp" tsv = workdir / "clu.tsv" if args.overwrite: for p in (seqdb, clu, tmp, tsv): if p.is_dir(): shutil.rmtree(p, ignore_errors=True) else: for suffix in ("", ".dbtype", ".index", ".lookup", ".source"): try: os.remove(str(p) + suffix) except OSError: pass tmp.mkdir(parents=True, exist_ok=True) _run( [ mmseqs, "createdb", str(fasta), str(seqdb), "--dbtype", _mmseqs_dbtype(args.seq_space), "--shuffle", "0", "--createdb-mode", "1", "--threads", str(int(args.threads)), ], env=env, ) linclust_cmd = [ mmseqs, "linclust", str(seqdb), str(clu), str(tmp), "--min-seq-id", str(float(args.min_seq_id)), "-c", str(float(args.coverage)), "--cov-mode", str(int(args.cov_mode)), "--cluster-mode", str(int(args.cluster_mode)), "--threads", str(int(args.threads)), "--max-seq-len", str(int(args.max_seq_len)), "--remove-tmp-files", "1" if args.remove_tmp_files else "0", ] if args.split_memory_limit: mem_total = _mem_total_bytes() limit_bytes = _parse_mmseqs_bytes(args.split_memory_limit) if mem_total and limit_bytes and limit_bytes > mem_total: print( f"WARNING: --split-memory-limit={args.split_memory_limit} ({_format_bytes(limit_bytes)}) " f"exceeds system MemTotal ({_format_bytes(mem_total)}). " "MMseqs2 may under-split and crash; consider lowering it or leaving it empty.", file=sys.stderr, flush=True, ) linclust_cmd += ["--split-memory-limit", str(args.split_memory_limit)] if args.kmer_per_seq_scale is not None: linclust_cmd += ["--kmer-per-seq-scale", str(float(args.kmer_per_seq_scale))] _run(linclust_cmd, env=env) _run([mmseqs, "createtsv", str(seqdb), str(seqdb), str(clu), str(tsv)], env=env) print(f"Wrote cluster TSV: {tsv}") def cmd_make_seq_cluster(args: argparse.Namespace) -> None: import duckdb tsv = Path(args.cluster_tsv) if not tsv.exists(): raise SystemExit(f"Cluster TSV not found: {tsv}") out = Path(args.output_parquet) _ensure_output_parent(out) con = duckdb.connect() con.execute(f"PRAGMA threads={int(args.threads)};") con.execute("PRAGMA enable_progress_bar=true;") tsv_path = _sql_escape_path(str(tsv)) out_path = _sql_escape_path(str(out)) sql = f""" COPY ( SELECT DISTINCT seq_id, cluster_id FROM read_csv( '{tsv_path}', delim='\\t', header=false, columns={{'cluster_id':'VARCHAR','seq_id':'VARCHAR'}} ) ) TO '{out_path}' (FORMAT PARQUET); """ t0 = time.time() con.execute(sql) print(f"Wrote seq→cluster parquet: {out} (elapsed_s={time.time() - t0:.1f})") def _write_cluster_split_parquet( con, *, cluster_split_path: Path, seed: int, val_frac: float, ) -> Dict[str, int]: import pyarrow as pa import pyarrow.parquet as pq cluster_split_path.parent.mkdir(parents=True, exist_ok=True) if cluster_split_path.exists(): cluster_split_path.unlink() total_seen_rows = int( con.execute( "SELECT coalesce(sum(n_total), 0)::BIGINT FROM cluster_flags WHERE n_test = 0" ).fetchone()[0] ) target_val_rows = int(total_seen_rows * float(val_frac)) species_remaining = { species_key: int(n_clusters) for species_key, n_clusters in con.execute( """ SELECT cc.species_key, count(*)::BIGINT AS n_clusters FROM cluster_counts cc JOIN cluster_flags cf USING (cluster_id) WHERE cf.n_test = 0 GROUP BY cc.species_key """ ).fetchall() } cur = con.execute( f""" SELECT cf.cluster_id, cf.n_total, abs(hash(cf.cluster_id || ':{seed}')) AS rnd, cc.species_key FROM cluster_flags cf JOIN cluster_counts cc USING (cluster_id) WHERE cf.n_test = 0 ORDER BY rnd, cf.cluster_id, cc.species_key """ ) writer = None batch_cluster_ids: List[str] = [] batch_splits: List[str] = [] val_rows = 0 train_clusters = 0 val_clusters = 0 current_cluster: Optional[str] = None current_n_total = 0 current_species: List[str] = [] def flush_current() -> None: nonlocal writer, val_rows, train_clusters, val_clusters nonlocal current_cluster, current_n_total, current_species if current_cluster is None: return can_val = ( val_rows < target_val_rows and all(species_remaining.get(species_key, 0) > 1 for species_key in current_species) ) split = "val" if can_val else "train" if can_val: for species_key in current_species: species_remaining[species_key] -= 1 val_rows += int(current_n_total) val_clusters += 1 else: train_clusters += 1 batch_cluster_ids.append(current_cluster) batch_splits.append(split) if len(batch_cluster_ids) >= 200_000: table = pa.table({"cluster_id": batch_cluster_ids, "split": batch_splits}) if writer is None: writer = pq.ParquetWriter(str(cluster_split_path), table.schema) writer.write_table(table) batch_cluster_ids.clear() batch_splits.clear() while True: rows = cur.fetchmany(200_000) if not rows: break for cluster_id, n_total, _rnd, species_key in rows: cluster_id = str(cluster_id) species_key = str(species_key) if current_cluster is None: current_cluster = cluster_id current_n_total = int(n_total) current_species = [species_key] continue if cluster_id != current_cluster: flush_current() current_cluster = cluster_id current_n_total = int(n_total) current_species = [species_key] continue current_species.append(species_key) flush_current() if batch_cluster_ids: table = pa.table({"cluster_id": batch_cluster_ids, "split": batch_splits}) if writer is None: writer = pq.ParquetWriter(str(cluster_split_path), table.schema) writer.write_table(table) elif writer is None: empty = pa.table( { "cluster_id": pa.array([], type=pa.string()), "split": pa.array([], type=pa.string()), } ) writer = pq.ParquetWriter(str(cluster_split_path), empty.schema) writer.write_table(empty) if writer is not None: writer.close() return { "nonheldout_total_rows": total_seen_rows, "target_val_rows": target_val_rows, "actual_val_rows": val_rows, "train_clusters": train_clusters, "val_clusters": val_clusters, } def cmd_make_seq_split(args: argparse.Namespace) -> None: import duckdb seq_cluster = Path(args.seq_cluster_parquet) if not seq_cluster.exists(): raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}") out = Path(args.output_parquet) cluster_split = Path(args.cluster_split_parquet) _ensure_output_parent(out) _ensure_output_parent(cluster_split) con = duckdb.connect() con.execute(f"PRAGMA threads={int(args.threads)};") con.execute("PRAGMA enable_progress_bar=true;") input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files)) heldout_sql = _duckdb_parquet_source(args.heldout_test_glob, 0) seq_cluster_path = _sql_escape_path(str(seq_cluster)) out_path = _sql_escape_path(str(out)) seq_id = _seq_id_sql() species_key = _species_key_sql(args.species_key_mode, "taxon") protein_norm = _protein_norm_sql("protein_seq") con.execute( f""" CREATE TEMP TABLE heldout_species AS SELECT DISTINCT {species_key} AS species_key FROM {heldout_sql} WHERE {species_key} != ''; """ ) con.execute( f""" CREATE TEMP TABLE cluster_counts AS WITH base AS ( SELECT {seq_id} AS seq_id, {species_key} AS species_key FROM {input_sql} WHERE length(({seq_id})) > 1 AND {species_key} != '' ) SELECT sc.cluster_id, base.species_key, count(*)::BIGINT AS n FROM base JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) GROUP BY sc.cluster_id, base.species_key; """ ) con.execute( """ CREATE TEMP TABLE cluster_flags AS SELECT cluster_id, sum(CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_test, sum(CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_seen, sum(n)::BIGINT AS n_total, count(*)::BIGINT AS n_species FROM cluster_counts GROUP BY cluster_id; """ ) t0 = time.time() split_summary = _write_cluster_split_parquet( con, cluster_split_path=cluster_split, seed=int(args.seed), val_frac=float(args.val_frac), ) print( "Cluster assignment summary: " f"train_clusters={split_summary['train_clusters']:,} " f"val_clusters={split_summary['val_clusters']:,} " f"target_val_rows={split_summary['target_val_rows']:,} " f"actual_val_rows={split_summary['actual_val_rows']:,} " f"(elapsed_s={time.time() - t0:.1f})" ) cluster_split_path = _sql_escape_path(str(cluster_split)) con.execute( f""" COPY ( WITH base AS ( SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key, {protein_norm} AS protein_norm FROM {input_sql} WHERE length(({seq_id})) > 1 AND {species_key} != '' ), joined AS ( SELECT base.seq_id, base.species_key, base.protein_norm, sc.cluster_id FROM base LEFT JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) ), labeled AS ( SELECT j.seq_id, j.species_key, j.protein_norm, CASE WHEN j.cluster_id IS NULL THEN 'drop' WHEN j.species_key IN (SELECT species_key FROM heldout_species) THEN 'test' WHEN coalesce(cf.n_test, 0) > 0 THEN 'drop' ELSE coalesce(cs.split, 'drop') END AS split FROM joined j LEFT JOIN cluster_flags cf USING (cluster_id) LEFT JOIN read_parquet('{cluster_split_path}') cs USING (cluster_id) ), protein_flags AS ( SELECT protein_norm, max(CASE WHEN split = 'test' THEN 1 ELSE 0 END) AS has_test, max(CASE WHEN split = 'train' THEN 1 ELSE 0 END) AS has_train FROM labeled WHERE length(protein_norm) > 0 GROUP BY protein_norm ), guarded AS ( SELECT l.seq_id, l.species_key, CASE WHEN l.split = 'drop' THEN 'drop' WHEN length(l.protein_norm) = 0 THEN l.split WHEN coalesce(pf.has_test, 0) = 1 AND l.split IN ('train', 'val') THEN 'drop' WHEN coalesce(pf.has_train, 0) = 1 AND l.split = 'val' THEN 'drop' ELSE l.split END AS split FROM labeled l LEFT JOIN protein_flags pf USING (protein_norm) ), dedup AS ( SELECT seq_id, CASE WHEN count(DISTINCT species_key) > 1 THEN 'drop' WHEN count(DISTINCT split) > 1 THEN 'drop' ELSE any_value(split) END AS split FROM guarded GROUP BY seq_id ) SELECT seq_id, split FROM dedup ) TO '{out_path}' (FORMAT PARQUET); """ ) rows = con.execute( f""" WITH base AS ( SELECT {seq_id} AS seq_id FROM {input_sql} ) SELECT s.split, count(*)::BIGINT AS n_rows FROM base JOIN read_parquet('{out_path}') s USING (seq_id) GROUP BY s.split ORDER BY n_rows DESC; """ ).fetchall() print("Split summary (rows):") for split, n in rows: print(f" {split}\t{n:,}") print(f"Wrote cluster→split parquet: {cluster_split}") print(f"Wrote seq→split parquet: {out}") def cmd_write_data_v3(args: argparse.Namespace) -> None: import duckdb seq_split = Path(args.seq_split_parquet) if not seq_split.exists(): raise SystemExit(f"seq_split parquet not found: {seq_split}") seq_cluster = Path(args.seq_cluster_parquet) if args.representatives_only and not seq_cluster.exists(): raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}") out_root = Path(args.output_root) out_root.mkdir(parents=True, exist_ok=True) (out_root / "_work").mkdir(parents=True, exist_ok=True) for split_dir in (out_root / "train", out_root / "val", out_root / "test"): if split_dir.exists(): if not args.overwrite: raise SystemExit(f"Output split directory exists: {split_dir} (pass --overwrite)") shutil.rmtree(split_dir) split_dir.mkdir(parents=True, exist_ok=True) con = duckdb.connect() con.execute(f"PRAGMA threads={int(args.threads)};") con.execute("PRAGMA enable_progress_bar=true;") input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files)) seq_split_path = _sql_escape_path(str(seq_split)) seq_cluster_path = _sql_escape_path(str(seq_cluster)) seq_id = _seq_id_sql() num_shards = int(args.num_shards) if num_shards <= 0: raise SystemExit("--num-shards must be > 0") for split in ("train", "val", "test"): out_dir = _sql_escape_path(str(out_root / split)) if args.representatives_only: target_seq_ids_sql = f""" SELECT min(s.seq_id) AS seq_id FROM read_parquet('{seq_split_path}') s JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) WHERE s.split = '{split}' GROUP BY sc.cluster_id """ else: target_seq_ids_sql = f""" SELECT DISTINCT s.seq_id FROM read_parquet('{seq_split_path}') s WHERE s.split = '{split}' """ sql = f""" COPY ( WITH target_seq_ids AS ( {target_seq_ids_sql} ), rows AS ( SELECT p.*, abs(hash({seq_id})) % {num_shards} AS shard FROM {input_sql} p JOIN target_seq_ids t ON t.seq_id = ({seq_id}) QUALIFY row_number() OVER (PARTITION BY ({seq_id}) ORDER BY ({seq_id})) = 1 ) SELECT * FROM rows ) TO '{out_dir}' (FORMAT PARQUET, PARTITION_BY (shard)); """ t0 = time.time() con.execute(sql) print( f"Wrote {split} parquets to {out_root / split} " f"representatives_only={bool(args.representatives_only)} " f"(elapsed_s={time.time() - t0:.1f})" ) def cmd_verify(args: argparse.Namespace) -> None: import duckdb seq_cluster = Path(args.seq_cluster_parquet) seq_split = Path(args.seq_split_parquet) if not seq_cluster.exists(): raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}") if not seq_split.exists(): raise SystemExit(f"seq_split parquet not found: {seq_split}") con = duckdb.connect() con.execute(f"PRAGMA threads={int(args.threads)};") con.execute("PRAGMA enable_progress_bar=true;") input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files)) heldout_sql = _duckdb_parquet_source(args.heldout_test_glob, 0) seq_cluster_path = _sql_escape_path(str(seq_cluster)) seq_split_path = _sql_escape_path(str(seq_split)) seq_id = _seq_id_sql() species_key = _species_key_sql(args.species_key_mode, "taxon") protein_norm = _protein_norm_sql("protein_seq") con.execute( f""" CREATE TEMP TABLE heldout_species AS SELECT DISTINCT {species_key} AS species_key FROM {heldout_sql} WHERE {species_key} != ''; """ ) con.execute( f""" CREATE TEMP TABLE cluster_counts AS WITH base AS ( SELECT {seq_id} AS seq_id, {species_key} AS species_key FROM {input_sql} WHERE length(({seq_id})) > 1 AND {species_key} != '' ) SELECT sc.cluster_id, base.species_key, count(*)::BIGINT AS n FROM base JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) GROUP BY sc.cluster_id, base.species_key; """ ) con.execute( """ CREATE TEMP TABLE cluster_flags AS SELECT cluster_id, sum(CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_test, sum(CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_seen, sum(n)::BIGINT AS n_total, count(*)::BIGINT AS n_species FROM cluster_counts GROUP BY cluster_id; """ ) split_seq_ids = { split: int(n) for split, n in con.execute( f""" SELECT split, count(*)::BIGINT AS n FROM read_parquet('{seq_split_path}') GROUP BY split """ ).fetchall() } split_rows = { split: int(n) for split, n in con.execute( f""" WITH base AS ( SELECT {seq_id} AS seq_id FROM {input_sql} ) SELECT s.split, count(*)::BIGINT AS n FROM base JOIN read_parquet('{seq_split_path}') s USING (seq_id) GROUP BY s.split """ ).fetchall() } bad_clusters = int( con.execute( f""" WITH keep AS ( SELECT sc.cluster_id, ss.split FROM read_parquet('{seq_cluster_path}') sc JOIN read_parquet('{seq_split_path}') ss USING (seq_id) WHERE ss.split != 'drop' ) SELECT count(*)::BIGINT FROM ( SELECT cluster_id FROM keep GROUP BY cluster_id HAVING count(DISTINCT split) > 1 ); """ ).fetchone()[0] ) print(f"clusters_spanning_splits(excluding drop) = {bad_clusters}") bad_test = int( con.execute( f""" WITH base AS ( SELECT {seq_id} AS seq_id, {species_key} AS species_key FROM {input_sql} ) SELECT count(*)::BIGINT FROM base JOIN read_parquet('{seq_split_path}') ss USING (seq_id) WHERE ss.split = 'test' AND base.species_key NOT IN (SELECT species_key FROM heldout_species); """ ).fetchone()[0] ) print(f"test_rows_with_seen_species = {bad_test}") bad_val_species = int( con.execute( f""" WITH base AS ( SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key FROM {input_sql} ), labeled AS ( SELECT base.species_key, ss.split FROM base JOIN read_parquet('{seq_split_path}') ss USING (seq_id) WHERE ss.split IN ('train', 'val') ), train_species AS (SELECT DISTINCT species_key FROM labeled WHERE split = 'train'), val_species AS (SELECT DISTINCT species_key FROM labeled WHERE split = 'val') SELECT count(*)::BIGINT FROM (SELECT species_key FROM val_species EXCEPT SELECT species_key FROM train_species); """ ).fetchone()[0] ) print(f"val_species_not_in_train = {bad_val_species}") protein_overlap_train_val = int( con.execute( f""" WITH base AS ( SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm FROM {input_sql} WHERE length({protein_norm}) > 0 ), labeled AS ( SELECT base.protein_norm, ss.split FROM base JOIN read_parquet('{seq_split_path}') ss USING (seq_id) WHERE ss.split IN ('train', 'val') ), train_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train'), val_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'val') SELECT count(*)::BIGINT FROM (SELECT protein_norm FROM train_p INTERSECT SELECT protein_norm FROM val_p); """ ).fetchone()[0] ) protein_overlap_train_test = int( con.execute( f""" WITH base AS ( SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm FROM {input_sql} WHERE length({protein_norm}) > 0 ), labeled AS ( SELECT base.protein_norm, ss.split FROM base JOIN read_parquet('{seq_split_path}') ss USING (seq_id) WHERE ss.split IN ('train', 'test') ), train_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train'), test_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'test') SELECT count(*)::BIGINT FROM (SELECT protein_norm FROM train_p INTERSECT SELECT protein_norm FROM test_p); """ ).fetchone()[0] ) print(f"exact_protein_overlap_train_val = {protein_overlap_train_val}") print(f"exact_protein_overlap_train_test = {protein_overlap_train_test}") mixed_test_clusters = int( con.execute( "SELECT count(*)::BIGINT FROM cluster_flags WHERE n_test > 0 AND n_seen > 0" ).fetchone()[0] ) exact_holdout_seen_conflicts = int( con.execute( f""" WITH base AS ( SELECT DISTINCT {protein_norm} AS protein_norm, {species_key} AS species_key FROM {input_sql} WHERE length({protein_norm}) > 0 AND {species_key} != '' ) SELECT count(*)::BIGINT FROM ( SELECT protein_norm FROM base GROUP BY protein_norm HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 ); """ ).fetchone()[0] ) dropped_seen_rows_exact_holdout = int( con.execute( f""" WITH base AS ( SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key, {protein_norm} AS protein_norm FROM {input_sql} WHERE length(({seq_id})) > 1 AND {species_key} != '' ), conflict_proteins AS ( SELECT protein_norm FROM base WHERE length(protein_norm) > 0 GROUP BY protein_norm HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 ) SELECT count(*)::BIGINT FROM base JOIN conflict_proteins USING (protein_norm) JOIN read_parquet('{seq_split_path}') ss USING (seq_id) WHERE ss.split = 'drop' AND base.species_key NOT IN (SELECT species_key FROM heldout_species); """ ).fetchone()[0] ) dropped_val_rows_exact_train = int( con.execute( f""" WITH base AS ( SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm FROM {input_sql} WHERE length(({seq_id})) > 1 AND length({protein_norm}) > 0 ), labeled AS ( SELECT base.protein_norm, ss.split FROM base JOIN read_parquet('{seq_split_path}') ss USING (seq_id) ), train_proteins AS ( SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train' ) SELECT count(*)::BIGINT FROM base JOIN train_proteins USING (protein_norm) JOIN read_parquet('{seq_split_path}') ss USING (seq_id) WHERE ss.split = 'drop'; """ ).fetchone()[0] ) dropped_seen_rows_mixed = int( con.execute( f""" WITH base AS ( SELECT {seq_id} AS seq_id, {species_key} AS species_key FROM {input_sql} ) SELECT count(*)::BIGINT FROM base JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) JOIN cluster_flags cf USING (cluster_id) JOIN read_parquet('{seq_split_path}') ss USING (seq_id) WHERE ss.split = 'drop' AND cf.n_test > 0 AND base.species_key NOT IN (SELECT species_key FROM heldout_species); """ ).fetchone()[0] ) dropped_seen_seqids_mixed = int( con.execute( f""" WITH base AS ( SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key FROM {input_sql} ) SELECT count(*)::BIGINT FROM base JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) JOIN cluster_flags cf USING (cluster_id) JOIN read_parquet('{seq_split_path}') ss USING (seq_id) WHERE ss.split = 'drop' AND cf.n_test > 0 AND base.species_key NOT IN (SELECT species_key FROM heldout_species); """ ).fetchone()[0] ) same_protein_multi_species = int( con.execute( f""" WITH base AS ( SELECT DISTINCT {protein_norm} AS protein_norm, {species_key} AS species_key FROM {input_sql} WHERE length({protein_norm}) > 0 AND {species_key} != '' ) SELECT count(*)::BIGINT FROM ( SELECT protein_norm FROM base GROUP BY protein_norm HAVING count(DISTINCT species_key) > 1 ); """ ).fetchone()[0] ) same_protein_cross_holdout = int( con.execute( f""" WITH base AS ( SELECT DISTINCT {protein_norm} AS protein_norm, {species_key} AS species_key FROM {input_sql} WHERE length({protein_norm}) > 0 AND {species_key} != '' ) SELECT count(*)::BIGINT FROM ( SELECT protein_norm FROM base GROUP BY protein_norm HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 ); """ ).fetchone()[0] ) report = { "parameters": { "input_glob": args.input_glob, "heldout_test_glob": args.heldout_test_glob, "seq_cluster_parquet": str(seq_cluster), "seq_split_parquet": str(seq_split), "seq_space": args.seq_space, "species_key_mode": args.species_key_mode, "limit_files": int(args.limit_files), }, "split_seq_ids": split_seq_ids, "split_rows": split_rows, "verification": { "clusters_spanning_splits_excluding_drop": bad_clusters, "test_rows_with_seen_species": bad_test, "val_species_not_in_train": bad_val_species, "exact_protein_overlap_train_val": protein_overlap_train_val, "exact_protein_overlap_train_test": protein_overlap_train_test, }, "audit": { "mixed_test_clusters": mixed_test_clusters, "exact_protein_cross_holdout_seen_groups": exact_holdout_seen_conflicts, "dropped_seen_rows_from_exact_protein_holdout_overlap": dropped_seen_rows_exact_holdout, "dropped_rows_from_exact_protein_train_overlap": dropped_val_rows_exact_train, "dropped_seen_rows_from_mixed_test_clusters": dropped_seen_rows_mixed, "dropped_seen_seqids_from_mixed_test_clusters": dropped_seen_seqids_mixed, "same_protein_multi_species_exact_matches": same_protein_multi_species, "same_protein_cross_holdout_species_exact_matches": same_protein_cross_holdout, }, } if args.report_json: report_path = Path(args.report_json) report_path.parent.mkdir(parents=True, exist_ok=True) with open(report_path, "w", encoding="utf-8") as f: json.dump(report, f, indent=2, sort_keys=True) print(f"Wrote audit report: {report_path}") if ( bad_clusters != 0 or bad_test != 0 or bad_val_species != 0 or protein_overlap_train_val != 0 or protein_overlap_train_test != 0 ): raise SystemExit("Verification FAILED (see counts above).") print("Verification OK.") def build_parser() -> argparse.ArgumentParser: ap = argparse.ArgumentParser( description="Resplit data_v2 to data_v3_rebuild using MMseqs2 protein clustering." ) sub = ap.add_subparsers(dest="cmd", required=True) p = sub.add_parser("make-fasta", help="Generate MMseqs FASTA from parquet shards.") p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") p.add_argument("--output-fasta", type=str, default="data_v3_rebuild/_work/mmseqs_input.fasta") p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein") p.add_argument( "--max-input-seq-len", type=int, default=0, help="Drop sequences longer than this from the MMseqs input FASTA (0=use seq-space default).", ) p.add_argument("--threads", type=int, default=32) p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") p.add_argument("--limit-rows", type=int, default=0, help="Debug: limit number of rows written (0=all)") p.set_defaults(func=cmd_make_fasta) p = sub.add_parser("mmseqs-cluster", help="Run MMseqs2 createdb+linclust and emit clustering TSV.") p.add_argument("--mmseqs", type=str, default=_default_mmseqs_path()) p.add_argument("--fasta", type=str, default="data_v3_rebuild/_work/mmseqs_input.fasta") p.add_argument("--workdir", type=str, default="data_v3_rebuild/_work/mmseqs") p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein") p.add_argument("--threads", type=int, default=32) p.add_argument("--min-seq-id", type=float, default=0.90) p.add_argument("-c", "--coverage", type=float, default=0.80) p.add_argument("--cov-mode", type=int, default=2, help="2=enforce representative/query coverage") p.add_argument("--cluster-mode", type=int, default=2, help="2=greedy clustering by sequence length") p.add_argument("--max-seq-len", type=int, default=200000) p.add_argument( "--kmer-per-seq-scale", type=float, default=None, help="Optional MMseqs2 override; leave empty to use MMseqs defaults.", ) p.add_argument("--split-memory-limit", type=str, default="", help="e.g. 120G (empty=use MMseqs default)") g = p.add_mutually_exclusive_group() g.add_argument( "--remove-tmp-files", dest="remove_tmp_files", action="store_true", default=True, help="Remove MMseqs2 tmp files (default).", ) g.add_argument( "--keep-tmp-files", dest="remove_tmp_files", action="store_false", help="Keep MMseqs2 tmp files.", ) p.add_argument("--overwrite", action="store_true") p.set_defaults(func=cmd_mmseqs_cluster) p = sub.add_parser("make-seq-cluster", help="Convert MMseqs TSV to parquet mapping seq_id→cluster_id.") p.add_argument("--cluster-tsv", type=str, default="data_v3_rebuild/_work/mmseqs/clu.tsv") p.add_argument("--output-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet") p.add_argument("--threads", type=int, default=32) p.set_defaults(func=cmd_make_seq_cluster) p = sub.add_parser( "make-seq-split", help="Create seq_id→{train,val,test,drop} using cluster assignments and heldout species.", ) p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet") p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial") p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet") p.add_argument("--cluster-split-parquet", type=str, default="data_v3_rebuild/_work/cluster_split.parquet") p.add_argument("--output-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet") p.add_argument("--val-frac", type=float, default=0.01) p.add_argument("--seed", type=int, default=13) p.add_argument("--threads", type=int, default=32) p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") p.set_defaults(func=cmd_make_seq_split) p = sub.add_parser("write-data-v3", help="Write data_v3 parquet directories from seq_split mapping.") p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet") p.add_argument("--seq-split-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet") p.add_argument("--output-root", type=str, default="data_v3_rebuild") p.add_argument("--num-shards", type=int, default=256, help="Partition each split into N shards") p.add_argument("--threads", type=int, default=32) p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") g = p.add_mutually_exclusive_group() g.add_argument( "--representatives-only", dest="representatives_only", action="store_true", default=True, help="Write only one representative seq_id per MMseqs cluster (default).", ) g.add_argument( "--all-cluster-members", dest="representatives_only", action="store_false", help="Write all seq_ids assigned to the split instead of one representative per cluster.", ) p.add_argument("--overwrite", action="store_true") p.set_defaults(func=cmd_write_data_v3) p = sub.add_parser("verify", help="Verify leakage/species constraints and write an audit report.") p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet") p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial") p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein") p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet") p.add_argument("--seq-split-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet") p.add_argument("--report-json", type=str, default="data_v3_rebuild/_work/split_report.json") p.add_argument("--threads", type=int, default=32) p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") p.set_defaults(func=cmd_verify) p = sub.add_parser( "all", help="Run the full pipeline: make-fasta → mmseqs-cluster → make-seq-cluster → make-seq-split → write-data-v3 → verify.", ) p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet") p.add_argument("--output-root", type=str, default="data_v3_rebuild") p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein") p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial") p.add_argument( "--max-input-seq-len", type=int, default=0, help="Drop sequences longer than this from the MMseqs input FASTA (0=use seq-space default).", ) p.add_argument("--threads", type=int, default=32) p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") p.add_argument("--num-shards", type=int, default=256) g = p.add_mutually_exclusive_group() g.add_argument( "--representatives-only", dest="representatives_only", action="store_true", default=True, help="Write only one representative seq_id per MMseqs cluster (default).", ) g.add_argument( "--all-cluster-members", dest="representatives_only", action="store_false", help="Write all seq_ids assigned to the split instead of one representative per cluster.", ) p.add_argument("--mmseqs", type=str, default=_default_mmseqs_path()) p.add_argument("--min-seq-id", type=float, default=0.90) p.add_argument("-c", "--coverage", type=float, default=0.80) p.add_argument("--cov-mode", type=int, default=2) p.add_argument("--cluster-mode", type=int, default=2) p.add_argument("--max-seq-len", type=int, default=200000) p.add_argument("--kmer-per-seq-scale", type=float, default=None) p.add_argument("--split-memory-limit", type=str, default="") p.add_argument("--val-frac", type=float, default=0.01) p.add_argument("--seed", type=int, default=13) p.add_argument("--overwrite", action="store_true") def _run_all(a: argparse.Namespace) -> None: out_root = Path(a.output_root) work = out_root / "_work" fasta = work / "mmseqs_input.fasta" mmseqs_work = work / "mmseqs" cluster_tsv = mmseqs_work / "clu.tsv" seq_cluster = work / "seq_cluster.parquet" cluster_split = work / "cluster_split.parquet" seq_split = work / "seq_split.parquet" report_json = work / "split_report.json" cmd_make_fasta( argparse.Namespace( input_glob=a.input_glob, output_fasta=str(fasta), seq_space=a.seq_space, max_input_seq_len=a.max_input_seq_len, threads=a.threads, limit_files=a.limit_files, limit_rows=0, ) ) cmd_mmseqs_cluster( argparse.Namespace( mmseqs=a.mmseqs, fasta=str(fasta), workdir=str(mmseqs_work), seq_space=a.seq_space, threads=a.threads, min_seq_id=a.min_seq_id, coverage=a.coverage, cov_mode=a.cov_mode, cluster_mode=a.cluster_mode, max_seq_len=a.max_seq_len, kmer_per_seq_scale=a.kmer_per_seq_scale, split_memory_limit=a.split_memory_limit, remove_tmp_files=True, overwrite=a.overwrite, ) ) cmd_make_seq_cluster( argparse.Namespace( cluster_tsv=str(cluster_tsv), output_parquet=str(seq_cluster), threads=a.threads, ) ) cmd_make_seq_split( argparse.Namespace( input_glob=a.input_glob, heldout_test_glob=a.heldout_test_glob, species_key_mode=a.species_key_mode, seq_cluster_parquet=str(seq_cluster), cluster_split_parquet=str(cluster_split), output_parquet=str(seq_split), val_frac=a.val_frac, seed=a.seed, threads=a.threads, limit_files=a.limit_files, ) ) cmd_write_data_v3( argparse.Namespace( input_glob=a.input_glob, seq_cluster_parquet=str(seq_cluster), seq_split_parquet=str(seq_split), output_root=str(out_root), num_shards=a.num_shards, threads=a.threads, limit_files=a.limit_files, representatives_only=a.representatives_only, overwrite=a.overwrite, ) ) cmd_verify( argparse.Namespace( input_glob=a.input_glob, heldout_test_glob=a.heldout_test_glob, species_key_mode=a.species_key_mode, seq_space=a.seq_space, seq_cluster_parquet=str(seq_cluster), seq_split_parquet=str(seq_split), report_json=str(report_json), threads=a.threads, limit_files=a.limit_files, ) ) p.set_defaults(func=_run_all) return ap def main(argv: Optional[List[str]] = None) -> int: ap = build_parser() args = ap.parse_args(argv) args.func(args) return 0 if __name__ == "__main__": raise SystemExit(main())