CodonTranslator / resplit_data_v3.py
alegendaryfish's picture
Public CodonTranslator model and training code release
2d8da02 verified
#!/usr/bin/env python3
"""
Resplit `data_v2/` into leakage-safe `data_v3_rebuild/` using MMseqs2 clustering.
Default policy for the current rebuild:
- Cluster `protein_seq` with MMseqs2 `linclust`
- Define species by normalized binomial name (`genus species`)
- Test species are exactly the normalized species present in `data_v2/test`
- Validation is cluster-unseen but species-seen
- Mixed seen/heldout clusters keep heldout rows in test and drop seen rows
Typical usage (end-to-end):
python resplit_data_v3.py all --threads 32 --split-memory-limit 120G --num-shards 256
"""
from __future__ import annotations
import argparse
import json
import os
import shutil
import stat
import subprocess
import sys
import time
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
def _default_mmseqs_path() -> str:
cand = Path("MMseqs2/build/bin/mmseqs")
if cand.exists():
return str(cand)
return "mmseqs"
def _run(cmd: List[str], *, cwd: Optional[str] = None, env: Optional[dict] = None) -> None:
pretty = " ".join(cmd)
print(f"+ {pretty}", flush=True)
subprocess.run(cmd, cwd=cwd, env=env, check=True)
def _sql_escape_path(path: str) -> str:
return path.replace("'", "''")
def _expand_parquet_inputs(inp: str) -> List[str]:
import glob
p = Path(inp)
if p.exists() and p.is_dir():
files = sorted(str(x) for x in p.rglob("*.parquet"))
else:
files = sorted(glob.glob(inp))
seen = set()
out: List[str] = []
for f in files:
if f not in seen:
out.append(f)
seen.add(f)
return out
def _duckdb_parquet_source(inp: str, limit_files: int = 0) -> str:
files = _expand_parquet_inputs(inp)
if not files:
raise SystemExit(f"No parquet files found for {inp!r}")
if limit_files and int(limit_files) > 0:
files = files[: int(limit_files)]
quoted = ", ".join(f"'{_sql_escape_path(fp)}'" for fp in files)
return f"read_parquet([{quoted}])"
def _mem_total_bytes() -> Optional[int]:
try:
with open("/proc/meminfo", "r", encoding="utf-8") as f:
for line in f:
if line.startswith("MemTotal:"):
parts = line.split()
kb = int(parts[1])
return kb * 1024
except OSError:
return None
except (ValueError, IndexError):
return None
return None
def _parse_mmseqs_bytes(s: str) -> Optional[int]:
s = (s or "").strip()
if not s:
return None
up = s.upper()
suffix = up[-1]
num_part = up[:-1]
unit = suffix
if suffix == "B" and len(up) >= 2 and up[-2] in "KMGT":
unit = up[-2]
num_part = up[:-2]
if unit not in "BKMGT":
return None
try:
val = float(num_part)
except ValueError:
return None
mult = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3, "T": 1024**4}[unit]
return int(val * mult)
def _format_bytes(n: int) -> str:
for unit, div in [("TiB", 1024**4), ("GiB", 1024**3), ("MiB", 1024**2), ("KiB", 1024)]:
if n >= div:
return f"{n / div:.1f}{unit}"
return f"{n}B"
def _seq_id_sql() -> str:
# Keep the stable row identifier aligned with the existing pipeline.
return "coalesce(protein_refseq_id, '') || '|' || coalesce(RefseqID, '')"
def _taxon_norm_sql(col: str = "taxon") -> str:
return f"regexp_replace(lower(trim(coalesce({col}, ''))), '\\\\s+', ' ', 'g')"
def _species_key_sql(mode: str, col: str = "taxon") -> str:
norm = _taxon_norm_sql(col)
if mode == "taxon":
return norm
if mode == "binomial":
return (
f"CASE "
f"WHEN strpos({norm}, ' ') > 0 "
f"THEN split_part({norm}, ' ', 1) || ' ' || split_part({norm}, ' ', 2) "
f"ELSE {norm} END"
)
raise ValueError(f"Unsupported species key mode: {mode}")
def _protein_norm_sql(col: str = "protein_seq") -> str:
cleaned = f"regexp_replace(upper(coalesce({col}, '')), '\\\\s+', '', 'g')"
no_stop = f"regexp_replace({cleaned}, '[_*]+$', '')"
return f"regexp_replace({no_stop}, '[^A-Z]', 'X', 'g')"
def _cds_norm_sql(col: str = "cds_DNA") -> str:
cleaned = f"regexp_replace(upper(coalesce({col}, '')), '\\\\s+', '', 'g')"
return f"regexp_replace({cleaned}, '[^ACGTN]', 'N', 'g')"
def _seq_expr_sql(seq_space: str) -> str:
if seq_space == "protein":
return _protein_norm_sql("protein_seq")
if seq_space == "cds":
return _cds_norm_sql("cds_DNA")
raise ValueError(f"Unsupported seq space: {seq_space}")
def _seq_space_input_col(seq_space: str) -> str:
if seq_space == "protein":
return "protein_seq"
if seq_space == "cds":
return "cds_DNA"
raise ValueError(f"Unsupported seq space: {seq_space}")
def _mmseqs_dbtype(seq_space: str) -> str:
if seq_space == "protein":
return "1"
if seq_space == "cds":
return "2"
raise ValueError(f"Unsupported seq space: {seq_space}")
def _default_max_input_seq_len(seq_space: str) -> int:
if seq_space == "protein":
# MMseqs linclust hit an internal SW bug on a tiny tail of ultra-long proteins
# (~39k aa+). Filtering this tail removes <0.01% of rows and keeps the run stable.
return 20_000
return 0
def _ensure_mmseqs_ready(mmseqs: str) -> Tuple[str, Dict[str, str]]:
path = Path(mmseqs)
env = os.environ.copy()
if path.exists():
mode = path.stat().st_mode
if not (mode & stat.S_IXUSR):
path.chmod(mode | stat.S_IXUSR)
py = Path(sys.executable).resolve()
env_root = py.parent.parent
conda_root = env_root.parent.parent if env_root.parent.name == "envs" else env_root.parent
lib_candidates = [env_root / "lib", conda_root / "lib"]
libs = [str(p) for p in lib_candidates if p.exists()]
if libs:
current = env.get("LD_LIBRARY_PATH", "")
env["LD_LIBRARY_PATH"] = ":".join(libs + ([current] if current else []))
return str(path if path.exists() else mmseqs), env
def _ensure_output_parent(path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
def cmd_make_fasta(args: argparse.Namespace) -> None:
out_fasta = Path(args.output_fasta)
_ensure_output_parent(out_fasta)
import duckdb
con = duckdb.connect()
con.execute(f"PRAGMA threads={int(args.threads)};")
con.execute("PRAGMA enable_progress_bar=true;")
source_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files))
out_path = _sql_escape_path(str(out_fasta))
seq_id = _seq_id_sql()
seq_expr = _seq_expr_sql(args.seq_space)
raw_col = _seq_space_input_col(args.seq_space)
max_input_seq_len = int(args.max_input_seq_len)
if max_input_seq_len <= 0:
max_input_seq_len = _default_max_input_seq_len(args.seq_space)
len_filter = (
f"AND length({seq_expr}) <= {max_input_seq_len}"
if max_input_seq_len > 0
else ""
)
sql = f"""
COPY (
SELECT
'>' || ({seq_id}) AS header,
{seq_expr} AS seq
FROM {source_sql}
WHERE {raw_col} IS NOT NULL
AND length({seq_expr}) > 0
{len_filter}
AND length(({seq_id})) > 1
{f"LIMIT {int(args.limit_rows)}" if args.limit_rows and int(args.limit_rows) > 0 else ""}
)
TO '{out_path}'
(FORMAT CSV, DELIMITER '\n', QUOTE '', ESCAPE '', HEADER FALSE);
"""
t0 = time.time()
con.execute(sql)
print(
f"Wrote FASTA: {out_fasta} seq_space={args.seq_space} "
f"max_input_seq_len={max_input_seq_len if max_input_seq_len > 0 else 'none'} "
f"(elapsed_s={time.time() - t0:.1f})"
)
def cmd_mmseqs_cluster(args: argparse.Namespace) -> None:
mmseqs, env = _ensure_mmseqs_ready(args.mmseqs)
workdir = Path(args.workdir)
workdir.mkdir(parents=True, exist_ok=True)
fasta = Path(args.fasta)
if not fasta.exists():
raise SystemExit(f"FASTA not found: {fasta}")
seqdb = workdir / "seqdb"
clu = workdir / "clu"
tmp = workdir / "tmp"
tsv = workdir / "clu.tsv"
if args.overwrite:
for p in (seqdb, clu, tmp, tsv):
if p.is_dir():
shutil.rmtree(p, ignore_errors=True)
else:
for suffix in ("", ".dbtype", ".index", ".lookup", ".source"):
try:
os.remove(str(p) + suffix)
except OSError:
pass
tmp.mkdir(parents=True, exist_ok=True)
_run(
[
mmseqs,
"createdb",
str(fasta),
str(seqdb),
"--dbtype",
_mmseqs_dbtype(args.seq_space),
"--shuffle",
"0",
"--createdb-mode",
"1",
"--threads",
str(int(args.threads)),
],
env=env,
)
linclust_cmd = [
mmseqs,
"linclust",
str(seqdb),
str(clu),
str(tmp),
"--min-seq-id",
str(float(args.min_seq_id)),
"-c",
str(float(args.coverage)),
"--cov-mode",
str(int(args.cov_mode)),
"--cluster-mode",
str(int(args.cluster_mode)),
"--threads",
str(int(args.threads)),
"--max-seq-len",
str(int(args.max_seq_len)),
"--remove-tmp-files",
"1" if args.remove_tmp_files else "0",
]
if args.split_memory_limit:
mem_total = _mem_total_bytes()
limit_bytes = _parse_mmseqs_bytes(args.split_memory_limit)
if mem_total and limit_bytes and limit_bytes > mem_total:
print(
f"WARNING: --split-memory-limit={args.split_memory_limit} ({_format_bytes(limit_bytes)}) "
f"exceeds system MemTotal ({_format_bytes(mem_total)}). "
"MMseqs2 may under-split and crash; consider lowering it or leaving it empty.",
file=sys.stderr,
flush=True,
)
linclust_cmd += ["--split-memory-limit", str(args.split_memory_limit)]
if args.kmer_per_seq_scale is not None:
linclust_cmd += ["--kmer-per-seq-scale", str(float(args.kmer_per_seq_scale))]
_run(linclust_cmd, env=env)
_run([mmseqs, "createtsv", str(seqdb), str(seqdb), str(clu), str(tsv)], env=env)
print(f"Wrote cluster TSV: {tsv}")
def cmd_make_seq_cluster(args: argparse.Namespace) -> None:
import duckdb
tsv = Path(args.cluster_tsv)
if not tsv.exists():
raise SystemExit(f"Cluster TSV not found: {tsv}")
out = Path(args.output_parquet)
_ensure_output_parent(out)
con = duckdb.connect()
con.execute(f"PRAGMA threads={int(args.threads)};")
con.execute("PRAGMA enable_progress_bar=true;")
tsv_path = _sql_escape_path(str(tsv))
out_path = _sql_escape_path(str(out))
sql = f"""
COPY (
SELECT DISTINCT
seq_id,
cluster_id
FROM read_csv(
'{tsv_path}',
delim='\\t',
header=false,
columns={{'cluster_id':'VARCHAR','seq_id':'VARCHAR'}}
)
)
TO '{out_path}'
(FORMAT PARQUET);
"""
t0 = time.time()
con.execute(sql)
print(f"Wrote seq→cluster parquet: {out} (elapsed_s={time.time() - t0:.1f})")
def _write_cluster_split_parquet(
con,
*,
cluster_split_path: Path,
seed: int,
val_frac: float,
) -> Dict[str, int]:
import pyarrow as pa
import pyarrow.parquet as pq
cluster_split_path.parent.mkdir(parents=True, exist_ok=True)
if cluster_split_path.exists():
cluster_split_path.unlink()
total_seen_rows = int(
con.execute(
"SELECT coalesce(sum(n_total), 0)::BIGINT FROM cluster_flags WHERE n_test = 0"
).fetchone()[0]
)
target_val_rows = int(total_seen_rows * float(val_frac))
species_remaining = {
species_key: int(n_clusters)
for species_key, n_clusters in con.execute(
"""
SELECT
cc.species_key,
count(*)::BIGINT AS n_clusters
FROM cluster_counts cc
JOIN cluster_flags cf USING (cluster_id)
WHERE cf.n_test = 0
GROUP BY cc.species_key
"""
).fetchall()
}
cur = con.execute(
f"""
SELECT
cf.cluster_id,
cf.n_total,
abs(hash(cf.cluster_id || ':{seed}')) AS rnd,
cc.species_key
FROM cluster_flags cf
JOIN cluster_counts cc USING (cluster_id)
WHERE cf.n_test = 0
ORDER BY rnd, cf.cluster_id, cc.species_key
"""
)
writer = None
batch_cluster_ids: List[str] = []
batch_splits: List[str] = []
val_rows = 0
train_clusters = 0
val_clusters = 0
current_cluster: Optional[str] = None
current_n_total = 0
current_species: List[str] = []
def flush_current() -> None:
nonlocal writer, val_rows, train_clusters, val_clusters
nonlocal current_cluster, current_n_total, current_species
if current_cluster is None:
return
can_val = (
val_rows < target_val_rows
and all(species_remaining.get(species_key, 0) > 1 for species_key in current_species)
)
split = "val" if can_val else "train"
if can_val:
for species_key in current_species:
species_remaining[species_key] -= 1
val_rows += int(current_n_total)
val_clusters += 1
else:
train_clusters += 1
batch_cluster_ids.append(current_cluster)
batch_splits.append(split)
if len(batch_cluster_ids) >= 200_000:
table = pa.table({"cluster_id": batch_cluster_ids, "split": batch_splits})
if writer is None:
writer = pq.ParquetWriter(str(cluster_split_path), table.schema)
writer.write_table(table)
batch_cluster_ids.clear()
batch_splits.clear()
while True:
rows = cur.fetchmany(200_000)
if not rows:
break
for cluster_id, n_total, _rnd, species_key in rows:
cluster_id = str(cluster_id)
species_key = str(species_key)
if current_cluster is None:
current_cluster = cluster_id
current_n_total = int(n_total)
current_species = [species_key]
continue
if cluster_id != current_cluster:
flush_current()
current_cluster = cluster_id
current_n_total = int(n_total)
current_species = [species_key]
continue
current_species.append(species_key)
flush_current()
if batch_cluster_ids:
table = pa.table({"cluster_id": batch_cluster_ids, "split": batch_splits})
if writer is None:
writer = pq.ParquetWriter(str(cluster_split_path), table.schema)
writer.write_table(table)
elif writer is None:
empty = pa.table(
{
"cluster_id": pa.array([], type=pa.string()),
"split": pa.array([], type=pa.string()),
}
)
writer = pq.ParquetWriter(str(cluster_split_path), empty.schema)
writer.write_table(empty)
if writer is not None:
writer.close()
return {
"nonheldout_total_rows": total_seen_rows,
"target_val_rows": target_val_rows,
"actual_val_rows": val_rows,
"train_clusters": train_clusters,
"val_clusters": val_clusters,
}
def cmd_make_seq_split(args: argparse.Namespace) -> None:
import duckdb
seq_cluster = Path(args.seq_cluster_parquet)
if not seq_cluster.exists():
raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}")
out = Path(args.output_parquet)
cluster_split = Path(args.cluster_split_parquet)
_ensure_output_parent(out)
_ensure_output_parent(cluster_split)
con = duckdb.connect()
con.execute(f"PRAGMA threads={int(args.threads)};")
con.execute("PRAGMA enable_progress_bar=true;")
input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files))
heldout_sql = _duckdb_parquet_source(args.heldout_test_glob, 0)
seq_cluster_path = _sql_escape_path(str(seq_cluster))
out_path = _sql_escape_path(str(out))
seq_id = _seq_id_sql()
species_key = _species_key_sql(args.species_key_mode, "taxon")
protein_norm = _protein_norm_sql("protein_seq")
con.execute(
f"""
CREATE TEMP TABLE heldout_species AS
SELECT DISTINCT {species_key} AS species_key
FROM {heldout_sql}
WHERE {species_key} != '';
"""
)
con.execute(
f"""
CREATE TEMP TABLE cluster_counts AS
WITH base AS (
SELECT
{seq_id} AS seq_id,
{species_key} AS species_key
FROM {input_sql}
WHERE length(({seq_id})) > 1
AND {species_key} != ''
)
SELECT
sc.cluster_id,
base.species_key,
count(*)::BIGINT AS n
FROM base
JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
GROUP BY sc.cluster_id, base.species_key;
"""
)
con.execute(
"""
CREATE TEMP TABLE cluster_flags AS
SELECT
cluster_id,
sum(CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_test,
sum(CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_seen,
sum(n)::BIGINT AS n_total,
count(*)::BIGINT AS n_species
FROM cluster_counts
GROUP BY cluster_id;
"""
)
t0 = time.time()
split_summary = _write_cluster_split_parquet(
con,
cluster_split_path=cluster_split,
seed=int(args.seed),
val_frac=float(args.val_frac),
)
print(
"Cluster assignment summary: "
f"train_clusters={split_summary['train_clusters']:,} "
f"val_clusters={split_summary['val_clusters']:,} "
f"target_val_rows={split_summary['target_val_rows']:,} "
f"actual_val_rows={split_summary['actual_val_rows']:,} "
f"(elapsed_s={time.time() - t0:.1f})"
)
cluster_split_path = _sql_escape_path(str(cluster_split))
con.execute(
f"""
COPY (
WITH base AS (
SELECT DISTINCT
{seq_id} AS seq_id,
{species_key} AS species_key,
{protein_norm} AS protein_norm
FROM {input_sql}
WHERE length(({seq_id})) > 1
AND {species_key} != ''
),
joined AS (
SELECT
base.seq_id,
base.species_key,
base.protein_norm,
sc.cluster_id
FROM base
LEFT JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
),
labeled AS (
SELECT
j.seq_id,
j.species_key,
j.protein_norm,
CASE
WHEN j.cluster_id IS NULL THEN 'drop'
WHEN j.species_key IN (SELECT species_key FROM heldout_species) THEN 'test'
WHEN coalesce(cf.n_test, 0) > 0 THEN 'drop'
ELSE coalesce(cs.split, 'drop')
END AS split
FROM joined j
LEFT JOIN cluster_flags cf USING (cluster_id)
LEFT JOIN read_parquet('{cluster_split_path}') cs USING (cluster_id)
),
protein_flags AS (
SELECT
protein_norm,
max(CASE WHEN split = 'test' THEN 1 ELSE 0 END) AS has_test,
max(CASE WHEN split = 'train' THEN 1 ELSE 0 END) AS has_train
FROM labeled
WHERE length(protein_norm) > 0
GROUP BY protein_norm
),
guarded AS (
SELECT
l.seq_id,
l.species_key,
CASE
WHEN l.split = 'drop' THEN 'drop'
WHEN length(l.protein_norm) = 0 THEN l.split
WHEN coalesce(pf.has_test, 0) = 1 AND l.split IN ('train', 'val') THEN 'drop'
WHEN coalesce(pf.has_train, 0) = 1 AND l.split = 'val' THEN 'drop'
ELSE l.split
END AS split
FROM labeled l
LEFT JOIN protein_flags pf USING (protein_norm)
),
dedup AS (
SELECT
seq_id,
CASE
WHEN count(DISTINCT species_key) > 1 THEN 'drop'
WHEN count(DISTINCT split) > 1 THEN 'drop'
ELSE any_value(split)
END AS split
FROM guarded
GROUP BY seq_id
)
SELECT seq_id, split FROM dedup
)
TO '{out_path}'
(FORMAT PARQUET);
"""
)
rows = con.execute(
f"""
WITH base AS (
SELECT {seq_id} AS seq_id
FROM {input_sql}
)
SELECT s.split, count(*)::BIGINT AS n_rows
FROM base
JOIN read_parquet('{out_path}') s USING (seq_id)
GROUP BY s.split
ORDER BY n_rows DESC;
"""
).fetchall()
print("Split summary (rows):")
for split, n in rows:
print(f" {split}\t{n:,}")
print(f"Wrote cluster→split parquet: {cluster_split}")
print(f"Wrote seq→split parquet: {out}")
def cmd_write_data_v3(args: argparse.Namespace) -> None:
import duckdb
seq_split = Path(args.seq_split_parquet)
if not seq_split.exists():
raise SystemExit(f"seq_split parquet not found: {seq_split}")
seq_cluster = Path(args.seq_cluster_parquet)
if args.representatives_only and not seq_cluster.exists():
raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}")
out_root = Path(args.output_root)
out_root.mkdir(parents=True, exist_ok=True)
(out_root / "_work").mkdir(parents=True, exist_ok=True)
for split_dir in (out_root / "train", out_root / "val", out_root / "test"):
if split_dir.exists():
if not args.overwrite:
raise SystemExit(f"Output split directory exists: {split_dir} (pass --overwrite)")
shutil.rmtree(split_dir)
split_dir.mkdir(parents=True, exist_ok=True)
con = duckdb.connect()
con.execute(f"PRAGMA threads={int(args.threads)};")
con.execute("PRAGMA enable_progress_bar=true;")
input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files))
seq_split_path = _sql_escape_path(str(seq_split))
seq_cluster_path = _sql_escape_path(str(seq_cluster))
seq_id = _seq_id_sql()
num_shards = int(args.num_shards)
if num_shards <= 0:
raise SystemExit("--num-shards must be > 0")
for split in ("train", "val", "test"):
out_dir = _sql_escape_path(str(out_root / split))
if args.representatives_only:
target_seq_ids_sql = f"""
SELECT min(s.seq_id) AS seq_id
FROM read_parquet('{seq_split_path}') s
JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
WHERE s.split = '{split}'
GROUP BY sc.cluster_id
"""
else:
target_seq_ids_sql = f"""
SELECT DISTINCT s.seq_id
FROM read_parquet('{seq_split_path}') s
WHERE s.split = '{split}'
"""
sql = f"""
COPY (
WITH target_seq_ids AS (
{target_seq_ids_sql}
),
rows AS (
SELECT
p.*,
abs(hash({seq_id})) % {num_shards} AS shard
FROM {input_sql} p
JOIN target_seq_ids t
ON t.seq_id = ({seq_id})
QUALIFY row_number() OVER (PARTITION BY ({seq_id}) ORDER BY ({seq_id})) = 1
)
SELECT * FROM rows
)
TO '{out_dir}'
(FORMAT PARQUET, PARTITION_BY (shard));
"""
t0 = time.time()
con.execute(sql)
print(
f"Wrote {split} parquets to {out_root / split} "
f"representatives_only={bool(args.representatives_only)} "
f"(elapsed_s={time.time() - t0:.1f})"
)
def cmd_verify(args: argparse.Namespace) -> None:
import duckdb
seq_cluster = Path(args.seq_cluster_parquet)
seq_split = Path(args.seq_split_parquet)
if not seq_cluster.exists():
raise SystemExit(f"seq_cluster parquet not found: {seq_cluster}")
if not seq_split.exists():
raise SystemExit(f"seq_split parquet not found: {seq_split}")
con = duckdb.connect()
con.execute(f"PRAGMA threads={int(args.threads)};")
con.execute("PRAGMA enable_progress_bar=true;")
input_sql = _duckdb_parquet_source(args.input_glob, int(args.limit_files))
heldout_sql = _duckdb_parquet_source(args.heldout_test_glob, 0)
seq_cluster_path = _sql_escape_path(str(seq_cluster))
seq_split_path = _sql_escape_path(str(seq_split))
seq_id = _seq_id_sql()
species_key = _species_key_sql(args.species_key_mode, "taxon")
protein_norm = _protein_norm_sql("protein_seq")
con.execute(
f"""
CREATE TEMP TABLE heldout_species AS
SELECT DISTINCT {species_key} AS species_key
FROM {heldout_sql}
WHERE {species_key} != '';
"""
)
con.execute(
f"""
CREATE TEMP TABLE cluster_counts AS
WITH base AS (
SELECT
{seq_id} AS seq_id,
{species_key} AS species_key
FROM {input_sql}
WHERE length(({seq_id})) > 1
AND {species_key} != ''
)
SELECT
sc.cluster_id,
base.species_key,
count(*)::BIGINT AS n
FROM base
JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
GROUP BY sc.cluster_id, base.species_key;
"""
)
con.execute(
"""
CREATE TEMP TABLE cluster_flags AS
SELECT
cluster_id,
sum(CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_test,
sum(CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN n ELSE 0 END)::BIGINT AS n_seen,
sum(n)::BIGINT AS n_total,
count(*)::BIGINT AS n_species
FROM cluster_counts
GROUP BY cluster_id;
"""
)
split_seq_ids = {
split: int(n)
for split, n in con.execute(
f"""
SELECT split, count(*)::BIGINT AS n
FROM read_parquet('{seq_split_path}')
GROUP BY split
"""
).fetchall()
}
split_rows = {
split: int(n)
for split, n in con.execute(
f"""
WITH base AS (
SELECT {seq_id} AS seq_id FROM {input_sql}
)
SELECT s.split, count(*)::BIGINT AS n
FROM base
JOIN read_parquet('{seq_split_path}') s USING (seq_id)
GROUP BY s.split
"""
).fetchall()
}
bad_clusters = int(
con.execute(
f"""
WITH keep AS (
SELECT sc.cluster_id, ss.split
FROM read_parquet('{seq_cluster_path}') sc
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
WHERE ss.split != 'drop'
)
SELECT count(*)::BIGINT
FROM (
SELECT cluster_id
FROM keep
GROUP BY cluster_id
HAVING count(DISTINCT split) > 1
);
"""
).fetchone()[0]
)
print(f"clusters_spanning_splits(excluding drop) = {bad_clusters}")
bad_test = int(
con.execute(
f"""
WITH base AS (
SELECT {seq_id} AS seq_id, {species_key} AS species_key
FROM {input_sql}
)
SELECT count(*)::BIGINT
FROM base
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
WHERE ss.split = 'test'
AND base.species_key NOT IN (SELECT species_key FROM heldout_species);
"""
).fetchone()[0]
)
print(f"test_rows_with_seen_species = {bad_test}")
bad_val_species = int(
con.execute(
f"""
WITH base AS (
SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key
FROM {input_sql}
),
labeled AS (
SELECT base.species_key, ss.split
FROM base
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
WHERE ss.split IN ('train', 'val')
),
train_species AS (SELECT DISTINCT species_key FROM labeled WHERE split = 'train'),
val_species AS (SELECT DISTINCT species_key FROM labeled WHERE split = 'val')
SELECT count(*)::BIGINT
FROM (SELECT species_key FROM val_species EXCEPT SELECT species_key FROM train_species);
"""
).fetchone()[0]
)
print(f"val_species_not_in_train = {bad_val_species}")
protein_overlap_train_val = int(
con.execute(
f"""
WITH base AS (
SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm
FROM {input_sql}
WHERE length({protein_norm}) > 0
),
labeled AS (
SELECT base.protein_norm, ss.split
FROM base
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
WHERE ss.split IN ('train', 'val')
),
train_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train'),
val_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'val')
SELECT count(*)::BIGINT
FROM (SELECT protein_norm FROM train_p INTERSECT SELECT protein_norm FROM val_p);
"""
).fetchone()[0]
)
protein_overlap_train_test = int(
con.execute(
f"""
WITH base AS (
SELECT DISTINCT {seq_id} AS seq_id, {protein_norm} AS protein_norm
FROM {input_sql}
WHERE length({protein_norm}) > 0
),
labeled AS (
SELECT base.protein_norm, ss.split
FROM base
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
WHERE ss.split IN ('train', 'test')
),
train_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'train'),
test_p AS (SELECT DISTINCT protein_norm FROM labeled WHERE split = 'test')
SELECT count(*)::BIGINT
FROM (SELECT protein_norm FROM train_p INTERSECT SELECT protein_norm FROM test_p);
"""
).fetchone()[0]
)
print(f"exact_protein_overlap_train_val = {protein_overlap_train_val}")
print(f"exact_protein_overlap_train_test = {protein_overlap_train_test}")
mixed_test_clusters = int(
con.execute(
"SELECT count(*)::BIGINT FROM cluster_flags WHERE n_test > 0 AND n_seen > 0"
).fetchone()[0]
)
exact_holdout_seen_conflicts = int(
con.execute(
f"""
WITH base AS (
SELECT DISTINCT
{protein_norm} AS protein_norm,
{species_key} AS species_key
FROM {input_sql}
WHERE length({protein_norm}) > 0
AND {species_key} != ''
)
SELECT count(*)::BIGINT
FROM (
SELECT protein_norm
FROM base
GROUP BY protein_norm
HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
);
"""
).fetchone()[0]
)
dropped_seen_rows_exact_holdout = int(
con.execute(
f"""
WITH base AS (
SELECT DISTINCT
{seq_id} AS seq_id,
{species_key} AS species_key,
{protein_norm} AS protein_norm
FROM {input_sql}
WHERE length(({seq_id})) > 1
AND {species_key} != ''
),
conflict_proteins AS (
SELECT protein_norm
FROM base
WHERE length(protein_norm) > 0
GROUP BY protein_norm
HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
)
SELECT count(*)::BIGINT
FROM base
JOIN conflict_proteins USING (protein_norm)
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
WHERE ss.split = 'drop'
AND base.species_key NOT IN (SELECT species_key FROM heldout_species);
"""
).fetchone()[0]
)
dropped_val_rows_exact_train = int(
con.execute(
f"""
WITH base AS (
SELECT DISTINCT
{seq_id} AS seq_id,
{protein_norm} AS protein_norm
FROM {input_sql}
WHERE length(({seq_id})) > 1
AND length({protein_norm}) > 0
),
labeled AS (
SELECT base.protein_norm, ss.split
FROM base
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
),
train_proteins AS (
SELECT DISTINCT protein_norm
FROM labeled
WHERE split = 'train'
)
SELECT count(*)::BIGINT
FROM base
JOIN train_proteins USING (protein_norm)
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
WHERE ss.split = 'drop';
"""
).fetchone()[0]
)
dropped_seen_rows_mixed = int(
con.execute(
f"""
WITH base AS (
SELECT {seq_id} AS seq_id, {species_key} AS species_key
FROM {input_sql}
)
SELECT count(*)::BIGINT
FROM base
JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
JOIN cluster_flags cf USING (cluster_id)
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
WHERE ss.split = 'drop'
AND cf.n_test > 0
AND base.species_key NOT IN (SELECT species_key FROM heldout_species);
"""
).fetchone()[0]
)
dropped_seen_seqids_mixed = int(
con.execute(
f"""
WITH base AS (
SELECT DISTINCT {seq_id} AS seq_id, {species_key} AS species_key
FROM {input_sql}
)
SELECT count(*)::BIGINT
FROM base
JOIN read_parquet('{seq_cluster_path}') sc USING (seq_id)
JOIN cluster_flags cf USING (cluster_id)
JOIN read_parquet('{seq_split_path}') ss USING (seq_id)
WHERE ss.split = 'drop'
AND cf.n_test > 0
AND base.species_key NOT IN (SELECT species_key FROM heldout_species);
"""
).fetchone()[0]
)
same_protein_multi_species = int(
con.execute(
f"""
WITH base AS (
SELECT DISTINCT
{protein_norm} AS protein_norm,
{species_key} AS species_key
FROM {input_sql}
WHERE length({protein_norm}) > 0
AND {species_key} != ''
)
SELECT count(*)::BIGINT
FROM (
SELECT protein_norm
FROM base
GROUP BY protein_norm
HAVING count(DISTINCT species_key) > 1
);
"""
).fetchone()[0]
)
same_protein_cross_holdout = int(
con.execute(
f"""
WITH base AS (
SELECT DISTINCT
{protein_norm} AS protein_norm,
{species_key} AS species_key
FROM {input_sql}
WHERE length({protein_norm}) > 0
AND {species_key} != ''
)
SELECT count(*)::BIGINT
FROM (
SELECT protein_norm
FROM base
GROUP BY protein_norm
HAVING count(DISTINCT CASE WHEN species_key IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
AND count(DISTINCT CASE WHEN species_key NOT IN (SELECT species_key FROM heldout_species) THEN species_key END) > 0
);
"""
).fetchone()[0]
)
report = {
"parameters": {
"input_glob": args.input_glob,
"heldout_test_glob": args.heldout_test_glob,
"seq_cluster_parquet": str(seq_cluster),
"seq_split_parquet": str(seq_split),
"seq_space": args.seq_space,
"species_key_mode": args.species_key_mode,
"limit_files": int(args.limit_files),
},
"split_seq_ids": split_seq_ids,
"split_rows": split_rows,
"verification": {
"clusters_spanning_splits_excluding_drop": bad_clusters,
"test_rows_with_seen_species": bad_test,
"val_species_not_in_train": bad_val_species,
"exact_protein_overlap_train_val": protein_overlap_train_val,
"exact_protein_overlap_train_test": protein_overlap_train_test,
},
"audit": {
"mixed_test_clusters": mixed_test_clusters,
"exact_protein_cross_holdout_seen_groups": exact_holdout_seen_conflicts,
"dropped_seen_rows_from_exact_protein_holdout_overlap": dropped_seen_rows_exact_holdout,
"dropped_rows_from_exact_protein_train_overlap": dropped_val_rows_exact_train,
"dropped_seen_rows_from_mixed_test_clusters": dropped_seen_rows_mixed,
"dropped_seen_seqids_from_mixed_test_clusters": dropped_seen_seqids_mixed,
"same_protein_multi_species_exact_matches": same_protein_multi_species,
"same_protein_cross_holdout_species_exact_matches": same_protein_cross_holdout,
},
}
if args.report_json:
report_path = Path(args.report_json)
report_path.parent.mkdir(parents=True, exist_ok=True)
with open(report_path, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2, sort_keys=True)
print(f"Wrote audit report: {report_path}")
if (
bad_clusters != 0
or bad_test != 0
or bad_val_species != 0
or protein_overlap_train_val != 0
or protein_overlap_train_test != 0
):
raise SystemExit("Verification FAILED (see counts above).")
print("Verification OK.")
def build_parser() -> argparse.ArgumentParser:
ap = argparse.ArgumentParser(
description="Resplit data_v2 to data_v3_rebuild using MMseqs2 protein clustering."
)
sub = ap.add_subparsers(dest="cmd", required=True)
p = sub.add_parser("make-fasta", help="Generate MMseqs FASTA from parquet shards.")
p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
p.add_argument("--output-fasta", type=str, default="data_v3_rebuild/_work/mmseqs_input.fasta")
p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein")
p.add_argument(
"--max-input-seq-len",
type=int,
default=0,
help="Drop sequences longer than this from the MMseqs input FASTA (0=use seq-space default).",
)
p.add_argument("--threads", type=int, default=32)
p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
p.add_argument("--limit-rows", type=int, default=0, help="Debug: limit number of rows written (0=all)")
p.set_defaults(func=cmd_make_fasta)
p = sub.add_parser("mmseqs-cluster", help="Run MMseqs2 createdb+linclust and emit clustering TSV.")
p.add_argument("--mmseqs", type=str, default=_default_mmseqs_path())
p.add_argument("--fasta", type=str, default="data_v3_rebuild/_work/mmseqs_input.fasta")
p.add_argument("--workdir", type=str, default="data_v3_rebuild/_work/mmseqs")
p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein")
p.add_argument("--threads", type=int, default=32)
p.add_argument("--min-seq-id", type=float, default=0.90)
p.add_argument("-c", "--coverage", type=float, default=0.80)
p.add_argument("--cov-mode", type=int, default=2, help="2=enforce representative/query coverage")
p.add_argument("--cluster-mode", type=int, default=2, help="2=greedy clustering by sequence length")
p.add_argument("--max-seq-len", type=int, default=200000)
p.add_argument(
"--kmer-per-seq-scale",
type=float,
default=None,
help="Optional MMseqs2 override; leave empty to use MMseqs defaults.",
)
p.add_argument("--split-memory-limit", type=str, default="", help="e.g. 120G (empty=use MMseqs default)")
g = p.add_mutually_exclusive_group()
g.add_argument(
"--remove-tmp-files",
dest="remove_tmp_files",
action="store_true",
default=True,
help="Remove MMseqs2 tmp files (default).",
)
g.add_argument(
"--keep-tmp-files",
dest="remove_tmp_files",
action="store_false",
help="Keep MMseqs2 tmp files.",
)
p.add_argument("--overwrite", action="store_true")
p.set_defaults(func=cmd_mmseqs_cluster)
p = sub.add_parser("make-seq-cluster", help="Convert MMseqs TSV to parquet mapping seq_id→cluster_id.")
p.add_argument("--cluster-tsv", type=str, default="data_v3_rebuild/_work/mmseqs/clu.tsv")
p.add_argument("--output-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet")
p.add_argument("--threads", type=int, default=32)
p.set_defaults(func=cmd_make_seq_cluster)
p = sub.add_parser(
"make-seq-split",
help="Create seq_id→{train,val,test,drop} using cluster assignments and heldout species.",
)
p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet")
p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial")
p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet")
p.add_argument("--cluster-split-parquet", type=str, default="data_v3_rebuild/_work/cluster_split.parquet")
p.add_argument("--output-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet")
p.add_argument("--val-frac", type=float, default=0.01)
p.add_argument("--seed", type=int, default=13)
p.add_argument("--threads", type=int, default=32)
p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
p.set_defaults(func=cmd_make_seq_split)
p = sub.add_parser("write-data-v3", help="Write data_v3 parquet directories from seq_split mapping.")
p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet")
p.add_argument("--seq-split-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet")
p.add_argument("--output-root", type=str, default="data_v3_rebuild")
p.add_argument("--num-shards", type=int, default=256, help="Partition each split into N shards")
p.add_argument("--threads", type=int, default=32)
p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
g = p.add_mutually_exclusive_group()
g.add_argument(
"--representatives-only",
dest="representatives_only",
action="store_true",
default=True,
help="Write only one representative seq_id per MMseqs cluster (default).",
)
g.add_argument(
"--all-cluster-members",
dest="representatives_only",
action="store_false",
help="Write all seq_ids assigned to the split instead of one representative per cluster.",
)
p.add_argument("--overwrite", action="store_true")
p.set_defaults(func=cmd_write_data_v3)
p = sub.add_parser("verify", help="Verify leakage/species constraints and write an audit report.")
p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet")
p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial")
p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein")
p.add_argument("--seq-cluster-parquet", type=str, default="data_v3_rebuild/_work/seq_cluster.parquet")
p.add_argument("--seq-split-parquet", type=str, default="data_v3_rebuild/_work/seq_split.parquet")
p.add_argument("--report-json", type=str, default="data_v3_rebuild/_work/split_report.json")
p.add_argument("--threads", type=int, default=32)
p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
p.set_defaults(func=cmd_verify)
p = sub.add_parser(
"all",
help="Run the full pipeline: make-fasta → mmseqs-cluster → make-seq-cluster → make-seq-split → write-data-v3 → verify.",
)
p.add_argument("--input-glob", type=str, default="data_v2/*/*.parquet")
p.add_argument("--heldout-test-glob", type=str, default="data_v2/test/*.parquet")
p.add_argument("--output-root", type=str, default="data_v3_rebuild")
p.add_argument("--seq-space", type=str, choices=["protein", "cds"], default="protein")
p.add_argument("--species-key-mode", type=str, choices=["binomial", "taxon"], default="binomial")
p.add_argument(
"--max-input-seq-len",
type=int,
default=0,
help="Drop sequences longer than this from the MMseqs input FASTA (0=use seq-space default).",
)
p.add_argument("--threads", type=int, default=32)
p.add_argument("--limit-files", type=int, default=0, help="Only read the first N parquet files (0=all)")
p.add_argument("--num-shards", type=int, default=256)
g = p.add_mutually_exclusive_group()
g.add_argument(
"--representatives-only",
dest="representatives_only",
action="store_true",
default=True,
help="Write only one representative seq_id per MMseqs cluster (default).",
)
g.add_argument(
"--all-cluster-members",
dest="representatives_only",
action="store_false",
help="Write all seq_ids assigned to the split instead of one representative per cluster.",
)
p.add_argument("--mmseqs", type=str, default=_default_mmseqs_path())
p.add_argument("--min-seq-id", type=float, default=0.90)
p.add_argument("-c", "--coverage", type=float, default=0.80)
p.add_argument("--cov-mode", type=int, default=2)
p.add_argument("--cluster-mode", type=int, default=2)
p.add_argument("--max-seq-len", type=int, default=200000)
p.add_argument("--kmer-per-seq-scale", type=float, default=None)
p.add_argument("--split-memory-limit", type=str, default="")
p.add_argument("--val-frac", type=float, default=0.01)
p.add_argument("--seed", type=int, default=13)
p.add_argument("--overwrite", action="store_true")
def _run_all(a: argparse.Namespace) -> None:
out_root = Path(a.output_root)
work = out_root / "_work"
fasta = work / "mmseqs_input.fasta"
mmseqs_work = work / "mmseqs"
cluster_tsv = mmseqs_work / "clu.tsv"
seq_cluster = work / "seq_cluster.parquet"
cluster_split = work / "cluster_split.parquet"
seq_split = work / "seq_split.parquet"
report_json = work / "split_report.json"
cmd_make_fasta(
argparse.Namespace(
input_glob=a.input_glob,
output_fasta=str(fasta),
seq_space=a.seq_space,
max_input_seq_len=a.max_input_seq_len,
threads=a.threads,
limit_files=a.limit_files,
limit_rows=0,
)
)
cmd_mmseqs_cluster(
argparse.Namespace(
mmseqs=a.mmseqs,
fasta=str(fasta),
workdir=str(mmseqs_work),
seq_space=a.seq_space,
threads=a.threads,
min_seq_id=a.min_seq_id,
coverage=a.coverage,
cov_mode=a.cov_mode,
cluster_mode=a.cluster_mode,
max_seq_len=a.max_seq_len,
kmer_per_seq_scale=a.kmer_per_seq_scale,
split_memory_limit=a.split_memory_limit,
remove_tmp_files=True,
overwrite=a.overwrite,
)
)
cmd_make_seq_cluster(
argparse.Namespace(
cluster_tsv=str(cluster_tsv),
output_parquet=str(seq_cluster),
threads=a.threads,
)
)
cmd_make_seq_split(
argparse.Namespace(
input_glob=a.input_glob,
heldout_test_glob=a.heldout_test_glob,
species_key_mode=a.species_key_mode,
seq_cluster_parquet=str(seq_cluster),
cluster_split_parquet=str(cluster_split),
output_parquet=str(seq_split),
val_frac=a.val_frac,
seed=a.seed,
threads=a.threads,
limit_files=a.limit_files,
)
)
cmd_write_data_v3(
argparse.Namespace(
input_glob=a.input_glob,
seq_cluster_parquet=str(seq_cluster),
seq_split_parquet=str(seq_split),
output_root=str(out_root),
num_shards=a.num_shards,
threads=a.threads,
limit_files=a.limit_files,
representatives_only=a.representatives_only,
overwrite=a.overwrite,
)
)
cmd_verify(
argparse.Namespace(
input_glob=a.input_glob,
heldout_test_glob=a.heldout_test_glob,
species_key_mode=a.species_key_mode,
seq_space=a.seq_space,
seq_cluster_parquet=str(seq_cluster),
seq_split_parquet=str(seq_split),
report_json=str(report_json),
threads=a.threads,
limit_files=a.limit_files,
)
)
p.set_defaults(func=_run_all)
return ap
def main(argv: Optional[List[str]] = None) -> int:
ap = build_parser()
args = ap.parse_args(argv)
args.func(args)
return 0
if __name__ == "__main__":
raise SystemExit(main())