| """ |
| Extract VFDB virulence positives + species-matched MAG-derived negatives. |
| |
| Companion to `extract_targeted.py`. See `vfdb_negative_pipeline_plan.md` for |
| design rationale. |
| |
| Phase 0: filter VFDB to species present in local MAG catalogue |
| Phase 1: build per-species candidate negative pool from MAG annotations |
| Phase 2: match VFDB positives 1:1 with negatives (length+GC, fallback hierarchy) |
| Phase 3: emit one JSONL per species with MGnify-compatible schema |
| |
| Output: data/targeted_jsonl/vfdb/<species_slug>.jsonl |
| """ |
| import argparse |
| import json |
| import random |
| import re |
| from collections import defaultdict |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional |
|
|
| import pandas as pd |
| from pyfaidx import Fasta |
|
|
| |
| from extract_targeted import ( |
| CDS, |
| CANONICAL_STARTS, |
| revcomp, |
| parse_master_gff, |
| parse_interval_gff, |
| cds_overlaps_any_interval, |
| gc_content, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def vfdb_species(org: str) -> Optional[str]: |
| """Parse VFDB organism string → 'Genus species'. |
| Drops rows whose 2nd token is a qualifier (sp./subsp./str./strain).""" |
| if not isinstance(org, str): |
| return None |
| parts = org.strip().split() |
| if len(parts) < 2: |
| return None |
| g = parts[0] |
| sp = parts[1].lower() |
| if sp in ("sp.", "subsp.", "str.", "strain", "virus"): |
| return None |
| return f"{g} {sp}" |
|
|
|
|
| def gtdb_species(lineage: str) -> Optional[str]: |
| """Parse GTDB lineage → 'Genus species', stripping GTDB suffixes (_A/_B/_E).""" |
| if not isinstance(lineage, str): |
| return None |
| for part in lineage.split(";"): |
| part = part.strip() |
| if part.startswith("s__"): |
| v = part[3:] |
| if not v: |
| return None |
| toks = v.split() |
| if len(toks) >= 2: |
| g = toks[0].split("_")[0] |
| return f"{g} {toks[1]}" |
| return None |
| return None |
|
|
|
|
| def gtdb_genus(lineage: str) -> Optional[str]: |
| if not isinstance(lineage, str): |
| return None |
| for part in lineage.split(";"): |
| part = part.strip() |
| if part.startswith("g__"): |
| v = part[3:] |
| return v.split("_")[0] if v else None |
| return None |
|
|
|
|
| def gtdb_family(lineage: str) -> Optional[str]: |
| if not isinstance(lineage, str): |
| return None |
| for part in lineage.split(";"): |
| part = part.strip() |
| if part.startswith("f__"): |
| v = part[3:] |
| return v if v else None |
| return None |
|
|
|
|
| def species_slug(species: str) -> str: |
| return re.sub(r"[^A-Za-z0-9]+", "_", species).strip("_") |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class CandidateCDS: |
| mag_id: str |
| locus_tag: str |
| contig: str |
| start: int |
| end: int |
| strand: str |
| cds_length: int |
| gc: float |
| gene_symbol: Optional[str] |
| product: Optional[str] |
| in_mobilome: bool |
| species: str |
| genus: Optional[str] |
| family: Optional[str] |
| catalogue: str |
|
|
|
|
| def parse_gff_attrs(gff_path: Path) -> dict[str, dict[str, str]]: |
| """Extract gene_symbol/product per locus_tag from the master GFF.""" |
| out: dict[str, dict[str, str]] = {} |
| if not gff_path.exists(): |
| return out |
| for line in gff_path.read_text().splitlines(): |
| if not line or line.startswith("#"): |
| continue |
| cols = line.split("\t") |
| if len(cols) < 9 or cols[2] != "CDS": |
| continue |
| attrs = cols[8] |
| m_lt = re.search(r"locus_tag=([^;]+)", attrs) |
| if not m_lt: |
| continue |
| lt = m_lt.group(1) |
| gene = re.search(r"(?:^|;)gene=([^;]+)", attrs) |
| product = re.search(r"product=([^;]+)", attrs) |
| out[lt] = { |
| "gene": gene.group(1) if gene else None, |
| "product": product.group(1) if product else None, |
| } |
| return out |
|
|
|
|
| def build_mag_pool( |
| mag_dir: Path, |
| mag_id: str, |
| species: str, |
| genus: Optional[str], |
| family: Optional[str], |
| catalogue: str, |
| vfdb_seq_hashes: Optional[set] = None, |
| ) -> tuple[list[CandidateCDS], Fasta]: |
| """Parse one MAG; return (candidate-list, open Fasta) so seqs can be extracted later. |
| Excludes positives (AMR/STRESS/VIRULENCE locus tags), antiSMASH/CRISPR/defense overlaps, |
| partial CDSs, and any CDS whose coding-strand sequence hashes to a known VFDB virulence |
| entry (catches cases AMRFinderPlus missed but VFDB knows about). Mobilome flagged but |
| not excluded.""" |
| fna = mag_dir / f"{mag_id}.fna" |
| gff = mag_dir / f"{mag_id}.gff" |
| amr_tsv = mag_dir / f"{mag_id}_amrfinderplus.tsv" |
| if not (fna.exists() and gff.exists()): |
| return [], None |
|
|
| fa = Fasta(str(fna)) |
| all_cds = parse_master_gff(gff) |
| attrs = parse_gff_attrs(gff) |
|
|
| |
| positive_locus_tags: set[str] = set() |
| if amr_tsv.exists(): |
| try: |
| amr_df = pd.read_csv(amr_tsv, sep="\t") |
| if "Element type" in amr_df.columns: |
| pos_rows = amr_df[amr_df["Element type"].isin(["AMR", "STRESS", "VIRULENCE"])] |
| col = "Protein identifier" if "Protein identifier" in pos_rows.columns else "Protein id" |
| if col in pos_rows.columns: |
| positive_locus_tags = set(pos_rows[col].dropna().astype(str)) |
| except Exception: |
| pass |
|
|
| |
| bgc_iv = parse_interval_gff(mag_dir / f"{mag_id}_antismash.gff") |
| crispr_iv = parse_interval_gff(mag_dir / f"{mag_id}_crisprcasfinder.gff") |
| defense_iv = parse_interval_gff(mag_dir / f"{mag_id}_defense_finder.gff") |
| mobilome_iv = parse_interval_gff(mag_dir / f"{mag_id}_mobilome.gff") |
| strict_iv = bgc_iv + crispr_iv + defense_iv |
|
|
| candidates: list[CandidateCDS] = [] |
| for c in all_cds: |
| if c.partial != "00": |
| continue |
| if c.locus_tag in positive_locus_tags: |
| continue |
| if cds_overlaps_any_interval(c, strict_iv): |
| continue |
| try: |
| seq_fwd = str(fa.get_seq(c.contig, c.start, c.end).seq).upper() |
| except (KeyError, ValueError): |
| continue |
| |
| |
| coding = revcomp(seq_fwd) if c.strand == "-" else seq_fwd |
| if vfdb_seq_hashes is not None and coding in vfdb_seq_hashes: |
| continue |
| gc = gc_content(seq_fwd) |
| in_mob = cds_overlaps_any_interval(c, mobilome_iv) |
| a = attrs.get(c.locus_tag, {}) |
| candidates.append(CandidateCDS( |
| mag_id=mag_id, locus_tag=c.locus_tag, contig=c.contig, |
| start=c.start, end=c.end, strand=c.strand, cds_length=c.length, |
| gc=gc, gene_symbol=a.get("gene"), product=a.get("product"), |
| in_mobilome=in_mob, species=species, genus=genus, family=family, |
| catalogue=catalogue, |
| )) |
| return candidates, fa |
|
|
|
|
| def get_coding_strand_seq(fa: Fasta, c: CandidateCDS) -> str: |
| """Gene-only DNA in coding orientation (revcomp if minus strand). Uppercase.""" |
| seq = str(fa.get_seq(c.contig, c.start, c.end).seq).upper() |
| if c.strand == "-": |
| seq = revcomp(seq) |
| return seq |
|
|
|
|
| |
| |
| |
|
|
| FALLBACK_LEVELS = [ |
| |
| (False, 0.20, 0.05, "strict_no_mob_l20_g5"), |
| (True, 0.20, 0.05, "strict_w_mob_l20_g5"), |
| (True, 0.20, 0.10, "strict_w_mob_l20_g10"), |
| (True, 0.50, 0.10, "strict_w_mob_l50_g10"), |
| ] |
|
|
|
|
| def pick_negative( |
| pos_len: int, |
| pos_gc: float, |
| species_pool: list[CandidateCDS], |
| used_locus_tags: set[str], |
| rng: random.Random, |
| ) -> tuple[Optional[CandidateCDS], Optional[str]]: |
| """Walk fallback hierarchy; return (chosen CDS, fallback_label) or (None, None).""" |
| for include_mob, len_tol, gc_tol, label in FALLBACK_LEVELS: |
| pool = [ |
| c for c in species_pool |
| if c.locus_tag not in used_locus_tags |
| and (include_mob or not c.in_mobilome) |
| and abs(c.cds_length - pos_len) / pos_len <= len_tol |
| and abs(c.gc - pos_gc) <= gc_tol |
| ] |
| if pool: |
| return rng.choice(pool), label |
| return None, None |
|
|
|
|
| |
| |
| |
|
|
| def positive_record( |
| pos_row: pd.Series, species: str, |
| paired_neg: Optional[CandidateCDS], |
| fallback_label: Optional[str], |
| seed: int, |
| ) -> dict: |
| seq = str(pos_row["actual_sequence"]).upper() |
| return { |
| |
| |
| "region_id": f"{pos_row['vfg_id']}_VIRULENCE", |
| "is_positive": True, |
| "label": "VIRULENCE", |
| "label_class": pos_row.get("vfcategory_name"), |
| "label_subclass": pos_row.get("vf_prototype_name"), |
| "gene_symbol": pos_row.get("gene_name"), |
| "vf_id": pos_row.get("vf_id"), |
| "vfcategory_id": pos_row.get("vfcategory_id"), |
| "species": species, |
| "organism": pos_row.get("organism"), |
| "source_db": "VFDB", |
| "source_accession": pos_row.get("source_accession"), |
| "cds_length": len(seq), |
| "gc_content": round(gc_content(seq), 4), |
| "paired_with": paired_neg.locus_tag if paired_neg else None, |
| "negative_pool_fallback": fallback_label, |
| "extract_status": "ok" if paired_neg else "no_matching_negative", |
| "random_seed": seed, |
| "sequence": seq, |
| } |
|
|
|
|
| def negative_record( |
| neg: CandidateCDS, fa: Fasta, paired_pos_acc: str, paired_pos_class, |
| paired_pos_subclass, fallback_label: str, seed: int, |
| ) -> dict: |
| seq = get_coding_strand_seq(fa, neg) |
| return { |
| "region_id": f"{neg.locus_tag}_negative", |
| "is_positive": False, |
| "label": "negative", |
| "label_class": paired_pos_class, |
| "label_subclass": paired_pos_subclass, |
| "gene_symbol": neg.gene_symbol, |
| "product": neg.product, |
| "species": neg.species, |
| "genus": neg.genus, |
| "family": neg.family, |
| "catalogue": neg.catalogue, |
| "mag_id": neg.mag_id, |
| "locus_tag": neg.locus_tag, |
| "contig": neg.contig, |
| "gene_start": neg.start, |
| "gene_end": neg.end, |
| "strand": neg.strand, |
| "cds_length": neg.cds_length, |
| "partial": "00", |
| "cds_in_mobilome": neg.in_mobilome, |
| "gc_content": round(neg.gc, 4), |
| "paired_with": paired_pos_acc, |
| "negative_pool_fallback": fallback_label, |
| "extract_status": "ok", |
| "random_seed": seed, |
| "sequence": seq, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def discover_local_mags(catalogue_root: Path, catalogue_name: str, metadata_tsv: Path) -> dict[str, dict]: |
| """Return MAG_ID → {mag_dir, species, genus, family, catalogue} for MAGs present on disk.""" |
| meta = pd.read_csv(metadata_tsv, sep="\t") |
| meta["species"] = meta["Lineage"].apply(gtdb_species) |
| meta["genus"] = meta["Lineage"].apply(gtdb_genus) |
| meta["family"] = meta["Lineage"].apply(gtdb_family) |
| out = {} |
| for prefix_dir in catalogue_root.glob("MGYG*"): |
| if not prefix_dir.is_dir(): |
| continue |
| for mag_dir in prefix_dir.glob("MGYG*"): |
| mag_id = mag_dir.name |
| genome_dir = mag_dir / "genome" |
| if not genome_dir.is_dir(): |
| continue |
| row = meta[meta["Genome"] == mag_id] |
| if len(row) == 0: |
| continue |
| r = row.iloc[0] |
| if not isinstance(r["species"], str): |
| continue |
| out[mag_id] = { |
| "mag_dir": genome_dir, |
| "species": r["species"], |
| "genus": r["genus"], |
| "family": r["family"], |
| "catalogue": catalogue_name, |
| } |
| return out |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--master-parquet", type=Path, |
| default=Path("/home/ror25cal/MGnify/data/master_annotations_clean.parquet")) |
| ap.add_argument("--skin-root", type=Path, |
| default=Path("/home/ror25cal/MGnify/data/human-skin")) |
| ap.add_argument("--chicken-gut-root", type=Path, |
| default=Path("/home/ror25cal/MGnify/data/chicken-gut")) |
| ap.add_argument("--out-dir", type=Path, |
| default=Path("/home/ror25cal/MGnify/data/targeted_jsonl/vfdb")) |
| ap.add_argument("--seed", type=int, default=42) |
| args = ap.parse_args() |
|
|
| args.out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| clean = pd.read_parquet(args.master_parquet) |
| vir = clean[clean["db"] == "VFDB"].copy() |
| vir["species"] = vir["organism"].apply(vfdb_species) |
| print(f"VFDB rows total: {len(vir)}") |
| print(f" with parseable species: {vir['species'].notna().sum()}") |
|
|
| mags_skin = discover_local_mags( |
| args.skin_root / "species_catalogue", "skin", |
| args.skin_root / "genomes-all_metadata.tsv", |
| ) |
| mags_cg = discover_local_mags( |
| args.chicken_gut_root / "species_catalogue", "chicken-gut", |
| args.chicken_gut_root / "genomes-all_metadata.tsv", |
| ) |
| all_mags = {**mags_skin, **mags_cg} |
| print(f"Local MAGs on disk: {len(all_mags)} (skin={len(mags_skin)}, chicken-gut={len(mags_cg)})") |
|
|
| mag_species_set = {m["species"] for m in all_mags.values()} |
| overlap = set(vir["species"].dropna()) & mag_species_set |
| vir = vir[vir["species"].isin(overlap)].reset_index(drop=True) |
| print(f"Species in VFDB ∩ MAG: {len(overlap)}") |
| print(f"VFDB positives retained: {len(vir)}") |
|
|
| |
| mags_by_species: dict[str, list[tuple[str, dict]]] = defaultdict(list) |
| for mag_id, info in all_mags.items(): |
| if info["species"] in overlap: |
| mags_by_species[info["species"]].append((mag_id, info)) |
| print(f"MAGs covering retained species: {sum(len(v) for v in mags_by_species.values())}") |
| print() |
|
|
| |
| |
| vfdb_seq_set = set(vir["actual_sequence"].dropna().astype(str).str.upper()) |
| print(f"VFDB sequence-hash exclusion set: {len(vfdb_seq_set)} sequences") |
| print() |
|
|
| pools_by_species: dict[str, list[CandidateCDS]] = defaultdict(list) |
| fa_by_mag: dict[str, Fasta] = {} |
| for sp, mag_list in mags_by_species.items(): |
| for mag_id, info in mag_list: |
| cands, fa = build_mag_pool( |
| info["mag_dir"], mag_id, sp, |
| info["genus"], info["family"], info["catalogue"], |
| vfdb_seq_hashes=vfdb_seq_set, |
| ) |
| if fa is None: |
| print(f" [{mag_id}] missing files; skipped") |
| continue |
| pools_by_species[sp].extend(cands) |
| fa_by_mag[mag_id] = fa |
| print(f" {sp:35s} pool={len(pools_by_species[sp]):5d} candidates " |
| f"(MAGs: {[m for m,_ in mag_list]})") |
| print() |
|
|
| |
| rng_global = random.Random(args.seed) |
|
|
| fallback_counter = defaultdict(int) |
| total_pairs = 0 |
| total_unpaired = 0 |
|
|
| for sp, pos_subset in vir.groupby("species"): |
| slug = species_slug(sp) |
| species_pool = pools_by_species.get(sp, []) |
| if not species_pool: |
| print(f" {sp}: no candidate pool, skipping {len(pos_subset)} positives") |
| continue |
|
|
| |
| sp_rng = random.Random(args.seed) |
| positives_shuffled = pos_subset.sample(frac=1, random_state=args.seed).reset_index(drop=True) |
| used_locus_tags: set[str] = set() |
| records: list[dict] = [] |
| sp_pairs = 0 |
| sp_unpaired = 0 |
|
|
| for _, prow in positives_shuffled.iterrows(): |
| seq = str(prow["actual_sequence"]).upper() |
| pos_len = len(seq) |
| pos_gc = gc_content(seq) |
| neg, fallback = pick_negative(pos_len, pos_gc, species_pool, used_locus_tags, sp_rng) |
| if neg is None: |
| pos_rec = positive_record(prow, sp, None, None, args.seed) |
| records.append(pos_rec) |
| sp_unpaired += 1 |
| continue |
| used_locus_tags.add(neg.locus_tag) |
| pos_rec = positive_record(prow, sp, neg, fallback, args.seed) |
| neg_rec = negative_record( |
| neg, fa_by_mag[neg.mag_id], |
| paired_pos_acc=prow["vfg_id"], |
| paired_pos_class=prow.get("vfcategory_name"), |
| paired_pos_subclass=prow.get("vf_prototype_name"), |
| fallback_label=fallback, seed=args.seed, |
| ) |
| records.append(pos_rec); records.append(neg_rec) |
| fallback_counter[fallback] += 1 |
| sp_pairs += 1 |
|
|
| out_path = args.out_dir / f"{slug}.jsonl" |
| with open(out_path, "w") as f: |
| for r in records: |
| f.write(json.dumps(_jsonable(r)) + "\n") |
| total_pairs += sp_pairs |
| total_unpaired += sp_unpaired |
| print(f" [{sp:35s}] pairs={sp_pairs:5d} unpaired={sp_unpaired:4d} → {out_path.name}") |
|
|
| print() |
| print("=" * 65) |
| print("SUMMARY") |
| print("=" * 65) |
| print(f"Total VFDB positives processed: {len(vir)}") |
| print(f"Pairs emitted: {total_pairs}") |
| print(f"Unpaired (no_matching_negative): {total_unpaired}") |
| print(f"\nFallback usage (negative selection):") |
| for label in [l for *_, l in FALLBACK_LEVELS]: |
| print(f" {label:25s} {fallback_counter[label]}") |
|
|
|
|
| def _jsonable(d: dict) -> dict: |
| """Coerce numpy / pandas scalars and NaN to JSON-safe types.""" |
| out = {} |
| for k, v in d.items(): |
| if v is None: |
| out[k] = None |
| elif isinstance(v, float) and v != v: |
| out[k] = None |
| elif hasattr(v, "item"): |
| out[k] = v.item() |
| else: |
| out[k] = v |
| return out |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|