mgnify-evo2-probes / code /scripts /build_human_viral_jsonl.py
JG1310's picture
Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs
eb69de4 verified
"""
Build a JSONL of paired (human-viral-sequence, length+GC-matched bacterial-CDS) records
for a "is this a human virus?" probe.
Positives: 6000 human viral sequences from human_viral_sequences.xlsx.
Negatives: VFDB CDSs (14,695 bacterial CDSs from MGnify MAGs, length 78-9492 bp, GC 0.20-0.75).
Match per positive: closest VFDB CDS within ±20% length and ±0.05 GC, no replacement,
deterministic (seed=42). Viral records longer than 9492 bp get dropped (no VFDB match
possible) — these are mostly complete viral genomes (~44% of the corpus).
Output: ~/MGnify/data/targeted_jsonl/human_viral/{paired,unpaired}.jsonl
"""
from __future__ import annotations
import json
import random
from collections import defaultdict
from pathlib import Path
import pandas as pd
VIRAL_XLSX = "/home/ror25cal/MGnify/human_viral_sequences.xlsx"
VFDB_DIR = Path("/home/ror25cal/MGnify/data/targeted_jsonl/vfdb")
OUT_DIR = Path("/home/ror25cal/MGnify/data/targeted_jsonl/human_viral")
SEED = 42
LEN_TOL = 0.20
GC_TOL = 0.05
def gc_content(seq: str) -> float:
s = seq.upper()
n = sum(1 for c in s if c in "ACGT")
if n == 0:
return 0.0
g = sum(1 for c in s if c in "GC")
return g / n
def main():
rng = random.Random(SEED)
OUT_DIR.mkdir(parents=True, exist_ok=True)
# ---- Load viral positives ----
print("loading viral xlsx…")
df = pd.read_excel(VIRAL_XLSX)
print(f" {len(df)} viral records loaded")
df["gc_content"] = df["sequence"].astype(str).map(gc_content)
print(f" computed GC content for all viral records")
# ---- Load VFDB negative pool ----
print("loading VFDB JSONLs…")
vfdb = []
for fp in sorted(VFDB_DIR.glob("*.jsonl")):
for line in fp.read_text().splitlines():
r = json.loads(line)
if r.get("extract_status") != "ok":
continue
vfdb.append(r)
print(f" {len(vfdb)} VFDB records (both VF positives and matched-CDS negatives)")
# Pre-bucket VFDB by length-tens for fast lookup (e.g. round to nearest 100 bp)
# Actually iterate linearly per viral; 14k × 6k is feasible (~80M comparisons,
# each cheap). Bucket-and-skip cleverness not needed.
vfdb_indexed = list(enumerate(vfdb))
used_vfdb_idx: set[int] = set()
paired_records = []
unpaired_records = []
n_too_long = 0
# Shuffle viral records for unbiased no-replacement matching
viral_rows = list(df.iterrows())
rng.shuffle(viral_rows)
for _orig_idx, vr in viral_rows:
vlen = int(vr["sequence_len"])
vgc = float(vr["gc_content"])
if vlen > 9492: # max VFDB length
n_too_long += 1
unpaired_records.append({
"drop_reason": "viral length exceeds VFDB max (9492 bp)",
**{k: (v if pd.notna(v) else None)
for k, v in vr.to_dict().items()
if k != "sequence"},
"sequence_first_200bp": str(vr["sequence"])[:200],
})
continue
# Build candidate list: VFDB records within ±20% length AND ±0.05 GC, not yet used
len_lo = vlen * (1 - LEN_TOL)
len_hi = vlen * (1 + LEN_TOL)
gc_lo = vgc - GC_TOL
gc_hi = vgc + GC_TOL
candidates = [
(i, n) for i, n in vfdb_indexed
if i not in used_vfdb_idx
and len_lo <= n["cds_length"] <= len_hi
and gc_lo <= n["gc_content"] <= gc_hi
]
if not candidates:
unpaired_records.append({
"drop_reason": "no VFDB candidate within length+GC tolerance",
"viral_length": vlen,
"viral_gc": vgc,
**{k: (v if pd.notna(v) else None)
for k, v in vr.to_dict().items()
if k not in ("sequence",)},
"sequence_first_200bp": str(vr["sequence"])[:200],
})
continue
# Pick one randomly from candidates (rng-driven)
idx, neg = rng.choice(candidates)
used_vfdb_idx.add(idx)
# Construct paired records (positive then negative, mirroring our convention)
viral_id = str(vr["sequence_id"])
neg_locus = neg["locus_tag"] if "locus_tag" in neg else neg.get("region_id", f"vfdb_{idx}")
pos_record = {
"region_id": f"VIRAL_{viral_id}",
"is_positive": True,
"label": "HUMAN_VIRAL",
"label_class": str(vr.get("genome_type") or "unknown"),
"label_subclass": (str(vr["product_name"]) if pd.notna(vr.get("product_name")) else None),
"source_db": str(vr.get("db") or "GenBank"),
"source_accession": (str(vr["source_accession"]) if pd.notna(vr.get("source_accession")) else None),
"sequence_id": viral_id,
"seq_hash": (str(vr["seq_hash"]) if pd.notna(vr.get("seq_hash")) else None),
"organism": (str(vr["organism"]) if pd.notna(vr.get("organism")) else None),
"gene_name": (str(vr["gene_name"]) if pd.notna(vr.get("gene_name")) else None),
"product_name": (str(vr["product_name"]) if pd.notna(vr.get("product_name")) else None),
"description": (str(vr["description"]) if pd.notna(vr.get("description")) else None),
"sample_source": (str(vr["sample_source"]) if pd.notna(vr.get("sample_source")) else None),
"cds_length": vlen,
"gc_content": vgc,
"paired_with": neg_locus,
"sequence": str(vr["sequence"]),
}
neg_record = {
"region_id": f"NEG_{neg_locus}_for_{viral_id}",
"is_positive": False,
"label": "negative",
"label_class": neg.get("label_class"),
"label_subclass": neg.get("label_subclass"),
"source_db": "VFDB-derived MGnify-MAG-CDS",
"source_accession": None,
"vfdb_origin": neg.get("region_id"), # original VFDB JSONL region_id
"vfdb_paired_with": neg.get("paired_with"), # what the VFDB neg was originally paired with (a VF gene)
"vfdb_is_positive": bool(neg.get("is_positive")),
"locus_tag": neg.get("locus_tag"),
"mag_id": neg.get("mag_id"),
"species": neg.get("species"),
"cds_length": int(neg["cds_length"]),
"gc_content": float(neg["gc_content"]),
"paired_with": viral_id,
"sequence": str(neg["sequence"]),
}
paired_records.append(pos_record)
paired_records.append(neg_record)
print(f"\n=== matching summary ===")
print(f" viral positives: {len(df)}")
print(f" too long (>9492 bp, dropped): {n_too_long}")
print(f" no length+GC match: {sum(1 for u in unpaired_records if u.get('drop_reason') == 'no VFDB candidate within length+GC tolerance')}")
print(f" successfully paired: {len(paired_records) // 2}")
out_paired = OUT_DIR / "human_viral_v1.jsonl"
out_paired.write_text("\n".join(json.dumps(r) for r in paired_records) + "\n")
print(f"\n wrote {out_paired} ({out_paired.stat().st_size/1024/1024:.1f} MB)")
out_unpaired = OUT_DIR / "human_viral_v1_unpaired.jsonl"
out_unpaired.write_text("\n".join(json.dumps(u) for u in unpaired_records) + "\n")
print(f" wrote {out_unpaired} ({out_unpaired.stat().st_size/1024:.1f} KB) ← dropped/unpaired viral records, for inspection")
# Quick stats by genome type for the paired viral set
print("\n=== paired-viral genome_type breakdown ===")
by_gt = defaultdict(int)
for r in paired_records:
if r["is_positive"]:
by_gt[r["label_class"]] += 1
for gt, n in sorted(by_gt.items(), key=lambda x: -x[1]):
print(f" {gt:20s} {n}")
if __name__ == "__main__":
main()