mgnify-evo2-probes / code /scripts /sample_syngenome_amr.py
JG1310's picture
Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs
eb69de4 verified
"""Sample ~8,000 AMR records from SynGenome with full ORFs (start AND stop
codons present) and emit JSONL files in the same schema as the VFDB pipeline
so they slot into the existing Modal embed flow.
Sampling strategy: take ALL records from rare drug classes (anything with <1k
records) so probes/SAE see all 13 drug-class secondary_labels. Sample the
remainder budget from macrolide (which dominates with 14,591 records).
Output: data/targeted_jsonl/syngenome/<drug_class>.jsonl
Each record's mag_id = drug_class slug so embed_*_lean's
/<output>/<label>/<mag_id>/<region_id>.npz path layout works as-is.
"""
import argparse
import json
import random
import re
from collections import Counter, defaultdict
from pathlib import Path
CANONICAL_STARTS = {"ATG", "GTG", "TTG", "CTG"}
STOP_CODONS = {"TAA", "TAG", "TGA"}
COMP = str.maketrans("ACGTacgt", "TGCAtgca")
def revcomp(s: str) -> str:
return s.translate(COMP)[::-1]
def has_full_orf(seq: str, min_aa: int = 50) -> bool:
"""True if any of 6 frames contains a full ORF (start codon → stop codon)
of >= min_aa codons. Excludes ORFs that run off either sequence end."""
for orient in (seq, revcomp(seq)):
for frame in range(3):
sub = orient[frame:]
n = len(sub) - (len(sub) % 3)
codons = [sub[i:i + 3] for i in range(0, n, 3)]
in_orf = False
cur = 0
for c in codons:
if not in_orf:
if c in CANONICAL_STARTS:
in_orf = True
cur = 1
else:
if c in STOP_CODONS:
if cur >= min_aa:
return True
in_orf = False
cur = 0
else:
cur += 1
return False
def gc_content(seq: str) -> float:
s = seq.upper()
gc = sum(1 for c in s if c in "GC")
acgt = sum(1 for c in s if c in "ACGT")
return gc / acgt if acgt else 0.0
def class_slug(s: str) -> str:
return re.sub(r"[^A-Za-z0-9]+", "_", str(s)).strip("_")
def parse_score(desc: str) -> float | None:
if not isinstance(desc, str):
return None
m = re.search(r"score=(-?\d+\.\d+)", desc)
return float(m.group(1)) if m else None
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--in-json", type=Path,
default=Path("/home/ror25cal/MGnify/syngenome_db.json"))
ap.add_argument("--out-dir", type=Path,
default=Path("/home/ror25cal/MGnify/data/targeted_jsonl/syngenome"))
ap.add_argument("--target-n", type=int, default=8000)
ap.add_argument("--small-class-threshold", type=int, default=1000,
help="classes with ≤this many records are taken in full")
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
args.out_dir.mkdir(parents=True, exist_ok=True)
rng = random.Random(args.seed)
print(f"Loading {args.in_json} ...")
with open(args.in_json) as f:
d = json.load(f)
amrs = [r for r in d if r.get("functional_class") == "AMR"]
print(f" AMR records: {len(amrs)}")
# Filter: full ORF
print("Filtering to full-ORF records (start + stop, ≥50 aa) ...")
full_orf = [r for r in amrs if has_full_orf(r["sequence"].upper())]
print(f" {len(full_orf)} records pass ({100*len(full_orf)/len(amrs):.1f}%)")
# Stratified sample: take all from "small" classes; sample remainder from large classes
by_class: dict[str, list[dict]] = defaultdict(list)
for r in full_orf:
by_class[r.get("secondary_label") or "UNKNOWN"].append(r)
print(f" drug classes: {len(by_class)}")
small_classes = {c: rs for c, rs in by_class.items() if len(rs) <= args.small_class_threshold}
large_classes = {c: rs for c, rs in by_class.items() if len(rs) > args.small_class_threshold}
chosen: list[tuple[str, dict]] = []
for c, rs in small_classes.items():
for r in rs:
chosen.append((c, r))
n_taken_from_small = len(chosen)
print(f" taking all from small classes (≤{args.small_class_threshold}): {n_taken_from_small} records "
f"across {len(small_classes)} classes")
remaining_budget = args.target_n - n_taken_from_small
print(f" remaining budget for large classes: {remaining_budget}")
# Distribute proportionally among large classes (just macrolide here usually)
total_large = sum(len(rs) for rs in large_classes.values())
for c, rs in large_classes.items():
frac = len(rs) / total_large
n_take = min(int(round(remaining_budget * frac)), len(rs))
sampled = rng.sample(rs, n_take)
for r in sampled:
chosen.append((c, r))
print(f" {c}: sampling {n_take} of {len(rs)}")
print(f"\nTotal chosen: {len(chosen)} records")
# Write JSONL per drug class with our standard schema
by_drug_out: dict[str, list[dict]] = defaultdict(list)
for c, r in chosen:
seq = r["sequence"].upper()
slug = class_slug(c)
rec = {
"region_id": f"{r['seq_hash']}_AMR_SYN",
"is_positive": True,
"label": "AMR",
"label_class": c, # drug class (secondary_label)
"label_subclass": r.get("primary_label"),
"gene_symbol": r.get("gene_name"),
"product_name": r.get("product_name"),
"organism": r.get("organism"),
"source_db": "SynGenome",
"source_accession": r.get("source_accession"),
"evidence": r.get("evidence"),
"evo2_score": parse_score(r.get("description", "")),
"functional_class": r.get("functional_class"),
"primary_label": r.get("primary_label"),
"secondary_label": r.get("secondary_label"),
"pmids": r.get("pmids"),
"cds_length": len(seq),
"gc_content": round(gc_content(seq), 4),
"mag_id": slug, # output-path grouping key
"random_seed": args.seed,
"sequence": seq,
}
# JSON-safe coercion (NaN/None)
clean = {}
for k, v in rec.items():
if v is None:
clean[k] = None
elif isinstance(v, float) and v != v:
clean[k] = None
else:
clean[k] = v
by_drug_out[slug].append(clean)
summary = []
for slug, recs in sorted(by_drug_out.items()):
out_path = args.out_dir / f"{slug}.jsonl"
with open(out_path, "w") as f:
for r in recs:
f.write(json.dumps(r) + "\n")
summary.append((slug, len(recs)))
print(f" → {out_path.name} ({len(recs)} records)")
print(f"\nTotal: {sum(n for _, n in summary)} records across {len(summary)} files")
# Save summary
summary_path = args.out_dir / "_sample_summary.csv"
with open(summary_path, "w") as f:
f.write("drug_class_slug,n_records\n")
for slug, n in summary:
f.write(f"{slug},{n}\n")
print(f"Summary: {summary_path}")
if __name__ == "__main__":
main()