| """Sample ~8,000 AMR records from SynGenome with full ORFs (start AND stop |
| codons present) and emit JSONL files in the same schema as the VFDB pipeline |
| so they slot into the existing Modal embed flow. |
| |
| Sampling strategy: take ALL records from rare drug classes (anything with <1k |
| records) so probes/SAE see all 13 drug-class secondary_labels. Sample the |
| remainder budget from macrolide (which dominates with 14,591 records). |
| |
| Output: data/targeted_jsonl/syngenome/<drug_class>.jsonl |
| Each record's mag_id = drug_class slug so embed_*_lean's |
| /<output>/<label>/<mag_id>/<region_id>.npz path layout works as-is. |
| """ |
| import argparse |
| import json |
| import random |
| import re |
| from collections import Counter, defaultdict |
| from pathlib import Path |
|
|
| CANONICAL_STARTS = {"ATG", "GTG", "TTG", "CTG"} |
| STOP_CODONS = {"TAA", "TAG", "TGA"} |
| COMP = str.maketrans("ACGTacgt", "TGCAtgca") |
|
|
|
|
| def revcomp(s: str) -> str: |
| return s.translate(COMP)[::-1] |
|
|
|
|
| def has_full_orf(seq: str, min_aa: int = 50) -> bool: |
| """True if any of 6 frames contains a full ORF (start codon → stop codon) |
| of >= min_aa codons. Excludes ORFs that run off either sequence end.""" |
| for orient in (seq, revcomp(seq)): |
| for frame in range(3): |
| sub = orient[frame:] |
| n = len(sub) - (len(sub) % 3) |
| codons = [sub[i:i + 3] for i in range(0, n, 3)] |
| in_orf = False |
| cur = 0 |
| for c in codons: |
| if not in_orf: |
| if c in CANONICAL_STARTS: |
| in_orf = True |
| cur = 1 |
| else: |
| if c in STOP_CODONS: |
| if cur >= min_aa: |
| return True |
| in_orf = False |
| cur = 0 |
| else: |
| cur += 1 |
| return False |
|
|
|
|
| def gc_content(seq: str) -> float: |
| s = seq.upper() |
| gc = sum(1 for c in s if c in "GC") |
| acgt = sum(1 for c in s if c in "ACGT") |
| return gc / acgt if acgt else 0.0 |
|
|
|
|
| def class_slug(s: str) -> str: |
| return re.sub(r"[^A-Za-z0-9]+", "_", str(s)).strip("_") |
|
|
|
|
| def parse_score(desc: str) -> float | None: |
| if not isinstance(desc, str): |
| return None |
| m = re.search(r"score=(-?\d+\.\d+)", desc) |
| return float(m.group(1)) if m else None |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--in-json", type=Path, |
| default=Path("/home/ror25cal/MGnify/syngenome_db.json")) |
| ap.add_argument("--out-dir", type=Path, |
| default=Path("/home/ror25cal/MGnify/data/targeted_jsonl/syngenome")) |
| ap.add_argument("--target-n", type=int, default=8000) |
| ap.add_argument("--small-class-threshold", type=int, default=1000, |
| help="classes with ≤this many records are taken in full") |
| ap.add_argument("--seed", type=int, default=42) |
| args = ap.parse_args() |
|
|
| args.out_dir.mkdir(parents=True, exist_ok=True) |
| rng = random.Random(args.seed) |
|
|
| print(f"Loading {args.in_json} ...") |
| with open(args.in_json) as f: |
| d = json.load(f) |
| amrs = [r for r in d if r.get("functional_class") == "AMR"] |
| print(f" AMR records: {len(amrs)}") |
|
|
| |
| print("Filtering to full-ORF records (start + stop, ≥50 aa) ...") |
| full_orf = [r for r in amrs if has_full_orf(r["sequence"].upper())] |
| print(f" {len(full_orf)} records pass ({100*len(full_orf)/len(amrs):.1f}%)") |
|
|
| |
| by_class: dict[str, list[dict]] = defaultdict(list) |
| for r in full_orf: |
| by_class[r.get("secondary_label") or "UNKNOWN"].append(r) |
| print(f" drug classes: {len(by_class)}") |
|
|
| small_classes = {c: rs for c, rs in by_class.items() if len(rs) <= args.small_class_threshold} |
| large_classes = {c: rs for c, rs in by_class.items() if len(rs) > args.small_class_threshold} |
|
|
| chosen: list[tuple[str, dict]] = [] |
| for c, rs in small_classes.items(): |
| for r in rs: |
| chosen.append((c, r)) |
| n_taken_from_small = len(chosen) |
| print(f" taking all from small classes (≤{args.small_class_threshold}): {n_taken_from_small} records " |
| f"across {len(small_classes)} classes") |
|
|
| remaining_budget = args.target_n - n_taken_from_small |
| print(f" remaining budget for large classes: {remaining_budget}") |
| |
| total_large = sum(len(rs) for rs in large_classes.values()) |
| for c, rs in large_classes.items(): |
| frac = len(rs) / total_large |
| n_take = min(int(round(remaining_budget * frac)), len(rs)) |
| sampled = rng.sample(rs, n_take) |
| for r in sampled: |
| chosen.append((c, r)) |
| print(f" {c}: sampling {n_take} of {len(rs)}") |
|
|
| print(f"\nTotal chosen: {len(chosen)} records") |
|
|
| |
| by_drug_out: dict[str, list[dict]] = defaultdict(list) |
| for c, r in chosen: |
| seq = r["sequence"].upper() |
| slug = class_slug(c) |
| rec = { |
| "region_id": f"{r['seq_hash']}_AMR_SYN", |
| "is_positive": True, |
| "label": "AMR", |
| "label_class": c, |
| "label_subclass": r.get("primary_label"), |
| "gene_symbol": r.get("gene_name"), |
| "product_name": r.get("product_name"), |
| "organism": r.get("organism"), |
| "source_db": "SynGenome", |
| "source_accession": r.get("source_accession"), |
| "evidence": r.get("evidence"), |
| "evo2_score": parse_score(r.get("description", "")), |
| "functional_class": r.get("functional_class"), |
| "primary_label": r.get("primary_label"), |
| "secondary_label": r.get("secondary_label"), |
| "pmids": r.get("pmids"), |
| "cds_length": len(seq), |
| "gc_content": round(gc_content(seq), 4), |
| "mag_id": slug, |
| "random_seed": args.seed, |
| "sequence": seq, |
| } |
| |
| clean = {} |
| for k, v in rec.items(): |
| if v is None: |
| clean[k] = None |
| elif isinstance(v, float) and v != v: |
| clean[k] = None |
| else: |
| clean[k] = v |
| by_drug_out[slug].append(clean) |
|
|
| summary = [] |
| for slug, recs in sorted(by_drug_out.items()): |
| out_path = args.out_dir / f"{slug}.jsonl" |
| with open(out_path, "w") as f: |
| for r in recs: |
| f.write(json.dumps(r) + "\n") |
| summary.append((slug, len(recs))) |
| print(f" → {out_path.name} ({len(recs)} records)") |
|
|
| print(f"\nTotal: {sum(n for _, n in summary)} records across {len(summary)} files") |
| |
| summary_path = args.out_dir / "_sample_summary.csv" |
| with open(summary_path, "w") as f: |
| f.write("drug_class_slug,n_records\n") |
| for slug, n in summary: |
| f.write(f"{slug},{n}\n") |
| print(f"Summary: {summary_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|