"""Prepare apples-to-apples external benchmark inputs. This script pins the same BacDive/MediaDive held-out strains used by ``scripts/41_benchmark_media_recommender.py`` and checks whether the local machine can run external baselines such as GenomeSPOT, CarveMe, and gapseq. It deliberately separates preparation from heavy external execution: those tools need raw genome FASTA files and optional third-party databases that are much larger than the repository. """ from __future__ import annotations import argparse import gzip import json import shutil import time from pathlib import Path from typing import Any import pandas as pd from sklearn.model_selection import GroupKFold, KFold from microbe_model import config from microbe_model.pipeline import _fetch_fasta_bytes from microbe_model.train.media_recommender import build_training_table TOOL_CANDIDATES = { "GenomeSPOT": ("genomespot", "genome-spot", "genome_spot"), "CarveMe": ("carve",), "gapseq": ("gapseq",), } def load_recommender_features() -> pd.DataFrame: """Load the same feature stack used by the media recommender.""" feats = pd.read_parquet(config.DATA / "features.parquet") hmm_path = config.DATA / "hmm_features.parquet" if hmm_path.exists(): hmm = pd.read_parquet(hmm_path) feats = feats.merge(hmm, on="genome_accession", how="left") kegg_path = config.DATA / "kegg_modules.parquet" if kegg_path.exists(): kegg = pd.read_parquet(kegg_path) feats = feats.merge(kegg, on="genome_accession", how="left") iso_meta_path = config.DATA / "isolation_metadata.parquet" if iso_meta_path.exists(): iso_meta = pd.read_parquet(iso_meta_path) iso_meta["bacdive_id"] = iso_meta["bacdive_id"].astype(int) feats["bacdive_id"] = feats["bacdive_id"].astype(int) keep = ["bacdive_id", "iso_lat", "iso_lon", "iso_collection_year"] keep += [ c for c in iso_meta.columns if c.startswith(("iso_continent_", "iso_country_", "iso_host_kingdom_")) ] feats = feats.merge(iso_meta[keep], on="bacdive_id", how="left") return feats def group_labels(pheno: pd.DataFrame, index: pd.Index) -> pd.Series: """Return stable taxonomic groups with family, then genus, then species fallback.""" tax = pheno.set_index("bacdive_id").reindex(index) groups = tax["family"].copy() groups = groups.fillna(tax["genus"]).fillna(tax["species"]).fillna("__unknown__") return groups.astype(str) def assign_folds( X: pd.DataFrame, y_matrix: pd.DataFrame, pheno: pd.DataFrame, *, split_mode: str, n_splits: int, seed: int, ) -> pd.Series: """Assign the same fold IDs used by the dry-lab recommender benchmark.""" if split_mode == "family": groups = group_labels(pheno, X.index) splitter = GroupKFold(n_splits=min(n_splits, groups.nunique())) splits = list(splitter.split(X, y_matrix, groups)) else: splitter = KFold(n_splits=n_splits, shuffle=True, random_state=seed) splits = list(splitter.split(X)) fold_by_id = pd.Series(index=X.index, dtype="int64") for fold_idx, (_, test_idx) in enumerate(splits): fold_by_id.iloc[test_idx] = fold_idx return fold_by_id.astype(int) def build_manifest(*, split_mode: str, n_splits: int, seed: int) -> tuple[pd.DataFrame, list[str]]: """Build the external benchmark manifest and return selected medium IDs.""" pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet") feats = load_recommender_features() strain_media = pd.read_parquet(config.DATA / "strain_media.parquet") media_meta = pd.read_parquet(config.DATA / "media_metadata.parquet") X, y_matrix, medium_ids = build_training_table(feats, strain_media, pheno) folds = assign_folds(X, y_matrix, pheno, split_mode=split_mode, n_splits=n_splits, seed=seed) labels = pheno.set_index("bacdive_id").reindex(X.index) medium_names = dict(zip(media_meta["medium_id"].astype(str), media_meta["name"], strict=True)) rows: list[dict[str, Any]] = [] for bacdive_id in X.index: y_row = y_matrix.loc[bacdive_id] true_ids = [str(mid) for mid, value in y_row.items() if int(value) == 1] rows.append( { "bacdive_id": int(bacdive_id), "fold": int(folds.loc[bacdive_id]), "genome_accession": str(labels.loc[bacdive_id, "genome_accession"]), "species": _clean_str(labels.loc[bacdive_id, "species"]), "genus": _clean_str(labels.loc[bacdive_id, "genus"]), "family": _clean_str(labels.loc[bacdive_id, "family"]), "optimal_temperature_c": _clean_float(labels.loc[bacdive_id, "optimal_temperature_c"]), "optimal_ph": _clean_float(labels.loc[bacdive_id, "optimal_ph"]), "salt_tolerance_pct": _clean_float(labels.loc[bacdive_id, "salt_tolerance_pct"]), "oxygen_requirement": _clean_str(labels.loc[bacdive_id, "oxygen_requirement"]), "true_media_ids": "|".join(true_ids), "true_media_names": "|".join(medium_names.get(mid, "") for mid in true_ids), "n_true_media": len(true_ids), } ) manifest = pd.DataFrame(rows).sort_values(["fold", "bacdive_id"]).reset_index(drop=True) return manifest, [str(mid) for mid in medium_ids] def _clean_str(value: Any) -> str: if pd.isna(value): return "" return str(value) def _clean_float(value: Any) -> float | None: if pd.isna(value): return None return float(value) def detect_tools() -> dict[str, dict[str, str | None]]: """Detect external command-line tools without installing anything.""" out: dict[str, dict[str, str | None]] = {} for label, candidates in TOOL_CANDIDATES.items(): found_name = None found_path = None for candidate in candidates: path = shutil.which(candidate) if path: found_name = candidate found_path = path break if label == "GenomeSPOT" and found_path is None: local_source = config.DATA / "external_tools" / "GenomeSPOT-main" if (local_source / "genome_spot" / "genome_spot.py").exists() and (local_source / "models").exists(): found_name = "uv run python -m genome_spot.genome_spot" found_path = str(local_source.relative_to(config.ROOT)) if label == "CarveMe" and found_path is None and shutil.which("diamond"): found_name = "uv run --with carveme carve" found_path = shutil.which("diamond") out[label] = {"command": found_name, "path": found_path} return out def local_fasta_path(fasta_dir: Path, accession: str) -> Path | None: """Return an existing FASTA path for an accession, if present.""" for suffix in (".fna", ".fna.gz", ".fa", ".fa.gz", ".fasta", ".fasta.gz"): candidate = fasta_dir / f"{accession}{suffix}" if candidate.exists(): return candidate return None def fasta_coverage(manifest: pd.DataFrame, fasta_dir: Path) -> dict[str, Any]: """Count how many manifest accessions already have local FASTA files.""" accessions = manifest["genome_accession"].dropna().astype(str).unique().tolist() present = [acc for acc in accessions if local_fasta_path(fasta_dir, acc)] return { "fasta_dir": str(fasta_dir), "unique_accessions": len(accessions), "present_fastas": len(present), "missing_fastas": len(accessions) - len(present), "coverage_pct": 0.0 if not accessions else 100.0 * len(present) / len(accessions), } def download_fastas(manifest: pd.DataFrame, fasta_dir: Path, *, limit: int) -> dict[str, Any]: """Download a small number of missing genome FASTAs from NCBI for smoke tests.""" fasta_dir.mkdir(parents=True, exist_ok=True) accessions = manifest["genome_accession"].dropna().astype(str).drop_duplicates().tolist() missing = [acc for acc in accessions if local_fasta_path(fasta_dir, acc) is None] if limit <= 0: return {"attempted": 0, "downloaded": 0, "failed": 0} attempted = 0 downloaded = 0 failed = 0 for accession in missing[:limit]: attempted += 1 contigs = _fetch_fasta_bytes(accession) if not contigs: failed += 1 continue out_path = fasta_dir / f"{accession}.fna.gz" with gzip.open(out_path, "wt") as handle: for contig_id, sequence in contigs: handle.write(f">{contig_id}\n") for i in range(0, len(sequence), 80): handle.write(sequence[i : i + 80] + "\n") downloaded += 1 return {"attempted": attempted, "downloaded": downloaded, "failed": failed} def write_status_report( *, path: Path, manifest: pd.DataFrame, medium_ids: list[str], tools: dict[str, dict[str, str | None]], coverage: dict[str, Any], download: dict[str, Any], out_manifest: Path, ) -> None: """Write a human-readable status report for the external baseline run.""" label_counts = { "temperature": int(manifest["optimal_temperature_c"].notna().sum()), "ph": int(manifest["optimal_ph"].notna().sum()), "salt": int(manifest["salt_tolerance_pct"].notna().sum()), "oxygen": int((manifest["oxygen_requirement"] != "").sum()), "medium": int((manifest["n_true_media"] > 0).sum()), } fold_counts = manifest["fold"].value_counts().sort_index().to_dict() lines = [ "# External Tool Benchmark Status", "", "This file tracks the apples-to-apples benchmark setup for external tools", "on the same held-out BacDive/MediaDive strains used by the dry-lab media", "recommender benchmark.", "", "## Held-Out Manifest", "", f"- Manifest: `{display_path(out_manifest)}`", f"- Rows: {len(manifest):,}", f"- Unique genome accessions: {coverage['unique_accessions']:,}", f"- Media labels retained: {len(medium_ids):,}", f"- Fold counts: {json.dumps({str(k): int(v) for k, v in fold_counts.items()})}", "", "Label coverage:", "", "| Target | Labeled rows |", "|---|---:|", f"| Temperature | {label_counts['temperature']:,} |", f"| pH | {label_counts['ph']:,} |", f"| Salt | {label_counts['salt']:,} |", f"| Oxygen | {label_counts['oxygen']:,} |", f"| Medium | {label_counts['medium']:,} |", "", "## Local Requirements", "", f"- FASTA directory: `{display_path(Path(str(coverage['fasta_dir'])))}`", f"- FASTAs present: {coverage['present_fastas']:,} / {coverage['unique_accessions']:,} " f"({coverage['coverage_pct']:.2f}%)", f"- FASTA download smoke run: {json.dumps(download)}", "", "| Tool | Local command | Status |", "|---|---|---|", ] for tool, info in tools.items(): command = info["command"] or "" status = "available" if info["path"] else "missing" lines.append(f"| {tool} | `{command}` | {status} |") lines += [ "", "## Verdict", "", ] if all(info["path"] for info in tools.values()) and coverage["present_fastas"] == coverage["unique_accessions"]: lines.append("External baseline execution is ready on this machine.") else: lines.append( "External baseline execution is not ready on this machine yet: the full " "held-out FASTA set and one or more external tool binaries/databases are missing." ) lines += [ "", "## Next Commands", "", "Use the manifest to run each external tool against the same rows and folds.", "The medium-feasibility tools should be scored by whether at least one known", "MediaDive medium is feasible or closest among the tool's predicted feasible", "media/metabolite environments.", "", "```bash", "PYTHONPATH=src uv run --python 3.11 python scripts/42_prepare_external_benchmarks.py \\", " --download-fastas 10", "```", "", "For the full benchmark, download the complete FASTA set into the FASTA", "directory above, install the external tools plus their databases, then run", "tool-specific inference using the `bacdive_id`, `fold`, and", "`genome_accession` columns from the manifest.", "", ] path.write_text("\n".join(lines)) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--split-mode", choices=("family", "random"), default="family") parser.add_argument("--n-splits", type=int, default=5) parser.add_argument("--seed", type=int, default=7) parser.add_argument("--fasta-dir", type=Path, default=config.DATA / "external_benchmark_fastas") parser.add_argument( "--manifest-parquet", type=Path, default=config.ARTIFACTS / "external_benchmark_manifest.parquet", ) parser.add_argument( "--manifest-csv", type=Path, default=config.ARTIFACTS / "external_benchmark_manifest.csv", ) parser.add_argument( "--status-json", type=Path, default=config.ARTIFACTS / "external_benchmark_status.json", ) parser.add_argument( "--status-md", type=Path, default=config.ARTIFACTS / "external_benchmark_status.md", ) parser.add_argument( "--download-fastas", type=int, default=0, help="Download this many missing FASTAs from NCBI for smoke testing. Default: 0.", ) return parser.parse_args() def display_path(path: Path) -> str: """Format project-local paths relative to the repository root.""" try: return str(path.resolve().relative_to(config.ROOT.resolve())) except ValueError: return str(path) def main() -> None: args = parse_args() t0 = time.time() manifest, medium_ids = build_manifest(split_mode=args.split_mode, n_splits=args.n_splits, seed=args.seed) args.manifest_parquet.parent.mkdir(parents=True, exist_ok=True) manifest.to_parquet(args.manifest_parquet, index=False) manifest.to_csv(args.manifest_csv, index=False) download = download_fastas(manifest, args.fasta_dir, limit=args.download_fastas) tools = detect_tools() coverage = fasta_coverage(manifest, args.fasta_dir) payload = { "split_mode": args.split_mode, "n_splits": args.n_splits, "seed": args.seed, "elapsed_s": time.time() - t0, "manifest_parquet": display_path(args.manifest_parquet), "manifest_csv": display_path(args.manifest_csv), "rows": int(len(manifest)), "media_labels": len(medium_ids), "tools": tools, "fasta_coverage": {**coverage, "fasta_dir": display_path(Path(str(coverage["fasta_dir"])))}, "download": download, } args.status_json.write_text(json.dumps(payload, indent=2)) write_status_report( path=args.status_md, manifest=manifest, medium_ids=medium_ids, tools=tools, coverage=coverage, download=download, out_manifest=args.manifest_parquet, ) print(json.dumps(payload, indent=2)) print(f"Wrote {args.manifest_parquet}") print(f"Wrote {args.manifest_csv}") print(f"Wrote {args.status_json}") print(f"Wrote {args.status_md}") if __name__ == "__main__": main()