| |
| """ |
| Resplit `data_v2/` into leakage-safe `data_v3_rebuild/` using MMseqs2 clustering. |
| |
| Default policy for the current rebuild: |
| - Cluster `protein_seq` with MMseqs2 `linclust` |
| - Define species by normalized binomial name (`genus species`) |
| - Test species are exactly the normalized species present in `data_v2/test` |
| - Validation is cluster-unseen but species-seen |
| - Mixed seen/heldout clusters keep heldout rows in test and drop seen rows |
| |
| Typical usage (end-to-end): |
| python resplit_data_v3.py all --threads 32 --split-memory-limit 120G --num-shards 256 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import shutil |
| import stat |
| import subprocess |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Dict, Iterable, List, Optional, Sequence, Tuple |
|
|
|
|
| def _default_mmseqs_path() -> str: |
| cand = Path("MMseqs2/build/bin/mmseqs") |
| if cand.exists(): |
| return str(cand) |
| return "mmseqs" |
|
|
|
|
| def _run(cmd: List[str], *, cwd: Optional[str] = None, env: Optional[dict] = None) -> None: |
| pretty = " ".join(cmd) |
| print(f"+ {pretty}", flush=True) |
| subprocess.run(cmd, cwd=cwd, env=env, check=True) |
|
|
|
|
| def _sql_escape_path(path: str) -> str: |
| return path.replace("'", "''") |
|
|
|
|
| def _expand_parquet_inputs(inp: str) -> List[str]: |
| import glob |
|
|
| p = Path(inp) |
| if p.exists() and p.is_dir(): |
| files = sorted(str(x) for x in p.rglob("*.parquet")) |
| else: |
| files = sorted(glob.glob(inp)) |
|
|
| seen = set() |
| out: List[str] = [] |
| for f in files: |
| if f not in seen: |
| out.append(f) |
| seen.add(f) |
| return out |
|
|
|
|
| def _duckdb_parquet_source(inp: str, limit_files: int = 0) -> str: |
| files = _expand_parquet_inputs(inp) |
| if not files: |
| raise SystemExit(f"No parquet files found for {inp!r}") |
| if limit_files and int(limit_files) > 0: |
| files = files[: int(limit_files)] |
| quoted = ", ".join(f"'{_sql_escape_path(fp)}'" for fp in files) |
| return f"read_parquet([{quoted}])" |
|
|
|
|
| def _mem_total_bytes() -> Optional[int]: |
| try: |
| with open("/proc/meminfo", "r", encoding="utf-8") as f: |
| for line in f: |
| if line.startswith("MemTotal:"): |
| parts = line.split() |
| kb = int(parts[1]) |
| return kb * 1024 |
| except OSError: |
| return None |
| except (ValueError, IndexError): |
| return None |
| return None |
|
|
|
|
| def _parse_mmseqs_bytes(s: str) -> Optional[int]: |
| s = (s or "").strip() |
| if not s: |
| return None |
| up = s.upper() |
| suffix = up[-1] |
| num_part = up[:-1] |
| unit = suffix |
| if suffix == "B" and len(up) >= 2 and up[-2] in "KMGT": |
| unit = up[-2] |
| num_part = up[:-2] |
| if unit not in "BKMGT": |
| return None |
| try: |
| val = float(num_part) |
| except ValueError: |
| return None |
| mult = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4}[unit] |
| return int(val * mult) |
|
|
|
|
| def _format_bytes(n: int) -> str: |
| for unit, div in [("TiB", 1024**4), ("GiB", 1024**3), ("MiB", 1024**2), ("KiB", 1024)]: |
| if n >= div: |
| return f"{n / div:.1f}{unit}" |
| return f"{n}B" |
|
|
|
|
| def _seq_id_sql() -> str: |
| |
| return "coalesce(protein_refseq_id, '') || '|' || coalesce(RefseqID, '')" |
|
|
|
|
| def _taxon_norm_sql(col: str = "taxon") -> str: |
| return f"regexp_replace(lower(trim(coalesce({col}, ''))), '\\\\s+', ' ', 'g')" |
|
|
|
|
| def _species_key_sql(mode: str, col: str = "taxon") -> str: |
| norm = _taxon_norm_sql(col) |
| if mode == "taxon": |
| return norm |
| if mode == "binomial": |
| return ( |
| f"CASE " |
| f"WHEN strpos({norm}, ' ') > 0 " |
| f"THEN split_part({norm}, ' ', 1) || ' ' || split_part({norm}, ' ', 2) " |
| f"ELSE {norm} END" |
| ) |
| raise ValueError(f"Unsupported species key mode: {mode}") |
|
|
|
|
| def _protein_norm_sql(col: str = "protein_seq") -> str: |
| cleaned = f"regexp_replace(upper(coalesce({col}, '')), '\\\\s+', '', 'g')" |
| no_stop = f"regexp_replace({cleaned}, '[_*]+$', '')" |
| return f"regexp_replace({no_stop}, '[^A-Z]', 'X', 'g')" |
|
|
|
|
| def _cds_norm_sql(col: str = "cds_DNA") -> str: |
| cleaned = f"regexp_replace(upper(coalesce({col}, '')), '\\\\s+', '', 'g')" |
| return f"regexp_replace({cleaned}, '[^ACGTN]', 'N', 'g')" |
|
|
|
|
| def _seq_expr_sql(seq_space: str) -> str: |
| if seq_space == "protein": |
| return _protein_norm_sql("protein_seq") |
| if seq_space == "cds": |
| return _cds_norm_sql("cds_DNA") |
| raise ValueError(f"Unsupported seq space: {seq_space}") |
|
|
|
|
| def _seq_space_input_col(seq_space: str) -> str: |
| if seq_space == "protein": |
| return "protein_seq" |
| if seq_space == "cds": |
| return "cds_DNA" |
| raise ValueError(f"Unsupported seq space: {seq_space}") |
|
|
|
|
| def _mmseqs_dbtype(seq_space: str) -> str: |
| if seq_space == "protein": |
| return "1" |
| if seq_space == "cds": |
| return "2" |
| raise ValueError(f"Unsupported seq space: {seq_space}") |
|
|
|
|
| def _default_max_input_seq_len(seq_space: str) -> int: |
| if seq_space == "protein": |
| |
| |
| return 20_000 |
| return 0 |
|
|
|
|
| def _ensure_mmseqs_ready(mmseqs: str) -> Tuple[str, Dict[str, str]]: |
| path = Path(mmseqs) |
| env = os.environ.copy() |
|
|
| if path.exists(): |
| mode = path.stat().st_mode |
| if not (mode & stat.S_IXUSR): |
| path.chmod(mode | stat.S_IXUSR) |
|
|
| py = Path(sys.executable).resolve() |
| env_root = py.parent.parent |
| conda_root = env_root.parent.parent if env_root.parent.name == "envs" else env_root.parent |
| lib_candidates = [env_root / "lib", conda_root / "lib"] |
| libs = [str(p) for p in lib_candidates if p.exists()] |
| if libs: |
| current = env.get("LD_LIBRARY_PATH", "") |
| env["LD_LIBRARY_PATH"] = ":".join(libs + ([current] if current else [])) |
|
|
| return str(path if path.exists() else mmseqs), env |
|
|
|
|
| def _ensure_output_parent(path: Path) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def cmd_make_fasta(args: argparse.Namespace) -> None: |
| out_fasta = Path(args.output_fasta) |
| _ensure_output_parent(out_fasta) |
|
|
| import duckdb |
|
|
| con = duckdb.connect() |
| con.execute(f"PRAGMA threads={int(args.threads)};") |
| con.execute("PRAGMA enable_progress_bar=true;") |
|
|
| source_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files)) |
| out_path = _sql_escape_path(str(out_fasta)) |
| seq_id = _seq_id_sql() |
| seq_expr = _seq_expr_sql(args.seq_space) |
| raw_col = _seq_space_input_col(args.seq_space) |
| max_input_seq_len = int(args.max_input_seq_len) |
| if max_input_seq_len <= 0: |
| max_input_seq_len = _default_max_input_seq_len(args.seq_space) |
| len_filter = ( |
| f"AND length({seq_expr}) <= {max_input_seq_len}" |
| if max_input_seq_len > 0 |
| else "" |
| ) |
|
|
| sql = f""" |
| COPY ( |
| SELECT |
| '>' || ({seq_id}) AS header, |
| {seq_expr} AS seq |
| FROM {source_sql} |
| WHERE {raw_col} IS NOT NULL |
| AND length({seq_expr}) > 0 |
| {len_filter} |
| AND length(({seq_id})) > 1 |
| {f"LIMIT {int(args.limit_rows)}" if args.limit_rows and int(args.limit_rows) > 0 else ""} |
| ) |
| TO '{out_path}' |
| (FORMAT CSV, DELIMITER '\n', QUOTE '', ESCAPE '', HEADER FALSE); |
| """ |
| t0 = time.time() |
| con.execute(sql) |
| print( |
| f"Wrote FASTA: {out_fasta} seq_space={args.seq_space} " |
| f"max_input_seq_len={max_input_seq_len if max_input_seq_len > 0 else 'none'} " |
| f"(elapsed_s={time.time() - t0:.1f})" |
| ) |
|
|
|
|
| def cmd_mmseqs_cluster(args: argparse.Namespace) -> None: |
| mmseqs, env = _ensure_mmseqs_ready(args.mmseqs) |
| workdir = Path(args.workdir) |
| workdir.mkdir(parents=True, exist_ok=True) |
|
|
| fasta = Path(args.fasta) |
| if not fasta.exists(): |
| raise SystemExit(f"FASTA not found: {fasta}") |
|
|
| seqdb = workdir / "seqdb" |
| clu = workdir / "clu" |
| tmp = workdir / "tmp" |
| tsv = workdir / "clu.tsv" |
|
|
| if args.overwrite: |
| for p in (seqdb, clu, tmp, tsv): |
| if p.is_dir(): |
| shutil.rmtree(p, ignore_errors=True) |
| else: |
| for suffix in ("", ".dbtype", ".index", ".lookup", ".source"): |
| try: |
| os.remove(str(p) + suffix) |
| except OSError: |
| pass |
|
|
| tmp.mkdir(parents=True, exist_ok=True) |
|
|
| _run( |
| [ |
| mmseqs, |
| "createdb", |
| str(fasta), |
| str(seqdb), |
| "--dbtype", |
| _mmseqs_dbtype(args.seq_space), |
| "--shuffle", |
| "0", |
| "--createdb-mode", |
| "1", |
| "--threads", |
| str(int(args.threads)), |
| ], |
| env=env, |
| ) |
|
|
| linclust_cmd = [ |
| mmseqs, |
| "linclust", |
| str(seqdb), |
| str(clu), |
| str(tmp), |
| "--min-seq-id", |
| str(float(args.min_seq_id)), |
| "-c", |
| str(float(args.coverage)), |
| "--cov-mode", |
| str(int(args.cov_mode)), |
| "--cluster-mode", |
| str(int(args.cluster_mode)), |
| "--threads", |
| str(int(args.threads)), |
| "--max-seq-len", |
| str(int(args.max_seq_len)), |
| "--remove-tmp-files", |
| "1" if args.remove_tmp_files else "0", |
| ] |
| if args.split_memory_limit: |
| mem_total = _mem_total_bytes() |
| limit_bytes = _parse_mmseqs_bytes(args.split_memory_limit) |
| if mem_total and limit_bytes and limit_bytes > mem_total: |
| print( |
| f"WARNING: --split-memory-limit={args.split_memory_limit} ({_format_bytes(limit_bytes)}) " |
| f"exceeds system MemTotal ({_format_bytes(mem_total)}). " |
| "MMseqs2 may under-split and crash; consider lowering it or leaving it empty.", |
| file=sys.stderr, |
| flush=True, |
| ) |
| linclust_cmd += ["--split-memory-limit", str(args.split_memory_limit)] |
| if args.kmer_per_seq_scale is not None: |
| linclust_cmd += ["--kmer-per-seq-scale", str(float(args.kmer_per_seq_scale))] |
|
|
| _run(linclust_cmd, env=env) |
| _run([mmseqs, "createtsv", str(seqdb), str(seqdb), str(clu), str(tsv)], env=env) |
| print(f"Wrote cluster TSV: {tsv}") |
|
|
|
|
| def cmd_make_seq_cluster(args: argparse.Namespace) -> None: |
| import duckdb |
|
|
| tsv = Path(args.cluster_tsv) |
| if not tsv.exists(): |
| raise SystemExit(f"Cluster TSV not found: {tsv}") |
| out = Path(args.output_parquet) |
| _ensure_output_parent(out) |
|
|
| con = duckdb.connect() |
| con.execute(f"PRAGMA threads={int(args.threads)};") |
| con.execute("PRAGMA enable_progress_bar=true;") |
|
|
| tsv_path = _sql_escape_path(str(tsv)) |
| out_path = _sql_escape_path(str(out)) |
|
|
| sql = f""" |
| COPY ( |
| SELECT DISTINCT |
| seq_id, |
| cluster_id |
| FROM read_csv( |
| '{tsv_path}', |
| delim='\\t', |
| header=false, |
| columns={{'cluster_id':'VARCHAR','seq_id':'VARCHAR'}} |
| ) |
| ) |
| TO '{out_path}' |
| (FORMAT PARQUET); |
| """ |
| t0 = time.time() |
| con.execute(sql) |
| print(f"Wrote seq→cluster parquet: {out} (elapsed_s={time.time() - t0:.1f})") |
|
|
|
|
| def _write_cluster_split_parquet( |
| con, |
| *, |
| cluster_split_path: Path, |
| seed: int, |
| val_frac: float, |
| ) -> Dict[str, int]: |
| import pyarrow as pa |
| import pyarrow.parquet as pq |
|
|
| cluster_split_path.parent.mkdir(parents=True, exist_ok=True) |
| if cluster_split_path.exists(): |
| cluster_split_path.unlink() |
|
|
| total_seen_rows = int( |
| con.execute( |
| "SELECT coalesce(sum(n_total), 0)::BIGINT FROM cluster_flags WHERE n_test = 0" |
| ).fetchone()[0] |
| ) |
| target_val_rows = int(total_seen_rows * float(val_frac)) |
|
|
| species_remaining = { |
| species_key: int(n_clusters) |
| for species_key, n_clusters in con.execute( |
| """ |
| SELECT |
| cc.species_key, |
| count(*)::BIGINT AS n_clusters |
| FROM cluster_counts cc |
| JOIN cluster_flags cf USING (cluster_id) |
| WHERE cf.n_test = 0 |
| GROUP BY cc.species_key |
| """ |
| ).fetchall() |
| } |
|
|
| cur = con.execute( |
| f""" |
| SELECT |
| cf.cluster_id, |
| cf.n_total, |
| abs(hash(cf.cluster_id || ':{seed}')) AS rnd, |
| cc.species_key |
| FROM cluster_flags cf |
| JOIN cluster_counts cc USING (cluster_id) |
| WHERE cf.n_test = 0 |
| ORDER BY rnd, cf.cluster_id, cc.species_key |
| """ |
| ) |
|
|
| writer = None |
| batch_cluster_ids: List[str] = [] |
| batch_splits: List[str] = [] |
| val_rows = 0 |
| train_clusters = 0 |
| val_clusters = 0 |
| current_cluster: Optional[str] = None |
| current_n_total = 0 |
| current_species: List[str] = [] |
|
|
| def flush_current() -> None: |
| nonlocal writer, val_rows, train_clusters, val_clusters |
| nonlocal current_cluster, current_n_total, current_species |
| if current_cluster is None: |
| return |
| can_val = ( |
| val_rows < target_val_rows |
| and all(species_remaining.get(species_key, 0) > 1 for species_key in current_species) |
| ) |
| split = "val" if can_val else "train" |
| if can_val: |
| for species_key in current_species: |
| species_remaining[species_key] -= 1 |
| val_rows += int(current_n_total) |
| val_clusters += 1 |
| else: |
| train_clusters += 1 |
|
|
| batch_cluster_ids.append(current_cluster) |
| batch_splits.append(split) |
| if len(batch_cluster_ids) >= 200_000: |
| table = pa.table({"cluster_id": batch_cluster_ids, "split": batch_splits}) |
| if writer is None: |
| writer = pq.ParquetWriter(str(cluster_split_path), table.schema) |
| writer.write_table(table) |
| batch_cluster_ids.clear() |
| batch_splits.clear() |
|
|
| while True: |
| rows = cur.fetchmany(200_000) |
| if not rows: |
| break |
| for cluster_id, n_total, _rnd, species_key in rows: |
| cluster_id = str(cluster_id) |
| species_key = str(species_key) |
| if current_cluster is None: |
| current_cluster = cluster_id |
| current_n_total = int(n_total) |
| current_species = [species_key] |
| continue |
| if cluster_id != current_cluster: |
| flush_current() |
| current_cluster = cluster_id |
| current_n_total = int(n_total) |
| current_species = [species_key] |
| continue |
| current_species.append(species_key) |
|
|
| flush_current() |
| if batch_cluster_ids: |
| table = pa.table({"cluster_id": batch_cluster_ids, "split": batch_splits}) |
| if writer is None: |
| writer = pq.ParquetWriter(str(cluster_split_path), table.schema) |
| writer.write_table(table) |
| elif writer is None: |
| empty = pa.table( |
| { |
| "cluster_id": pa.array([], type=pa.string()), |
| "split": pa.array([], type=pa.string()), |
| } |
| ) |
| writer = pq.ParquetWriter(str(cluster_split_path), empty.schema) |
| writer.write_table(empty) |
| if writer is not None: |
| writer.close() |
|
|
| return { |
| "nonheldout_total_rows": total_seen_rows, |
| "target_val_rows": target_val_rows, |
| "actual_val_rows": val_rows, |
| "train_clusters": train_clusters, |
| "val_clusters": val_clusters, |
| } |
|
|
|
|
| def cmd_make_seq_split(args: argparse.Namespace) -> None: |
| import duckdb |
|
|
| seq_cluster = Path(args.seq_cluster_parquet) |
| if not seq_cluster.exists(): |
| raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}") |
|
|
| out = Path(args.output_parquet) |
| cluster_split = Path(args.cluster_split_parquet) |
| _ensure_output_parent(out) |
| _ensure_output_parent(cluster_split) |
|
|
| con = duckdb.connect() |
| con.execute(f"PRAGMA threads={int(args.threads)};") |
| con.execute("PRAGMA enable_progress_bar=true;") |
|
|
| input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files)) |
| heldout_sql = _duckdb_parquet_source(args.heldout_test_glob, 0) |
| seq_cluster_path = _sql_escape_path(str(seq_cluster)) |
| out_path = _sql_escape_path(str(out)) |
|
|
| seq_id = _seq_id_sql() |
| species_key = _species_key_sql(args.species_key_mode, "taxon") |
| protein_norm = _protein_norm_sql("protein_seq") |
|
|
| con.execute( |
| f""" |
| CREATE TEMP TABLE heldout_species AS |
| SELECT DISTINCT {species_key} AS species_key |
| FROM {heldout_sql} |
| WHERE {species_key} != ''; |
| """ |
| ) |
|
|
| con.execute( |
| f""" |
| CREATE TEMP TABLE cluster_counts AS |
| WITH base AS ( |
| SELECT |
| {seq_id} AS seq_id, |
| {species_key} AS species_key |
| FROM {input_sql} |
| WHERE length(({seq_id})) > 1 |
| AND {species_key} != '' |
| ) |
| SELECT |
| sc.cluster_id, |
| base.species_key, |
| count(*)::BIGINT AS n |
| FROM base |
| JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) |
| GROUP BY sc.cluster_id, base.species_key; |
| """ |
| ) |
|
|
| con.execute( |
| """ |
| CREATE TEMP TABLE cluster_flags AS |
| SELECT |
| cluster_id, |
| sum(CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_test, |
| sum(CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_seen, |
| sum(n)::BIGINT AS n_total, |
| count(*)::BIGINT AS n_species |
| FROM cluster_counts |
| GROUP BY cluster_id; |
| """ |
| ) |
|
|
| t0 = time.time() |
| split_summary = _write_cluster_split_parquet( |
| con, |
| cluster_split_path=cluster_split, |
| seed=int(args.seed), |
| val_frac=float(args.val_frac), |
| ) |
| print( |
| "Cluster assignment summary: " |
| f"train_clusters={split_summary['train_clusters']:,} " |
| f"val_clusters={split_summary['val_clusters']:,} " |
| f"target_val_rows={split_summary['target_val_rows']:,} " |
| f"actual_val_rows={split_summary['actual_val_rows']:,} " |
| f"(elapsed_s={time.time() - t0:.1f})" |
| ) |
|
|
| cluster_split_path = _sql_escape_path(str(cluster_split)) |
| con.execute( |
| f""" |
| COPY ( |
| WITH base AS ( |
| SELECT DISTINCT |
| {seq_id} AS seq_id, |
| {species_key} AS species_key, |
| {protein_norm} AS protein_norm |
| FROM {input_sql} |
| WHERE length(({seq_id})) > 1 |
| AND {species_key} != '' |
| ), |
| joined AS ( |
| SELECT |
| base.seq_id, |
| base.species_key, |
| base.protein_norm, |
| sc.cluster_id |
| FROM base |
| LEFT JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) |
| ), |
| labeled AS ( |
| SELECT |
| j.seq_id, |
| j.species_key, |
| j.protein_norm, |
| CASE |
| WHEN j.cluster_id IS NULL THEN 'drop' |
| WHEN j.species_key IN (SELECT species_key FROM heldout_species) THEN 'test' |
| WHEN coalesce(cf.n_test, 0) > 0 THEN 'drop' |
| ELSE coalesce(cs.split, 'drop') |
| END AS split |
| FROM joined j |
| LEFT JOIN cluster_flags cf USING (cluster_id) |
| LEFT JOIN read_parquet('{cluster_split_path}') cs USING (cluster_id) |
| ), |
| protein_flags AS ( |
| SELECT |
| protein_norm, |
| max(CASE WHEN split = 'test' THEN 1 ELSE 0 END) AS has_test, |
| max(CASE WHEN split = 'train' THEN 1 ELSE 0 END) AS has_train |
| FROM labeled |
| WHERE length(protein_norm) > 0 |
| GROUP BY protein_norm |
| ), |
| guarded AS ( |
| SELECT |
| l.seq_id, |
| l.species_key, |
| CASE |
| WHEN l.split = 'drop' THEN 'drop' |
| WHEN length(l.protein_norm) = 0 THEN l.split |
| WHEN coalesce(pf.has_test, 0) = 1 AND l.split IN ('train', 'val') THEN 'drop' |
| WHEN coalesce(pf.has_train, 0) = 1 AND l.split = 'val' THEN 'drop' |
| ELSE l.split |
| END AS split |
| FROM labeled l |
| LEFT JOIN protein_flags pf USING (protein_norm) |
| ), |
| dedup AS ( |
| SELECT |
| seq_id, |
| CASE |
| WHEN count(DISTINCT species_key) > 1 THEN 'drop' |
| WHEN count(DISTINCT split) > 1 THEN 'drop' |
| ELSE any_value(split) |
| END AS split |
| FROM guarded |
| GROUP BY seq_id |
| ) |
| SELECT seq_id, split FROM dedup |
| ) |
| TO '{out_path}' |
| (FORMAT PARQUET); |
| """ |
| ) |
|
|
| rows = con.execute( |
| f""" |
| WITH base AS ( |
| SELECT {seq_id} AS seq_id |
| FROM {input_sql} |
| ) |
| SELECT s.split, count(*)::BIGINT AS n_rows |
| FROM base |
| JOIN read_parquet('{out_path}') s USING (seq_id) |
| GROUP BY s.split |
| ORDER BY n_rows DESC; |
| """ |
| ).fetchall() |
| print("Split summary (rows):") |
| for split, n in rows: |
| print(f" {split}\t{n:,}") |
|
|
| print(f"Wrote cluster→split parquet: {cluster_split}") |
| print(f"Wrote seq→split parquet: {out}") |
|
|
|
|
| def cmd_write_data_v3(args: argparse.Namespace) -> None: |
| import duckdb |
|
|
| seq_split = Path(args.seq_split_parquet) |
| if not seq_split.exists(): |
| raise SystemExit(f"seq_split parquet not found: {seq_split}") |
| seq_cluster = Path(args.seq_cluster_parquet) |
| if args.representatives_only and not seq_cluster.exists(): |
| raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}") |
|
|
| out_root = Path(args.output_root) |
| out_root.mkdir(parents=True, exist_ok=True) |
| (out_root / "_work").mkdir(parents=True, exist_ok=True) |
|
|
| for split_dir in (out_root / "train", out_root / "val", out_root / "test"): |
| if split_dir.exists(): |
| if not args.overwrite: |
| raise SystemExit(f"Output split directory exists: {split_dir} (pass --overwrite)") |
| shutil.rmtree(split_dir) |
| split_dir.mkdir(parents=True, exist_ok=True) |
|
|
| con = duckdb.connect() |
| con.execute(f"PRAGMA threads={int(args.threads)};") |
| con.execute("PRAGMA enable_progress_bar=true;") |
|
|
| input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files)) |
| seq_split_path = _sql_escape_path(str(seq_split)) |
| seq_cluster_path = _sql_escape_path(str(seq_cluster)) |
| seq_id = _seq_id_sql() |
|
|
| num_shards = int(args.num_shards) |
| if num_shards <= 0: |
| raise SystemExit("--num-shards must be > 0") |
|
|
| for split in ("train", "val", "test"): |
| out_dir = _sql_escape_path(str(out_root / split)) |
| if args.representatives_only: |
| target_seq_ids_sql = f""" |
| SELECT min(s.seq_id) AS seq_id |
| FROM read_parquet('{seq_split_path}') s |
| JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) |
| WHERE s.split = '{split}' |
| GROUP BY sc.cluster_id |
| """ |
| else: |
| target_seq_ids_sql = f""" |
| SELECT DISTINCT s.seq_id |
| FROM read_parquet('{seq_split_path}') s |
| WHERE s.split = '{split}' |
| """ |
| sql = f""" |
| COPY ( |
| WITH target_seq_ids AS ( |
| {target_seq_ids_sql} |
| ), |
| rows AS ( |
| SELECT |
| p.*, |
| abs(hash({seq_id})) % {num_shards} AS shard |
| FROM {input_sql} p |
| JOIN target_seq_ids t |
| ON t.seq_id = ({seq_id}) |
| QUALIFY row_number() OVER (PARTITION BY ({seq_id}) ORDER BY ({seq_id})) = 1 |
| ) |
| SELECT * FROM rows |
| ) |
| TO '{out_dir}' |
| (FORMAT PARQUET, PARTITION_BY (shard)); |
| """ |
| t0 = time.time() |
| con.execute(sql) |
| print( |
| f"Wrote {split} parquets to {out_root / split} " |
| f"representatives_only={bool(args.representatives_only)} " |
| f"(elapsed_s={time.time() - t0:.1f})" |
| ) |
|
|
|
|
| def cmd_verify(args: argparse.Namespace) -> None: |
| import duckdb |
|
|
| seq_cluster = Path(args.seq_cluster_parquet) |
| seq_split = Path(args.seq_split_parquet) |
| if not seq_cluster.exists(): |
| raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}") |
| if not seq_split.exists(): |
| raise SystemExit(f"seq_split parquet not found: {seq_split}") |
|
|
| con = duckdb.connect() |
| con.execute(f"PRAGMA threads={int(args.threads)};") |
| con.execute("PRAGMA enable_progress_bar=true;") |
|
|
| input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files)) |
| heldout_sql = _duckdb_parquet_source(args.heldout_test_glob, 0) |
| seq_cluster_path = _sql_escape_path(str(seq_cluster)) |
| seq_split_path = _sql_escape_path(str(seq_split)) |
|
|
| seq_id = _seq_id_sql() |
| species_key = _species_key_sql(args.species_key_mode, "taxon") |
| protein_norm = _protein_norm_sql("protein_seq") |
|
|
| con.execute( |
| f""" |
| CREATE TEMP TABLE heldout_species AS |
| SELECT DISTINCT {species_key} AS species_key |
| FROM {heldout_sql} |
| WHERE {species_key} != ''; |
| """ |
| ) |
| con.execute( |
| f""" |
| CREATE TEMP TABLE cluster_counts AS |
| WITH base AS ( |
| SELECT |
| {seq_id} AS seq_id, |
| {species_key} AS species_key |
| FROM {input_sql} |
| WHERE length(({seq_id})) > 1 |
| AND {species_key} != '' |
| ) |
| SELECT |
| sc.cluster_id, |
| base.species_key, |
| count(*)::BIGINT AS n |
| FROM base |
| JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) |
| GROUP BY sc.cluster_id, base.species_key; |
| """ |
| ) |
| con.execute( |
| """ |
| CREATE TEMP TABLE cluster_flags AS |
| SELECT |
| cluster_id, |
| sum(CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_test, |
| sum(CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_seen, |
| sum(n)::BIGINT AS n_total, |
| count(*)::BIGINT AS n_species |
| FROM cluster_counts |
| GROUP BY cluster_id; |
| """ |
| ) |
|
|
| split_seq_ids = { |
| split: int(n) |
| for split, n in con.execute( |
| f""" |
| SELECT split, count(*)::BIGINT AS n |
| FROM read_parquet('{seq_split_path}') |
| GROUP BY split |
| """ |
| ).fetchall() |
| } |
| split_rows = { |
| split: int(n) |
| for split, n in con.execute( |
| f""" |
| WITH base AS ( |
| SELECT {seq_id} AS seq_id FROM {input_sql} |
| ) |
| SELECT s.split, count(*)::BIGINT AS n |
| FROM base |
| JOIN read_parquet('{seq_split_path}') s USING (seq_id) |
| GROUP BY s.split |
| """ |
| ).fetchall() |
| } |
|
|
| bad_clusters = int( |
| con.execute( |
| f""" |
| WITH keep AS ( |
| SELECT sc.cluster_id, ss.split |
| FROM read_parquet('{seq_cluster_path}') sc |
| JOIN read_parquet('{seq_split_path}') ss USING (seq_id) |
| WHERE ss.split != 'drop' |
| ) |
| SELECT count(*)::BIGINT |
| FROM ( |
| SELECT cluster_id |
| FROM keep |
| GROUP BY cluster_id |
| HAVING count(DISTINCT split) > 1 |
| ); |
| """ |
| ).fetchone()[0] |
| ) |
| print(f"clusters_spanning_splits(excluding drop) = {bad_clusters}") |
|
|
| bad_test = int( |
| con.execute( |
| f""" |
| WITH base AS ( |
| SELECT {seq_id} AS seq_id, {species_key} AS species_key |
| FROM {input_sql} |
| ) |
| SELECT count(*)::BIGINT |
| FROM base |
| JOIN read_parquet('{seq_split_path}') ss USING (seq_id) |
| WHERE ss.split = 'test' |
| AND base.species_key NOT IN (SELECT species_key FROM heldout_species); |
| """ |
| ).fetchone()[0] |
| ) |
| print(f"test_rows_with_seen_species = {bad_test}") |
|
|
| bad_val_species = int( |
| con.execute( |
| f""" |
| WITH base AS ( |
| SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key |
| FROM {input_sql} |
| ), |
| labeled AS ( |
| SELECT base.species_key, ss.split |
| FROM base |
| JOIN read_parquet('{seq_split_path}') ss USING (seq_id) |
| WHERE ss.split IN ('train', 'val') |
| ), |
| train_species AS (SELECT DISTINCT species_key FROM labeled WHERE split = 'train'), |
| val_species AS (SELECT DISTINCT species_key FROM labeled WHERE split = 'val') |
| SELECT count(*)::BIGINT |
| FROM (SELECT species_key FROM val_species EXCEPT SELECT species_key FROM train_species); |
| """ |
| ).fetchone()[0] |
| ) |
| print(f"val_species_not_in_train = {bad_val_species}") |
|
|
| protein_overlap_train_val = int( |
| con.execute( |
| f""" |
| WITH base AS ( |
| SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm |
| FROM {input_sql} |
| WHERE length({protein_norm}) > 0 |
| ), |
| labeled AS ( |
| SELECT base.protein_norm, ss.split |
| FROM base |
| JOIN read_parquet('{seq_split_path}') ss USING (seq_id) |
| WHERE ss.split IN ('train', 'val') |
| ), |
| train_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train'), |
| val_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'val') |
| SELECT count(*)::BIGINT |
| FROM (SELECT protein_norm FROM train_p INTERSECT SELECT protein_norm FROM val_p); |
| """ |
| ).fetchone()[0] |
| ) |
| protein_overlap_train_test = int( |
| con.execute( |
| f""" |
| WITH base AS ( |
| SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm |
| FROM {input_sql} |
| WHERE length({protein_norm}) > 0 |
| ), |
| labeled AS ( |
| SELECT base.protein_norm, ss.split |
| FROM base |
| JOIN read_parquet('{seq_split_path}') ss USING (seq_id) |
| WHERE ss.split IN ('train', 'test') |
| ), |
| train_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train'), |
| test_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'test') |
| SELECT count(*)::BIGINT |
| FROM (SELECT protein_norm FROM train_p INTERSECT SELECT protein_norm FROM test_p); |
| """ |
| ).fetchone()[0] |
| ) |
| print(f"exact_protein_overlap_train_val = {protein_overlap_train_val}") |
| print(f"exact_protein_overlap_train_test = {protein_overlap_train_test}") |
|
|
| mixed_test_clusters = int( |
| con.execute( |
| "SELECT count(*)::BIGINT FROM cluster_flags WHERE n_test > 0 AND n_seen > 0" |
| ).fetchone()[0] |
| ) |
| exact_holdout_seen_conflicts = int( |
| con.execute( |
| f""" |
| WITH base AS ( |
| SELECT DISTINCT |
| {protein_norm} AS protein_norm, |
| {species_key} AS species_key |
| FROM {input_sql} |
| WHERE length({protein_norm}) > 0 |
| AND {species_key} != '' |
| ) |
| SELECT count(*)::BIGINT |
| FROM ( |
| SELECT protein_norm |
| FROM base |
| GROUP BY protein_norm |
| HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 |
| AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 |
| ); |
| """ |
| ).fetchone()[0] |
| ) |
| dropped_seen_rows_exact_holdout = int( |
| con.execute( |
| f""" |
| WITH base AS ( |
| SELECT DISTINCT |
| {seq_id} AS seq_id, |
| {species_key} AS species_key, |
| {protein_norm} AS protein_norm |
| FROM {input_sql} |
| WHERE length(({seq_id})) > 1 |
| AND {species_key} != '' |
| ), |
| conflict_proteins AS ( |
| SELECT protein_norm |
| FROM base |
| WHERE length(protein_norm) > 0 |
| GROUP BY protein_norm |
| HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 |
| AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 |
| ) |
| SELECT count(*)::BIGINT |
| FROM base |
| JOIN conflict_proteins USING (protein_norm) |
| JOIN read_parquet('{seq_split_path}') ss USING (seq_id) |
| WHERE ss.split = 'drop' |
| AND base.species_key NOT IN (SELECT species_key FROM heldout_species); |
| """ |
| ).fetchone()[0] |
| ) |
| dropped_val_rows_exact_train = int( |
| con.execute( |
| f""" |
| WITH base AS ( |
| SELECT DISTINCT |
| {seq_id} AS seq_id, |
| {protein_norm} AS protein_norm |
| FROM {input_sql} |
| WHERE length(({seq_id})) > 1 |
| AND length({protein_norm}) > 0 |
| ), |
| labeled AS ( |
| SELECT base.protein_norm, ss.split |
| FROM base |
| JOIN read_parquet('{seq_split_path}') ss USING (seq_id) |
| ), |
| train_proteins AS ( |
| SELECT DISTINCT protein_norm |
| FROM labeled |
| WHERE split = 'train' |
| ) |
| SELECT count(*)::BIGINT |
| FROM base |
| JOIN train_proteins USING (protein_norm) |
| JOIN read_parquet('{seq_split_path}') ss USING (seq_id) |
| WHERE ss.split = 'drop'; |
| """ |
| ).fetchone()[0] |
| ) |
| dropped_seen_rows_mixed = int( |
| con.execute( |
| f""" |
| WITH base AS ( |
| SELECT {seq_id} AS seq_id, {species_key} AS species_key |
| FROM {input_sql} |
| ) |
| SELECT count(*)::BIGINT |
| FROM base |
| JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) |
| JOIN cluster_flags cf USING (cluster_id) |
| JOIN read_parquet('{seq_split_path}') ss USING (seq_id) |
| WHERE ss.split = 'drop' |
| AND cf.n_test > 0 |
| AND base.species_key NOT IN (SELECT species_key FROM heldout_species); |
| """ |
| ).fetchone()[0] |
| ) |
| dropped_seen_seqids_mixed = int( |
| con.execute( |
| f""" |
| WITH base AS ( |
| SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key |
| FROM {input_sql} |
| ) |
| SELECT count(*)::BIGINT |
| FROM base |
| JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id) |
| JOIN cluster_flags cf USING (cluster_id) |
| JOIN read_parquet('{seq_split_path}') ss USING (seq_id) |
| WHERE ss.split = 'drop' |
| AND cf.n_test > 0 |
| AND base.species_key NOT IN (SELECT species_key FROM heldout_species); |
| """ |
| ).fetchone()[0] |
| ) |
| same_protein_multi_species = int( |
| con.execute( |
| f""" |
| WITH base AS ( |
| SELECT DISTINCT |
| {protein_norm} AS protein_norm, |
| {species_key} AS species_key |
| FROM {input_sql} |
| WHERE length({protein_norm}) > 0 |
| AND {species_key} != '' |
| ) |
| SELECT count(*)::BIGINT |
| FROM ( |
| SELECT protein_norm |
| FROM base |
| GROUP BY protein_norm |
| HAVING count(DISTINCT species_key) > 1 |
| ); |
| """ |
| ).fetchone()[0] |
| ) |
| same_protein_cross_holdout = int( |
| con.execute( |
| f""" |
| WITH base AS ( |
| SELECT DISTINCT |
| {protein_norm} AS protein_norm, |
| {species_key} AS species_key |
| FROM {input_sql} |
| WHERE length({protein_norm}) > 0 |
| AND {species_key} != '' |
| ) |
| SELECT count(*)::BIGINT |
| FROM ( |
| SELECT protein_norm |
| FROM base |
| GROUP BY protein_norm |
| HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 |
| AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0 |
| ); |
| """ |
| ).fetchone()[0] |
| ) |
|
|
| report = { |
| "parameters": { |
| "input_glob": args.input_glob, |
| "heldout_test_glob": args.heldout_test_glob, |
| "seq_cluster_parquet": str(seq_cluster), |
| "seq_split_parquet": str(seq_split), |
| "seq_space": args.seq_space, |
| "species_key_mode": args.species_key_mode, |
| "limit_files": int(args.limit_files), |
| }, |
| "split_seq_ids": split_seq_ids, |
| "split_rows": split_rows, |
| "verification": { |
| "clusters_spanning_splits_excluding_drop": bad_clusters, |
| "test_rows_with_seen_species": bad_test, |
| "val_species_not_in_train": bad_val_species, |
| "exact_protein_overlap_train_val": protein_overlap_train_val, |
| "exact_protein_overlap_train_test": protein_overlap_train_test, |
| }, |
| "audit": { |
| "mixed_test_clusters": mixed_test_clusters, |
| "exact_protein_cross_holdout_seen_groups": exact_holdout_seen_conflicts, |
| "dropped_seen_rows_from_exact_protein_holdout_overlap": dropped_seen_rows_exact_holdout, |
| "dropped_rows_from_exact_protein_train_overlap": dropped_val_rows_exact_train, |
| "dropped_seen_rows_from_mixed_test_clusters": dropped_seen_rows_mixed, |
| "dropped_seen_seqids_from_mixed_test_clusters": dropped_seen_seqids_mixed, |
| "same_protein_multi_species_exact_matches": same_protein_multi_species, |
| "same_protein_cross_holdout_species_exact_matches": same_protein_cross_holdout, |
| }, |
| } |
|
|
| if args.report_json: |
| report_path = Path(args.report_json) |
| report_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(report_path, "w", encoding="utf-8") as f: |
| json.dump(report, f, indent=2, sort_keys=True) |
| print(f"Wrote audit report: {report_path}") |
|
|
| if ( |
| bad_clusters != 0 |
| or bad_test != 0 |
| or bad_val_species != 0 |
| or protein_overlap_train_val != 0 |
| or protein_overlap_train_test != 0 |
| ): |
| raise SystemExit("Verification FAILED (see counts above).") |
| print("Verification OK.") |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| ap = argparse.ArgumentParser( |
| description="Resplit data_v2 to data_v3_rebuild using MMseqs2 protein clustering." |
| ) |
| sub = ap.add_subparsers(dest="cmd", required=True) |
|
|
| p = sub.add_parser("make-fasta", help="Generate MMseqs FASTA from parquet shards.") |
| p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") |
| p.add_argument("--output-fasta", type=str, default="data_v3_rebuild/_work/mmseqs_input.fasta") |
| p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein") |
| p.add_argument( |
| "--max-input-seq-len", |
| type=int, |
| default=0, |
| help="Drop sequences longer than this from the MMseqs input FASTA (0=use seq-space default).", |
| ) |
| p.add_argument("--threads", type=int, default=32) |
| p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") |
| p.add_argument("--limit-rows", type=int, default=0, help="Debug: limit number of rows written (0=all)") |
| p.set_defaults(func=cmd_make_fasta) |
|
|
| p = sub.add_parser("mmseqs-cluster", help="Run MMseqs2 createdb+linclust and emit clustering TSV.") |
| p.add_argument("--mmseqs", type=str, default=_default_mmseqs_path()) |
| p.add_argument("--fasta", type=str, default="data_v3_rebuild/_work/mmseqs_input.fasta") |
| p.add_argument("--workdir", type=str, default="data_v3_rebuild/_work/mmseqs") |
| p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein") |
| p.add_argument("--threads", type=int, default=32) |
| p.add_argument("--min-seq-id", type=float, default=0.90) |
| p.add_argument("-c", "--coverage", type=float, default=0.80) |
| p.add_argument("--cov-mode", type=int, default=2, help="2=enforce representative/query coverage") |
| p.add_argument("--cluster-mode", type=int, default=2, help="2=greedy clustering by sequence length") |
| p.add_argument("--max-seq-len", type=int, default=200000) |
| p.add_argument( |
| "--kmer-per-seq-scale", |
| type=float, |
| default=None, |
| help="Optional MMseqs2 override; leave empty to use MMseqs defaults.", |
| ) |
| p.add_argument("--split-memory-limit", type=str, default="", help="e.g. 120G (empty=use MMseqs default)") |
| g = p.add_mutually_exclusive_group() |
| g.add_argument( |
| "--remove-tmp-files", |
| dest="remove_tmp_files", |
| action="store_true", |
| default=True, |
| help="Remove MMseqs2 tmp files (default).", |
| ) |
| g.add_argument( |
| "--keep-tmp-files", |
| dest="remove_tmp_files", |
| action="store_false", |
| help="Keep MMseqs2 tmp files.", |
| ) |
| p.add_argument("--overwrite", action="store_true") |
| p.set_defaults(func=cmd_mmseqs_cluster) |
|
|
| p = sub.add_parser("make-seq-cluster", help="Convert MMseqs TSV to parquet mapping seq_id→cluster_id.") |
| p.add_argument("--cluster-tsv", type=str, default="data_v3_rebuild/_work/mmseqs/clu.tsv") |
| p.add_argument("--output-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet") |
| p.add_argument("--threads", type=int, default=32) |
| p.set_defaults(func=cmd_make_seq_cluster) |
|
|
| p = sub.add_parser( |
| "make-seq-split", |
| help="Create seq_id→{train,val,test,drop} using cluster assignments and heldout species.", |
| ) |
| p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") |
| p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet") |
| p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial") |
| p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet") |
| p.add_argument("--cluster-split-parquet", type=str, default="data_v3_rebuild/_work/cluster_split.parquet") |
| p.add_argument("--output-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet") |
| p.add_argument("--val-frac", type=float, default=0.01) |
| p.add_argument("--seed", type=int, default=13) |
| p.add_argument("--threads", type=int, default=32) |
| p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") |
| p.set_defaults(func=cmd_make_seq_split) |
|
|
| p = sub.add_parser("write-data-v3", help="Write data_v3 parquet directories from seq_split mapping.") |
| p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") |
| p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet") |
| p.add_argument("--seq-split-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet") |
| p.add_argument("--output-root", type=str, default="data_v3_rebuild") |
| p.add_argument("--num-shards", type=int, default=256, help="Partition each split into N shards") |
| p.add_argument("--threads", type=int, default=32) |
| p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") |
| g = p.add_mutually_exclusive_group() |
| g.add_argument( |
| "--representatives-only", |
| dest="representatives_only", |
| action="store_true", |
| default=True, |
| help="Write only one representative seq_id per MMseqs cluster (default).", |
| ) |
| g.add_argument( |
| "--all-cluster-members", |
| dest="representatives_only", |
| action="store_false", |
| help="Write all seq_ids assigned to the split instead of one representative per cluster.", |
| ) |
| p.add_argument("--overwrite", action="store_true") |
| p.set_defaults(func=cmd_write_data_v3) |
|
|
| p = sub.add_parser("verify", help="Verify leakage/species constraints and write an audit report.") |
| p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") |
| p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet") |
| p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial") |
| p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein") |
| p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet") |
| p.add_argument("--seq-split-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet") |
| p.add_argument("--report-json", type=str, default="data_v3_rebuild/_work/split_report.json") |
| p.add_argument("--threads", type=int, default=32) |
| p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") |
| p.set_defaults(func=cmd_verify) |
|
|
| p = sub.add_parser( |
| "all", |
| help="Run the full pipeline: make-fasta → mmseqs-cluster → make-seq-cluster → make-seq-split → write-data-v3 → verify.", |
| ) |
| p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet") |
| p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet") |
| p.add_argument("--output-root", type=str, default="data_v3_rebuild") |
| p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein") |
| p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial") |
| p.add_argument( |
| "--max-input-seq-len", |
| type=int, |
| default=0, |
| help="Drop sequences longer than this from the MMseqs input FASTA (0=use seq-space default).", |
| ) |
| p.add_argument("--threads", type=int, default=32) |
| p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)") |
| p.add_argument("--num-shards", type=int, default=256) |
| g = p.add_mutually_exclusive_group() |
| g.add_argument( |
| "--representatives-only", |
| dest="representatives_only", |
| action="store_true", |
| default=True, |
| help="Write only one representative seq_id per MMseqs cluster (default).", |
| ) |
| g.add_argument( |
| "--all-cluster-members", |
| dest="representatives_only", |
| action="store_false", |
| help="Write all seq_ids assigned to the split instead of one representative per cluster.", |
| ) |
| p.add_argument("--mmseqs", type=str, default=_default_mmseqs_path()) |
| p.add_argument("--min-seq-id", type=float, default=0.90) |
| p.add_argument("-c", "--coverage", type=float, default=0.80) |
| p.add_argument("--cov-mode", type=int, default=2) |
| p.add_argument("--cluster-mode", type=int, default=2) |
| p.add_argument("--max-seq-len", type=int, default=200000) |
| p.add_argument("--kmer-per-seq-scale", type=float, default=None) |
| p.add_argument("--split-memory-limit", type=str, default="") |
| p.add_argument("--val-frac", type=float, default=0.01) |
| p.add_argument("--seed", type=int, default=13) |
| p.add_argument("--overwrite", action="store_true") |
|
|
| def _run_all(a: argparse.Namespace) -> None: |
| out_root = Path(a.output_root) |
| work = out_root / "_work" |
| fasta = work / "mmseqs_input.fasta" |
| mmseqs_work = work / "mmseqs" |
| cluster_tsv = mmseqs_work / "clu.tsv" |
| seq_cluster = work / "seq_cluster.parquet" |
| cluster_split = work / "cluster_split.parquet" |
| seq_split = work / "seq_split.parquet" |
| report_json = work / "split_report.json" |
|
|
| cmd_make_fasta( |
| argparse.Namespace( |
| input_glob=a.input_glob, |
| output_fasta=str(fasta), |
| seq_space=a.seq_space, |
| max_input_seq_len=a.max_input_seq_len, |
| threads=a.threads, |
| limit_files=a.limit_files, |
| limit_rows=0, |
| ) |
| ) |
| cmd_mmseqs_cluster( |
| argparse.Namespace( |
| mmseqs=a.mmseqs, |
| fasta=str(fasta), |
| workdir=str(mmseqs_work), |
| seq_space=a.seq_space, |
| threads=a.threads, |
| min_seq_id=a.min_seq_id, |
| coverage=a.coverage, |
| cov_mode=a.cov_mode, |
| cluster_mode=a.cluster_mode, |
| max_seq_len=a.max_seq_len, |
| kmer_per_seq_scale=a.kmer_per_seq_scale, |
| split_memory_limit=a.split_memory_limit, |
| remove_tmp_files=True, |
| overwrite=a.overwrite, |
| ) |
| ) |
| cmd_make_seq_cluster( |
| argparse.Namespace( |
| cluster_tsv=str(cluster_tsv), |
| output_parquet=str(seq_cluster), |
| threads=a.threads, |
| ) |
| ) |
| cmd_make_seq_split( |
| argparse.Namespace( |
| input_glob=a.input_glob, |
| heldout_test_glob=a.heldout_test_glob, |
| species_key_mode=a.species_key_mode, |
| seq_cluster_parquet=str(seq_cluster), |
| cluster_split_parquet=str(cluster_split), |
| output_parquet=str(seq_split), |
| val_frac=a.val_frac, |
| seed=a.seed, |
| threads=a.threads, |
| limit_files=a.limit_files, |
| ) |
| ) |
| cmd_write_data_v3( |
| argparse.Namespace( |
| input_glob=a.input_glob, |
| seq_cluster_parquet=str(seq_cluster), |
| seq_split_parquet=str(seq_split), |
| output_root=str(out_root), |
| num_shards=a.num_shards, |
| threads=a.threads, |
| limit_files=a.limit_files, |
| representatives_only=a.representatives_only, |
| overwrite=a.overwrite, |
| ) |
| ) |
| cmd_verify( |
| argparse.Namespace( |
| input_glob=a.input_glob, |
| heldout_test_glob=a.heldout_test_glob, |
| species_key_mode=a.species_key_mode, |
| seq_space=a.seq_space, |
| seq_cluster_parquet=str(seq_cluster), |
| seq_split_parquet=str(seq_split), |
| report_json=str(report_json), |
| threads=a.threads, |
| limit_files=a.limit_files, |
| ) |
| ) |
|
|
| p.set_defaults(func=_run_all) |
| return ap |
|
|
|
|
| def main(argv: Optional[List[str]] = None) -> int: |
| ap = build_parser() |
| args = ap.parse_args(argv) |
| args.func(args) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|