| """ |
| Build a JSONL of paired (human-viral-sequence, length+GC-matched bacterial-CDS) records |
| for a "is this a human virus?" probe. |
| |
| Positives: 6000 human viral sequences from human_viral_sequences.xlsx. |
| Negatives: VFDB CDSs (14,695 bacterial CDSs from MGnify MAGs, length 78-9492 bp, GC 0.20-0.75). |
| |
| Match per positive: closest VFDB CDS within ±20% length and ±0.05 GC, no replacement, |
| deterministic (seed=42). Viral records longer than 9492 bp get dropped (no VFDB match |
| possible) — these are mostly complete viral genomes (~44% of the corpus). |
| |
| Output: ~/MGnify/data/targeted_jsonl/human_viral/{paired,unpaired}.jsonl |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import random |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| import pandas as pd |
|
|
|
|
| VIRAL_XLSX = "/home/ror25cal/MGnify/human_viral_sequences.xlsx" |
| VFDB_DIR = Path("/home/ror25cal/MGnify/data/targeted_jsonl/vfdb") |
| OUT_DIR = Path("/home/ror25cal/MGnify/data/targeted_jsonl/human_viral") |
| SEED = 42 |
| LEN_TOL = 0.20 |
| GC_TOL = 0.05 |
|
|
|
|
| def gc_content(seq: str) -> float: |
| s = seq.upper() |
| n = sum(1 for c in s if c in "ACGT") |
| if n == 0: |
| return 0.0 |
| g = sum(1 for c in s if c in "GC") |
| return g / n |
|
|
|
|
| def main(): |
| rng = random.Random(SEED) |
| OUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| print("loading viral xlsx…") |
| df = pd.read_excel(VIRAL_XLSX) |
| print(f" {len(df)} viral records loaded") |
| df["gc_content"] = df["sequence"].astype(str).map(gc_content) |
| print(f" computed GC content for all viral records") |
|
|
| |
| print("loading VFDB JSONLs…") |
| vfdb = [] |
| for fp in sorted(VFDB_DIR.glob("*.jsonl")): |
| for line in fp.read_text().splitlines(): |
| r = json.loads(line) |
| if r.get("extract_status") != "ok": |
| continue |
| vfdb.append(r) |
| print(f" {len(vfdb)} VFDB records (both VF positives and matched-CDS negatives)") |
|
|
| |
| |
| |
| vfdb_indexed = list(enumerate(vfdb)) |
|
|
| used_vfdb_idx: set[int] = set() |
| paired_records = [] |
| unpaired_records = [] |
| n_too_long = 0 |
|
|
| |
| viral_rows = list(df.iterrows()) |
| rng.shuffle(viral_rows) |
|
|
| for _orig_idx, vr in viral_rows: |
| vlen = int(vr["sequence_len"]) |
| vgc = float(vr["gc_content"]) |
|
|
| if vlen > 9492: |
| n_too_long += 1 |
| unpaired_records.append({ |
| "drop_reason": "viral length exceeds VFDB max (9492 bp)", |
| **{k: (v if pd.notna(v) else None) |
| for k, v in vr.to_dict().items() |
| if k != "sequence"}, |
| "sequence_first_200bp": str(vr["sequence"])[:200], |
| }) |
| continue |
|
|
| |
| len_lo = vlen * (1 - LEN_TOL) |
| len_hi = vlen * (1 + LEN_TOL) |
| gc_lo = vgc - GC_TOL |
| gc_hi = vgc + GC_TOL |
| candidates = [ |
| (i, n) for i, n in vfdb_indexed |
| if i not in used_vfdb_idx |
| and len_lo <= n["cds_length"] <= len_hi |
| and gc_lo <= n["gc_content"] <= gc_hi |
| ] |
| if not candidates: |
| unpaired_records.append({ |
| "drop_reason": "no VFDB candidate within length+GC tolerance", |
| "viral_length": vlen, |
| "viral_gc": vgc, |
| **{k: (v if pd.notna(v) else None) |
| for k, v in vr.to_dict().items() |
| if k not in ("sequence",)}, |
| "sequence_first_200bp": str(vr["sequence"])[:200], |
| }) |
| continue |
|
|
| |
| idx, neg = rng.choice(candidates) |
| used_vfdb_idx.add(idx) |
|
|
| |
| viral_id = str(vr["sequence_id"]) |
| neg_locus = neg["locus_tag"] if "locus_tag" in neg else neg.get("region_id", f"vfdb_{idx}") |
|
|
| pos_record = { |
| "region_id": f"VIRAL_{viral_id}", |
| "is_positive": True, |
| "label": "HUMAN_VIRAL", |
| "label_class": str(vr.get("genome_type") or "unknown"), |
| "label_subclass": (str(vr["product_name"]) if pd.notna(vr.get("product_name")) else None), |
| "source_db": str(vr.get("db") or "GenBank"), |
| "source_accession": (str(vr["source_accession"]) if pd.notna(vr.get("source_accession")) else None), |
| "sequence_id": viral_id, |
| "seq_hash": (str(vr["seq_hash"]) if pd.notna(vr.get("seq_hash")) else None), |
| "organism": (str(vr["organism"]) if pd.notna(vr.get("organism")) else None), |
| "gene_name": (str(vr["gene_name"]) if pd.notna(vr.get("gene_name")) else None), |
| "product_name": (str(vr["product_name"]) if pd.notna(vr.get("product_name")) else None), |
| "description": (str(vr["description"]) if pd.notna(vr.get("description")) else None), |
| "sample_source": (str(vr["sample_source"]) if pd.notna(vr.get("sample_source")) else None), |
| "cds_length": vlen, |
| "gc_content": vgc, |
| "paired_with": neg_locus, |
| "sequence": str(vr["sequence"]), |
| } |
| neg_record = { |
| "region_id": f"NEG_{neg_locus}_for_{viral_id}", |
| "is_positive": False, |
| "label": "negative", |
| "label_class": neg.get("label_class"), |
| "label_subclass": neg.get("label_subclass"), |
| "source_db": "VFDB-derived MGnify-MAG-CDS", |
| "source_accession": None, |
| "vfdb_origin": neg.get("region_id"), |
| "vfdb_paired_with": neg.get("paired_with"), |
| "vfdb_is_positive": bool(neg.get("is_positive")), |
| "locus_tag": neg.get("locus_tag"), |
| "mag_id": neg.get("mag_id"), |
| "species": neg.get("species"), |
| "cds_length": int(neg["cds_length"]), |
| "gc_content": float(neg["gc_content"]), |
| "paired_with": viral_id, |
| "sequence": str(neg["sequence"]), |
| } |
| paired_records.append(pos_record) |
| paired_records.append(neg_record) |
|
|
| print(f"\n=== matching summary ===") |
| print(f" viral positives: {len(df)}") |
| print(f" too long (>9492 bp, dropped): {n_too_long}") |
| print(f" no length+GC match: {sum(1 for u in unpaired_records if u.get('drop_reason') == 'no VFDB candidate within length+GC tolerance')}") |
| print(f" successfully paired: {len(paired_records) // 2}") |
|
|
| out_paired = OUT_DIR / "human_viral_v1.jsonl" |
| out_paired.write_text("\n".join(json.dumps(r) for r in paired_records) + "\n") |
| print(f"\n wrote {out_paired} ({out_paired.stat().st_size/1024/1024:.1f} MB)") |
|
|
| out_unpaired = OUT_DIR / "human_viral_v1_unpaired.jsonl" |
| out_unpaired.write_text("\n".join(json.dumps(u) for u in unpaired_records) + "\n") |
| print(f" wrote {out_unpaired} ({out_unpaired.stat().st_size/1024:.1f} KB) ← dropped/unpaired viral records, for inspection") |
|
|
| |
| print("\n=== paired-viral genome_type breakdown ===") |
| by_gt = defaultdict(int) |
| for r in paired_records: |
| if r["is_positive"]: |
| by_gt[r["label_class"]] += 1 |
| for gt, n in sorted(by_gt.items(), key=lambda x: -x[1]): |
| print(f" {gt:20s} {n}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|