mgnify-evo2-probes / code /scripts /sample_syngenome_negatives.py
JG1310's picture
Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs
eb69de4 verified
"""Sample ~8,000 NON-AMR records from SynGenome with full ORFs for AMR-probe
validation negatives. Stratified across the four non-AMR functional classes.
Less-careful match than the VFDB pipeline (no length/GC pairing, no organism
matching) — these are validation negatives, not training negatives. Goal is
just "diverse non-AMR sequences with a real CDS inside", in the same JSONL
schema as the AMR positives so the embed pipeline picks them up.
Output: data/targeted_jsonl/syngenome_neg/<class_slug>.jsonl
Each record's mag_id = "neg_<functional_class_slug>" so the embed
path layout segregates negatives from positives.
"""
import argparse
import json
import random
import re
from collections import Counter, defaultdict
from pathlib import Path
# Reuse the ORF-finder from the AMR sampler
import sys
sys.path.insert(0, str(Path(__file__).parent))
from sample_syngenome_amr import has_full_orf, gc_content, class_slug, parse_score
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--in-json", type=Path,
default=Path("/home/ror25cal/MGnify/syngenome_db.json"))
ap.add_argument("--out-dir", type=Path,
default=Path("/home/ror25cal/MGnify/data/targeted_jsonl/syngenome_neg"))
ap.add_argument("--target-n", type=int, default=8000)
ap.add_argument("--per-class-cap", type=int, default=2500,
help="cap per functional_class for the 3 large classes; small classes taken in full")
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
args.out_dir.mkdir(parents=True, exist_ok=True)
rng = random.Random(args.seed)
print(f"Loading {args.in_json} ...")
with open(args.in_json) as f:
d = json.load(f)
non_amr = [r for r in d if r.get("functional_class") != "AMR"]
print(f" non-AMR records: {len(non_amr)}")
print("Filtering to full-ORF records (start + stop, ≥50 aa) ...")
full_orf = [r for r in non_amr if has_full_orf(r["sequence"].upper())]
print(f" {len(full_orf)} pass ({100*len(full_orf)/len(non_amr):.1f}%)")
by_class: dict[str, list[dict]] = defaultdict(list)
for r in full_orf:
by_class[r.get("functional_class") or "UNKNOWN"].append(r)
print(" available per functional_class:")
for c, rs in sorted(by_class.items(), key=lambda kv: -len(kv[1])):
print(f" {c}: {len(rs)}")
# Stratified pick
chosen: list[dict] = []
print("\nSampling:")
for c, rs in sorted(by_class.items(), key=lambda kv: -len(kv[1])):
n_take = min(len(rs), args.per_class_cap)
sampled = rng.sample(rs, n_take) if n_take < len(rs) else rs
chosen.extend(sampled)
print(f" {c}: taking {n_take} of {len(rs)}")
print(f"\nTotal chosen: {len(chosen)} records")
# Write JSONL per functional_class with negative-specific schema
by_path: dict[str, list[dict]] = defaultdict(list)
for r in chosen:
seq = r["sequence"].upper()
fc = r.get("functional_class") or "UNKNOWN"
slug = f"neg_{class_slug(fc)}"
rec = {
"region_id": f"{r['seq_hash']}_NEG_SYN",
"is_positive": False,
"label": "negative",
"label_class": fc, # e.g. "virulence", "stress", "mobile_element"
"label_subclass": r.get("primary_label"),
"secondary_label": r.get("secondary_label"),
"gene_symbol": r.get("gene_name"),
"product_name": r.get("product_name"),
"organism": r.get("organism"),
"source_db": "SynGenome",
"source_accession": r.get("source_accession"),
"evidence": r.get("evidence"),
"evo2_score": parse_score(r.get("description", "")),
"functional_class": fc,
"primary_label": r.get("primary_label"),
"pmids": r.get("pmids"),
"cds_length": len(seq),
"gc_content": round(gc_content(seq), 4),
"mag_id": slug, # output-path grouping key
"random_seed": args.seed,
"sequence": seq,
}
clean = {k: (None if (isinstance(v, float) and v != v) else v)
for k, v in rec.items()}
by_path[slug].append(clean)
summary = []
for slug, recs in sorted(by_path.items()):
out_path = args.out_dir / f"{slug}.jsonl"
with open(out_path, "w") as f:
for r in recs:
f.write(json.dumps(r) + "\n")
summary.append((slug, len(recs)))
print(f" → {out_path.name} ({len(recs)} records)")
print(f"\nTotal: {sum(n for _, n in summary)} records across {len(summary)} files")
summary_path = args.out_dir / "_sample_summary.csv"
with open(summary_path, "w") as f:
f.write("class_slug,n_records\n")
for slug, n in summary:
f.write(f"{slug},{n}\n")
print(f"Summary: {summary_path}")
if __name__ == "__main__":
main()