| """ |
| Extract per-MAG JSONL of (positive, matched-negative) pairs for the targeted Evo2/SAE run. |
| |
| Per the plan in targeted_pipeline_plan.md: |
| - Positives: AMR / STRESS / VIRULENCE hits (from AMRFinderPlus TSV) |
| - Negatives: same-MAG CDSs, length ±20%, excluding AMR/STRESS/VIRULENCE/BGC/CRISPR/defence |
| (mobilome included by default, fallback to strict-only if needed, then ±50%) |
| - 1:1 pairing, sampling without replacement, seed=42 |
| - Forward strand always (matches Goodfire's reference notebook) |
| - Sanity checks run on every record (length, start codon by strand, boundary clamps) |
| |
| Output: one JSONL per MAG at <out_dir>/<mag_id>.jsonl, interleaved positive then negative. |
| """ |
| import argparse |
| import json |
| import random |
| import re |
| from collections import Counter |
| from dataclasses import dataclass, asdict, field |
| from pathlib import Path |
| from typing import Optional |
|
|
| import pandas as pd |
| from pyfaidx import Fasta |
|
|
|
|
| CANONICAL_STARTS = {"ATG", "GTG", "TTG", "CTG"} |
| COMP = str.maketrans("ACGTNacgtn", "TGCANtgcan") |
| def revcomp(s: str) -> str: |
| return s.translate(COMP)[::-1] |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class CDS: |
| locus_tag: str |
| contig: str |
| start: int |
| end: int |
| strand: str |
| partial: str = "00" |
|
|
| @property |
| def length(self) -> int: |
| return self.end - self.start + 1 |
|
|
|
|
| def parse_master_gff(path: Path) -> list[CDS]: |
| """Parse Prodigal master GFF → list of CDS records.""" |
| cds_list = [] |
| if not path.exists(): |
| return cds_list |
| for line in path.read_text().splitlines(): |
| if not line or line.startswith("#"): |
| continue |
| cols = line.split("\t") |
| if len(cols) < 9 or cols[2] != "CDS": |
| continue |
| contig, _, _, start, end, _, strand, _, attrs = cols |
| m = re.search(r"locus_tag=([^;]+)", attrs) |
| if not m: |
| continue |
| locus_tag = m.group(1) |
| partial_m = re.search(r"partial=(\d{2})", attrs) |
| partial = partial_m.group(1) if partial_m else "00" |
| cds_list.append(CDS(locus_tag=locus_tag, contig=contig, |
| start=int(start), end=int(end), strand=strand, partial=partial)) |
| return cds_list |
|
|
|
|
| def parse_interval_gff(path: Path) -> list[tuple[str, int, int]]: |
| """Parse a GFF (antismash, crispr, defense_finder, mobilome) → list of (contig, start, end) |
| for the major 'region'-style features. Skips sub-features like 'gene', 'CRISPRdr' etc.""" |
| intervals = [] |
| if not path.exists(): |
| return intervals |
| region_types = { |
| |
| "region", |
| |
| "CRISPR", |
| |
| "Defense system", "Antidefense system", |
| |
| "plasmid", "viral_sequence", "prophage", "integron", |
| "conjugative_integron", "phage_plasmid", |
| "insertion_sequence", "terminal_inverted_repeat_element", "attC_site", |
| } |
| for line in path.read_text().splitlines(): |
| if not line or line.startswith("#"): |
| continue |
| cols = line.split("\t") |
| if len(cols) < 5: |
| continue |
| if cols[2] in region_types: |
| intervals.append((cols[0], int(cols[3]), int(cols[4]))) |
| return intervals |
|
|
|
|
| def cds_overlaps_any_interval(cds: CDS, intervals: list[tuple[str, int, int]]) -> bool: |
| """1-indexed inclusive interval overlap check.""" |
| for contig, s, e in intervals: |
| if cds.contig == contig and cds.end >= s and cds.start <= e: |
| return True |
| return False |
|
|
|
|
| |
| |
| |
|
|
| def extract_region( |
| fa: Fasta, |
| contig: str, |
| gene_start: int, |
| gene_end: int, |
| strand: str, |
| flank: int, |
| ) -> dict: |
| """Extract gene + flank, forward-strand. Returns dict with sequence + ext coordinates. |
| Coordinates are 1-indexed inclusive, matching pyfaidx's slicing convention.""" |
| contig_len = len(fa[contig]) |
| ext_start = max(1, gene_start - flank) |
| ext_end = min(contig_len, gene_end + flank) |
| |
| seq = str(fa.get_seq(contig, ext_start, ext_end).seq) |
| |
| expected_len = ext_end - ext_start + 1 |
| assert len(seq) == expected_len, ( |
| f"length mismatch: contig={contig} ext={ext_start}-{ext_end} got {len(seq)} expected {expected_len}" |
| ) |
| return { |
| "ext_start": ext_start, "ext_end": ext_end, "contig_len": contig_len, "sequence": seq, |
| } |
|
|
|
|
| def check_start_codon( |
| sequence: str, gene_start: int, ext_start: int, gene_end: int, strand: str, |
| partial: str = "00", |
| ) -> Optional[str]: |
| """Return the gene's first 3 coding-strand bases, or None for partial-no-start genes. |
| Used for sanity checks: most full genes should start with ATG/GTG/TTG/CTG.""" |
| if strand == "+": |
| if partial in {"10", "11"}: |
| return None |
| offset = gene_start - ext_start |
| return sequence[offset : offset + 3].upper() |
| else: |
| if partial in {"01", "11"}: |
| return None |
| |
| |
| if partial in {"10", "11"}: |
| return None |
| offset_end = gene_end - ext_start + 1 |
| offset_start = offset_end - 3 |
| forward_chunk = sequence[offset_start : offset_end].upper() |
| return revcomp(forward_chunk) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class Stats: |
| mags_processed: int = 0 |
| positives_found: int = 0 |
| pairs_emitted: int = 0 |
| no_match: int = 0 |
| fallback_used: dict = field(default_factory=lambda: {"strict_no_mob_l20_g5": 0, |
| "strict_w_mob_l20_g5": 0, |
| "strict_w_mob_l20_g10": 0, |
| "strict_w_mob_l50_g10": 0}) |
| start_codon_pass_plus: int = 0 |
| start_codon_total_plus: int = 0 |
| start_codon_pass_minus: int = 0 |
| start_codon_total_minus: int = 0 |
| length_check_failures: int = 0 |
|
|
|
|
| def gc_content(seq: str) -> float: |
| """GC fraction over A/C/G/T (Ns excluded). Returns 0.0 for empty input.""" |
| if not seq: |
| return 0.0 |
| s = seq.upper() |
| gc = sum(1 for c in s if c == "G" or c == "C") |
| acgt = sum(1 for c in s if c in "ACGT") |
| return gc / acgt if acgt else 0.0 |
|
|
|
|
| def extract_for_mag( |
| mag_dir: Path, |
| mag_id: str, |
| out_root: Path, |
| |
| split_by_label: bool, |
| flank: int, |
| length_tol_strict: float, |
| length_tol_relaxed: float, |
| gc_tol_strict: float, |
| gc_tol_relaxed: float, |
| seed: int, |
| max_pairs: Optional[int], |
| stats: Stats, |
| ) -> int: |
| """Extract positives + matched negatives for one MAG. Writes JSONL to out_path. |
| Returns number of records emitted (positive + negative count combined).""" |
| fna = mag_dir / f"{mag_id}.fna" |
| gff = mag_dir / f"{mag_id}.gff" |
| amr_tsv = mag_dir / f"{mag_id}_amrfinderplus.tsv" |
| if not (fna.exists() and gff.exists() and amr_tsv.exists()): |
| print(f" [{mag_id}] missing required files, skipping") |
| return 0 |
|
|
| fa = Fasta(str(fna)) |
| all_cds = parse_master_gff(gff) |
| cds_by_locus = {c.locus_tag: c for c in all_cds} |
|
|
| |
| cds_gc: dict[str, float] = {} |
| for c in all_cds: |
| try: |
| cds_seq = str(fa.get_seq(c.contig, c.start, c.end).seq) |
| except (KeyError, ValueError): |
| continue |
| cds_gc[c.locus_tag] = gc_content(cds_seq) |
|
|
| |
| try: |
| amr_df = pd.read_csv(amr_tsv, sep="\t") |
| except Exception: |
| return 0 |
| if "Element type" not in amr_df.columns: |
| return 0 |
| positive_rows = amr_df[amr_df["Element type"].isin(["AMR", "STRESS", "VIRULENCE"])] |
| if len(positive_rows) == 0: |
| return 0 |
|
|
| |
| bgc_iv = parse_interval_gff(mag_dir / f"{mag_id}_antismash.gff") |
| crispr_iv = parse_interval_gff(mag_dir / f"{mag_id}_crisprcasfinder.gff") |
| defense_iv = parse_interval_gff(mag_dir / f"{mag_id}_defense_finder.gff") |
| mobilome_iv = parse_interval_gff(mag_dir / f"{mag_id}_mobilome.gff") |
| label_intervals_strict = bgc_iv + crispr_iv + defense_iv |
|
|
| |
| cds_in_mobilome = {c.locus_tag for c in all_cds if cds_overlaps_any_interval(c, mobilome_iv)} |
| cds_in_strict_excluded = {c.locus_tag for c in all_cds |
| if cds_overlaps_any_interval(c, label_intervals_strict)} |
| positive_locus_tags = set() |
| for _, row in positive_rows.iterrows(): |
| pi = row.get("Protein identifier") or row.get("Protein id") |
| if pi and pi in cds_by_locus: |
| positive_locus_tags.add(pi) |
|
|
| rng = random.Random(seed) |
| used_negs: set[str] = set() |
|
|
| records = [] |
| |
| shuffled = positive_rows.sample(frac=1, random_state=seed) |
|
|
| for _, prow in shuffled.iterrows(): |
| if max_pairs is not None and stats.pairs_emitted >= max_pairs: |
| break |
| |
| pi = prow.get("Protein identifier") or prow.get("Protein id") |
| if not pi or pi not in cds_by_locus: |
| continue |
| pos_cds = cds_by_locus[pi] |
| pos_len = pos_cds.length |
|
|
| |
| pos_ext = extract_region(fa, pos_cds.contig, pos_cds.start, pos_cds.end, pos_cds.strand, flank) |
|
|
| |
| |
| |
| |
| |
| candidates = None |
| fallback_used = None |
| pos_gc = cds_gc.get(pos_cds.locus_tag, 0.0) |
| for excluded, len_tol, gc_t, tag in [ |
| (positive_locus_tags | cds_in_strict_excluded | cds_in_mobilome, length_tol_strict, gc_tol_strict, "strict_no_mob_l20_g5"), |
| (positive_locus_tags | cds_in_strict_excluded, length_tol_strict, gc_tol_strict, "strict_w_mob_l20_g5"), |
| (positive_locus_tags | cds_in_strict_excluded, length_tol_strict, gc_tol_relaxed, "strict_w_mob_l20_g10"), |
| (positive_locus_tags | cds_in_strict_excluded, length_tol_relaxed, gc_tol_relaxed, "strict_w_mob_l50_g10"), |
| ]: |
| pool = [ |
| c for c in all_cds |
| if c.locus_tag not in excluded |
| and c.locus_tag not in used_negs |
| and c.locus_tag in cds_gc |
| and abs(c.length - pos_len) / pos_len <= len_tol |
| and abs(cds_gc[c.locus_tag] - pos_gc) <= gc_t |
| ] |
| if pool: |
| candidates = pool; fallback_used = tag; break |
|
|
| if not candidates: |
| |
| pos_record = build_record( |
| pos_cds, prow, pos_ext, is_positive=True, paired_with=None, |
| in_mobilome=(pos_cds.locus_tag in cds_in_mobilome), |
| fallback_used=None, extract_status="no_matching_negative", |
| seed=seed, mag_id=mag_id, |
| gc_content_val=cds_gc.get(pos_cds.locus_tag), |
| ) |
| records.append(pos_record) |
| stats.no_match += 1 |
| stats.positives_found += 1 |
| _track_start_codon(pos_record, stats) |
| continue |
|
|
| neg_cds = rng.choice(candidates) |
| used_negs.add(neg_cds.locus_tag) |
| neg_ext = extract_region(fa, neg_cds.contig, neg_cds.start, neg_cds.end, neg_cds.strand, flank) |
|
|
| pos_record = build_record( |
| pos_cds, prow, pos_ext, is_positive=True, paired_with=neg_cds.locus_tag, |
| in_mobilome=(pos_cds.locus_tag in cds_in_mobilome), |
| fallback_used=fallback_used, extract_status="ok", |
| seed=seed, mag_id=mag_id, |
| gc_content_val=cds_gc.get(pos_cds.locus_tag), |
| ) |
| neg_record = build_record( |
| neg_cds, None, neg_ext, is_positive=False, paired_with=pos_cds.locus_tag, |
| in_mobilome=(neg_cds.locus_tag in cds_in_mobilome), |
| fallback_used=None, extract_status="ok", |
| seed=seed, mag_id=mag_id, |
| gc_content_val=cds_gc.get(neg_cds.locus_tag), |
| ) |
| |
| |
| neg_record["label_class"] = pos_record["label_class"] |
| neg_record["label_subclass"] = pos_record["label_subclass"] |
| records.append(pos_record); records.append(neg_record) |
| stats.fallback_used[fallback_used] += 1 |
| stats.pairs_emitted += 1 |
| stats.positives_found += 1 |
| _track_start_codon(pos_record, stats) |
| _track_start_codon(neg_record, stats) |
|
|
| |
| if not records: |
| return 0 |
| if split_by_label: |
| |
| buckets: dict[str, list[dict]] = {} |
| for r in records: |
| folder = "MISC" if not r["is_positive"] else r["label"] |
| buckets.setdefault(folder, []).append(r) |
| for folder, recs in buckets.items(): |
| out_dir = out_root / folder |
| out_dir.mkdir(parents=True, exist_ok=True) |
| (out_dir / f"{mag_id}.jsonl").write_text( |
| "\n".join(json.dumps(r) for r in recs) + "\n" |
| ) |
| else: |
| out_root.mkdir(parents=True, exist_ok=True) |
| (out_root / f"{mag_id}.jsonl").write_text( |
| "\n".join(json.dumps(r) for r in records) + "\n" |
| ) |
| return len(records) |
|
|
|
|
| def build_record( |
| cds: CDS, prow, ext: dict, *, is_positive: bool, paired_with: Optional[str], |
| in_mobilome: bool, fallback_used: Optional[str], extract_status: str, |
| seed: int, mag_id: str, |
| gc_content_val: Optional[float] = None, |
| ) -> dict: |
| """Construct a single JSONL record from CDS + AMRFinderPlus pandas Series + extracted region.""" |
| label = "negative" |
| label_class = label_subclass = gene_symbol = None |
| pct_id = None |
| def _clean(v): |
| """Convert pandas NaN (and similar) to None; passes other values through.""" |
| if v is None: return None |
| |
| if isinstance(v, float) and v != v: return None |
| return v |
|
|
| if is_positive and prow is not None: |
| et = _clean(prow.get("Element type")) |
| label = et if et in ("AMR", "STRESS", "VIRULENCE") else "AMR" |
| label_class = _clean(prow.get("Class")) |
| label_subclass = _clean(prow.get("Subclass")) |
| gene_symbol = _clean(prow.get("Gene symbol")) |
| pct_id = _clean(prow.get("% Identity to reference sequence")) |
| |
| if pct_id is not None: |
| try: pct_id = float(pct_id) |
| except Exception: pct_id = None |
|
|
| return { |
| "mag_id": mag_id, |
| "locus_tag": cds.locus_tag, |
| "region_id": f"{cds.locus_tag}_{label}", |
| "is_positive": is_positive, |
| "label": label, |
| "label_class": label_class, |
| "label_subclass": label_subclass, |
| "gene_symbol": gene_symbol, |
| "pct_identity_to_ref": pct_id, |
| "contig": cds.contig, |
| "gene_start": cds.start, |
| "gene_end": cds.end, |
| "strand": cds.strand, |
| "cds_length": cds.length, |
| "partial": cds.partial, |
| "ext_start": ext["ext_start"], |
| "ext_end": ext["ext_end"], |
| "contig_len": ext["contig_len"], |
| "paired_with": paired_with, |
| "cds_in_mobilome": in_mobilome, |
| "gc_content": round(gc_content_val, 4) if gc_content_val is not None else None, |
| "negative_pool_fallback": fallback_used, |
| "extract_status": extract_status, |
| "random_seed": seed, |
| "sequence": ext["sequence"], |
| } |
|
|
|
|
| def _track_start_codon(record: dict, stats: Stats): |
| """Update start-codon pass-rate stats (used to detect strand-handling bugs).""" |
| codon = check_start_codon( |
| record["sequence"], record["gene_start"], record["ext_start"], |
| record["gene_end"], record["strand"], record.get("partial", "00"), |
| ) |
| if codon is None: |
| return |
| if record["strand"] == "+": |
| stats.start_codon_total_plus += 1 |
| if codon in CANONICAL_STARTS: |
| stats.start_codon_pass_plus += 1 |
| else: |
| stats.start_codon_total_minus += 1 |
| if codon in CANONICAL_STARTS: |
| stats.start_codon_pass_minus += 1 |
|
|
|
|
| def cross_tool_check(out_dir: Path, mag_dirs: dict, n_samples: int = 5) -> str: |
| """Compare pyfaidx output to samtools faidx output for n_samples random records. |
| Returns a status string. Skips silently if samtools not available.""" |
| import shutil, subprocess |
| if not shutil.which("samtools"): |
| return "skipped (samtools not on PATH)" |
| |
| records = [] |
| for jsonl in out_dir.glob("*.jsonl"): |
| for line in jsonl.read_text().splitlines(): |
| if line.strip(): |
| records.append(json.loads(line)) |
| if len(records) < 1: |
| return "no records to check" |
| rng = random.Random(0) |
| sample = rng.sample(records, min(n_samples, len(records))) |
| for r in sample: |
| mag_id = r["mag_id"] |
| fna_path = mag_dirs[mag_id] / f"{mag_id}.fna" |
| region = f"{r['contig']}:{r['ext_start']}-{r['ext_end']}" |
| result = subprocess.run(["samtools", "faidx", str(fna_path), region], |
| capture_output=True, text=True, check=False) |
| if result.returncode != 0: |
| return f"samtools failed on {region}: {result.stderr[:200]}" |
| |
| lines = result.stdout.splitlines() |
| samtools_seq = "".join(l for l in lines if not l.startswith(">")) |
| if samtools_seq.upper() != r["sequence"].upper(): |
| return (f"MISMATCH on {region}: pyfaidx len={len(r['sequence'])}, " |
| f"samtools len={len(samtools_seq)}") |
| return f"passed ({len(sample)} records cross-tool verified)" |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--skin-dir", "--catalogue-dir", type=Path, |
| default=Path("/home/ror25cal/MGnify/data/human-skin/species_catalogue"), |
| dest="skin_dir", |
| help="catalogue's species_catalogue directory (skin or chicken-gut)") |
| ap.add_argument("--out-dir", type=Path, |
| default=Path("/home/ror25cal/MGnify/data/targeted_jsonl/skin")) |
| ap.add_argument("--split-by-label", action="store_true", |
| help="write to {out_dir}/{AMR|VIRULENCE|STRESS|MISC}/{mag_id}.jsonl " |
| "(MISC = negatives). Default: single {out_dir}/{mag_id}.jsonl per MAG.") |
| ap.add_argument("--mag-ids", nargs="*", default=None, |
| help="restrict to specific MAG IDs (default: all 579 skin MAGs)") |
| ap.add_argument("--top-csv", type=Path, default=None, |
| help="if set, restrict to MAGs listed in this CSV file (col 'mag_id')") |
| ap.add_argument("--max-pairs", type=int, default=None, |
| help="for testing: stop after this many positive-negative pairs total") |
| ap.add_argument("--flank", type=int, default=2000) |
| ap.add_argument("--length-tol-strict", type=float, default=0.20) |
| ap.add_argument("--length-tol-relaxed", type=float, default=0.50) |
| ap.add_argument("--gc-tol-strict", type=float, default=0.05, |
| help="absolute GC-fraction tolerance for paired negative (default ±0.05)") |
| ap.add_argument("--gc-tol-relaxed", type=float, default=0.10, |
| help="relaxed GC tolerance used in later fallbacks (default ±0.10)") |
| ap.add_argument("--seed", type=int, default=42) |
| args = ap.parse_args() |
|
|
| args.out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| if args.mag_ids: |
| mag_ids = args.mag_ids |
| elif args.top_csv: |
| mag_ids = pd.read_csv(args.top_csv)["mag_id"].tolist() |
| else: |
| mag_ids = sorted(p.name for p in args.skin_dir.glob("*/MGYG*") if p.is_dir()) |
|
|
| print(f"target: {len(mag_ids)} MAG(s)") |
| print(f"flank={args.flank}, length_tol_strict=±{args.length_tol_strict*100:.0f}%, " |
| f"relaxed=±{args.length_tol_relaxed*100:.0f}%, seed={args.seed}") |
| if args.max_pairs: |
| print(f"max_pairs={args.max_pairs} (testing mode)") |
| print() |
|
|
| stats = Stats() |
| mag_dirs_map = {} |
|
|
| for mag_id in mag_ids: |
| if args.max_pairs and stats.pairs_emitted >= args.max_pairs: |
| break |
| prefix = mag_id[:11] |
| mag_dir = args.skin_dir / prefix / mag_id / "genome" |
| if not mag_dir.exists(): |
| print(f" {mag_id}: dir not found at {mag_dir}, skipping") |
| continue |
| mag_dirs_map[mag_id] = mag_dir |
| n_records = extract_for_mag( |
| mag_dir, mag_id, args.out_dir, |
| split_by_label=args.split_by_label, |
| flank=args.flank, |
| length_tol_strict=args.length_tol_strict, |
| length_tol_relaxed=args.length_tol_relaxed, |
| gc_tol_strict=args.gc_tol_strict, |
| gc_tol_relaxed=args.gc_tol_relaxed, |
| seed=args.seed, |
| max_pairs=(args.max_pairs - stats.pairs_emitted) if args.max_pairs else None, |
| stats=stats, |
| ) |
| if n_records and stats.mags_processed % 50 == 0: |
| print(f" ({stats.mags_processed} MAGs processed, {stats.pairs_emitted} pairs so far)") |
| stats.mags_processed += 1 |
|
|
| print() |
| print("=" * 65) |
| print("SUMMARY") |
| print("=" * 65) |
| print(f"MAGs processed: {stats.mags_processed}") |
| print(f"Positives found: {stats.positives_found}") |
| print(f"Pairs emitted: {stats.pairs_emitted}") |
| print(f"No-match positives: {stats.no_match}") |
| print(f"\nFallback usage (negative selection):") |
| for k, v in stats.fallback_used.items(): |
| print(f" {k:25s} {v}") |
| print(f"\nStart-codon sanity check (canonical starts: ATG/GTG/TTG/CTG):") |
| if stats.start_codon_total_plus > 0: |
| rate_p = stats.start_codon_pass_plus / stats.start_codon_total_plus * 100 |
| print(f" + strand: {stats.start_codon_pass_plus}/{stats.start_codon_total_plus} ({rate_p:.1f}%)") |
| if stats.start_codon_total_minus > 0: |
| rate_m = stats.start_codon_pass_minus / stats.start_codon_total_minus * 100 |
| print(f" - strand: {stats.start_codon_pass_minus}/{stats.start_codon_total_minus} ({rate_m:.1f}%)") |
| if stats.start_codon_total_plus and stats.start_codon_total_minus: |
| diff = abs(rate_p - rate_m) |
| if rate_p < 80 or rate_m < 80 or diff > 15: |
| print(f" ⚠ WARNING: strand pass-rates suggest a strand-handling bug") |
| else: |
| print(f" ✓ strand handling looks correct") |
|
|
| |
| print(f"\nCross-tool sanity check (pyfaidx vs samtools): " |
| f"{cross_tool_check(args.out_dir, mag_dirs_map, n_samples=10)}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|