| """Sample ~8,000 NON-AMR records from SynGenome with full ORFs for AMR-probe |
| validation negatives. Stratified across the four non-AMR functional classes. |
| |
| Less-careful match than the VFDB pipeline (no length/GC pairing, no organism |
| matching) — these are validation negatives, not training negatives. Goal is |
| just "diverse non-AMR sequences with a real CDS inside", in the same JSONL |
| schema as the AMR positives so the embed pipeline picks them up. |
| |
| Output: data/targeted_jsonl/syngenome_neg/<class_slug>.jsonl |
| Each record's mag_id = "neg_<functional_class_slug>" so the embed |
| path layout segregates negatives from positives. |
| """ |
| import argparse |
| import json |
| import random |
| import re |
| from collections import Counter, defaultdict |
| from pathlib import Path |
|
|
| |
| import sys |
| sys.path.insert(0, str(Path(__file__).parent)) |
| from sample_syngenome_amr import has_full_orf, gc_content, class_slug, parse_score |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--in-json", type=Path, |
| default=Path("/home/ror25cal/MGnify/syngenome_db.json")) |
| ap.add_argument("--out-dir", type=Path, |
| default=Path("/home/ror25cal/MGnify/data/targeted_jsonl/syngenome_neg")) |
| ap.add_argument("--target-n", type=int, default=8000) |
| ap.add_argument("--per-class-cap", type=int, default=2500, |
| help="cap per functional_class for the 3 large classes; small classes taken in full") |
| ap.add_argument("--seed", type=int, default=42) |
| args = ap.parse_args() |
|
|
| args.out_dir.mkdir(parents=True, exist_ok=True) |
| rng = random.Random(args.seed) |
|
|
| print(f"Loading {args.in_json} ...") |
| with open(args.in_json) as f: |
| d = json.load(f) |
| non_amr = [r for r in d if r.get("functional_class") != "AMR"] |
| print(f" non-AMR records: {len(non_amr)}") |
|
|
| print("Filtering to full-ORF records (start + stop, ≥50 aa) ...") |
| full_orf = [r for r in non_amr if has_full_orf(r["sequence"].upper())] |
| print(f" {len(full_orf)} pass ({100*len(full_orf)/len(non_amr):.1f}%)") |
|
|
| by_class: dict[str, list[dict]] = defaultdict(list) |
| for r in full_orf: |
| by_class[r.get("functional_class") or "UNKNOWN"].append(r) |
| print(" available per functional_class:") |
| for c, rs in sorted(by_class.items(), key=lambda kv: -len(kv[1])): |
| print(f" {c}: {len(rs)}") |
|
|
| |
| chosen: list[dict] = [] |
| print("\nSampling:") |
| for c, rs in sorted(by_class.items(), key=lambda kv: -len(kv[1])): |
| n_take = min(len(rs), args.per_class_cap) |
| sampled = rng.sample(rs, n_take) if n_take < len(rs) else rs |
| chosen.extend(sampled) |
| print(f" {c}: taking {n_take} of {len(rs)}") |
| print(f"\nTotal chosen: {len(chosen)} records") |
|
|
| |
| by_path: dict[str, list[dict]] = defaultdict(list) |
| for r in chosen: |
| seq = r["sequence"].upper() |
| fc = r.get("functional_class") or "UNKNOWN" |
| slug = f"neg_{class_slug(fc)}" |
| rec = { |
| "region_id": f"{r['seq_hash']}_NEG_SYN", |
| "is_positive": False, |
| "label": "negative", |
| "label_class": fc, |
| "label_subclass": r.get("primary_label"), |
| "secondary_label": r.get("secondary_label"), |
| "gene_symbol": r.get("gene_name"), |
| "product_name": r.get("product_name"), |
| "organism": r.get("organism"), |
| "source_db": "SynGenome", |
| "source_accession": r.get("source_accession"), |
| "evidence": r.get("evidence"), |
| "evo2_score": parse_score(r.get("description", "")), |
| "functional_class": fc, |
| "primary_label": r.get("primary_label"), |
| "pmids": r.get("pmids"), |
| "cds_length": len(seq), |
| "gc_content": round(gc_content(seq), 4), |
| "mag_id": slug, |
| "random_seed": args.seed, |
| "sequence": seq, |
| } |
| clean = {k: (None if (isinstance(v, float) and v != v) else v) |
| for k, v in rec.items()} |
| by_path[slug].append(clean) |
|
|
| summary = [] |
| for slug, recs in sorted(by_path.items()): |
| out_path = args.out_dir / f"{slug}.jsonl" |
| with open(out_path, "w") as f: |
| for r in recs: |
| f.write(json.dumps(r) + "\n") |
| summary.append((slug, len(recs))) |
| print(f" → {out_path.name} ({len(recs)} records)") |
|
|
| print(f"\nTotal: {sum(n for _, n in summary)} records across {len(summary)} files") |
| summary_path = args.out_dir / "_sample_summary.csv" |
| with open(summary_path, "w") as f: |
| f.write("class_slug,n_records\n") |
| for slug, n in summary: |
| f.write(f"{slug},{n}\n") |
| print(f"Summary: {summary_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|