"""Sample ~8,000 NON-AMR records from SynGenome with full ORFs for AMR-probe validation negatives. Stratified across the four non-AMR functional classes. Less-careful match than the VFDB pipeline (no length/GC pairing, no organism matching) — these are validation negatives, not training negatives. Goal is just "diverse non-AMR sequences with a real CDS inside", in the same JSONL schema as the AMR positives so the embed pipeline picks them up. Output: data/targeted_jsonl/syngenome_neg/.jsonl Each record's mag_id = "neg_" so the embed path layout segregates negatives from positives. """ import argparse import json import random import re from collections import Counter, defaultdict from pathlib import Path # Reuse the ORF-finder from the AMR sampler import sys sys.path.insert(0, str(Path(__file__).parent)) from sample_syngenome_amr import has_full_orf, gc_content, class_slug, parse_score def main(): ap = argparse.ArgumentParser() ap.add_argument("--in-json", type=Path, default=Path("/home/ror25cal/MGnify/syngenome_db.json")) ap.add_argument("--out-dir", type=Path, default=Path("/home/ror25cal/MGnify/data/targeted_jsonl/syngenome_neg")) ap.add_argument("--target-n", type=int, default=8000) ap.add_argument("--per-class-cap", type=int, default=2500, help="cap per functional_class for the 3 large classes; small classes taken in full") ap.add_argument("--seed", type=int, default=42) args = ap.parse_args() args.out_dir.mkdir(parents=True, exist_ok=True) rng = random.Random(args.seed) print(f"Loading {args.in_json} ...") with open(args.in_json) as f: d = json.load(f) non_amr = [r for r in d if r.get("functional_class") != "AMR"] print(f" non-AMR records: {len(non_amr)}") print("Filtering to full-ORF records (start + stop, ≥50 aa) ...") full_orf = [r for r in non_amr if has_full_orf(r["sequence"].upper())] print(f" {len(full_orf)} pass ({100*len(full_orf)/len(non_amr):.1f}%)") by_class: dict[str, list[dict]] = defaultdict(list) for r in full_orf: by_class[r.get("functional_class") or "UNKNOWN"].append(r) print(" available per functional_class:") for c, rs in sorted(by_class.items(), key=lambda kv: -len(kv[1])): print(f" {c}: {len(rs)}") # Stratified pick chosen: list[dict] = [] print("\nSampling:") for c, rs in sorted(by_class.items(), key=lambda kv: -len(kv[1])): n_take = min(len(rs), args.per_class_cap) sampled = rng.sample(rs, n_take) if n_take < len(rs) else rs chosen.extend(sampled) print(f" {c}: taking {n_take} of {len(rs)}") print(f"\nTotal chosen: {len(chosen)} records") # Write JSONL per functional_class with negative-specific schema by_path: dict[str, list[dict]] = defaultdict(list) for r in chosen: seq = r["sequence"].upper() fc = r.get("functional_class") or "UNKNOWN" slug = f"neg_{class_slug(fc)}" rec = { "region_id": f"{r['seq_hash']}_NEG_SYN", "is_positive": False, "label": "negative", "label_class": fc, # e.g. "virulence", "stress", "mobile_element" "label_subclass": r.get("primary_label"), "secondary_label": r.get("secondary_label"), "gene_symbol": r.get("gene_name"), "product_name": r.get("product_name"), "organism": r.get("organism"), "source_db": "SynGenome", "source_accession": r.get("source_accession"), "evidence": r.get("evidence"), "evo2_score": parse_score(r.get("description", "")), "functional_class": fc, "primary_label": r.get("primary_label"), "pmids": r.get("pmids"), "cds_length": len(seq), "gc_content": round(gc_content(seq), 4), "mag_id": slug, # output-path grouping key "random_seed": args.seed, "sequence": seq, } clean = {k: (None if (isinstance(v, float) and v != v) else v) for k, v in rec.items()} by_path[slug].append(clean) summary = [] for slug, recs in sorted(by_path.items()): out_path = args.out_dir / f"{slug}.jsonl" with open(out_path, "w") as f: for r in recs: f.write(json.dumps(r) + "\n") summary.append((slug, len(recs))) print(f" → {out_path.name} ({len(recs)} records)") print(f"\nTotal: {sum(n for _, n in summary)} records across {len(summary)} files") summary_path = args.out_dir / "_sample_summary.csv" with open(summary_path, "w") as f: f.write("class_slug,n_records\n") for slug, n in summary: f.write(f"{slug},{n}\n") print(f"Summary: {summary_path}") if __name__ == "__main__": main()