mgnify-evo2-probes / code /scripts /sample_qual_jsonl.py
JG1310's picture
Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs
eb69de4 verified
"""Sample ~20 records per genuinely-hierarchical secondary-level category from
master_annotations_clean.parquet for SAE qualitative exploration.
Categories sampled (rationale — see THREADS.md Thread C):
- 24 AMR drug-class secondary_labels (e.g. BETA-LACTAM, MACROLIDE)
- 12 STRESS metals/biocides secondary_labels (e.g. MERCURY, ARSENIC)
- 6 iGEM type secondary_labels (CRISPR, fluorescent, …)
- 14 VFDB vfcategory_name (true virulence subclass) (Effector delivery system, …)
Skipped (non-hierarchical / degenerate / provenance-only):
- virulence → full / core / VIRULENCE (VFDB curation tier, not mechanism)
- STRESS → STRESS (echo of primary label)
- AMR → AMR_gene (CARD's catch-all "unspecified mechanism")
- AMR → synthetic_AMR (provenance label, not mechanism)
Output: one JSONL per category at data/targeted_jsonl/qual/<slug>.jsonl
Schema mirrors the VFDB pipeline; positives only (no matched negatives needed
for qualitative SAE exploration).
Caveat: iGEM-derived categories sampled here are multi-component constructs
(~84% have internal stop codons). SAE features will reflect the whole
construct rather than a single CDS — flagged to the consumer.
"""
import argparse
import json
import re
from collections import defaultdict
from pathlib import Path
import pandas as pd
# ---- Categories to skip ----
EXCLUDE_SECONDARY = {
# Non-hierarchical or degenerate
"full", "core", "VIRULENCE", # virulence-tier, not mechanism
"STRESS", # echo of primary
"AMR_gene", # catch-all, no finer info
"synthetic_AMR", # provenance, not mechanism
}
def category_slug(s: str) -> str:
return re.sub(r"[^A-Za-z0-9]+", "_", str(s)).strip("_")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--master-parquet", type=Path,
default=Path("/home/ror25cal/MGnify/data/master_annotations_clean.parquet"))
ap.add_argument("--out-dir", type=Path,
default=Path("/home/ror25cal/MGnify/data/targeted_jsonl/qual"))
ap.add_argument("--per-category", type=int, default=20)
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
args.out_dir.mkdir(parents=True, exist_ok=True)
df = pd.read_parquet(args.master_parquet)
df = df[df["actual_sequence"].notna()].copy()
# Build the three category groupings.
# 1. AMR drug classes + 2. STRESS metals + 3. iGEM types: from secondary_label,
# excluding the non-hierarchical labels.
# 4. VFDB vfcategory_name: from the source_header parse (already in cleaned parquet).
rows_by_cat: dict[tuple[str, str], pd.DataFrame] = {}
# AMR drug-class
sub = df[(df["primary_label"] == "AMR") & ~df["secondary_label"].isin(EXCLUDE_SECONDARY)]
for sec, grp in sub.groupby("secondary_label"):
rows_by_cat[("AMR", str(sec))] = grp
# STRESS metals/biocides
sub = df[(df["primary_label"] == "STRESS") & ~df["secondary_label"].isin(EXCLUDE_SECONDARY)]
for sec, grp in sub.groupby("secondary_label"):
rows_by_cat[("STRESS", str(sec))] = grp
# iGEM subtypes — primary_label varies (reporter, gene_editing, integration,
# toxin, biosafety, containment), but functional_class == "synthetic_marker"
# captures iGEM cleanly along with toxin/biosafety/containment.
igem_primaries = {"reporter", "gene_editing", "integration", "toxin", "biosafety", "containment"}
sub = df[df["primary_label"].isin(igem_primaries) & ~df["secondary_label"].isin(EXCLUDE_SECONDARY)]
for sec, grp in sub.groupby("secondary_label"):
rows_by_cat[("iGEM_synthetic", str(sec))] = grp
# VFDB vfcategory_name (true virulence subclasses)
sub = df[df["vfcategory_name"].notna()]
for cat, grp in sub.groupby("vfcategory_name"):
rows_by_cat[("VFDB_virulence", str(cat))] = grp
# Sample, write JSONL, record provenance summary
summary_rows = []
total_records = 0
for (group, sec), grp in sorted(rows_by_cat.items()):
n_avail = len(grp)
n_take = min(args.per_category, n_avail)
sampled = grp.sample(n=n_take, random_state=args.seed).reset_index(drop=True)
slug = f"{group}__{category_slug(sec)}"
out_path = args.out_dir / f"{slug}.jsonl"
with open(out_path, "w") as f:
for _, r in sampled.iterrows():
seq = str(r["actual_sequence"]).upper()
rec = {
"region_id": f"{r.get('seq_hash')}_QUAL", # seq_hash unique per row in master
"is_positive": True,
"label": "QUAL",
"label_group": group, # AMR / STRESS / iGEM_synthetic / VFDB_virulence
"label_class": sec, # the secondary value
"primary_label": r.get("primary_label"),
"functional_class": r.get("functional_class"),
"gene_symbol": r.get("gene_name"),
"product_name": r.get("product_name"),
"organism": r.get("organism"),
"source_db": r.get("db"),
"source_accession": r.get("source_accession"),
"vf_id": r.get("vf_id"),
"vfcategory_name": r.get("vfcategory_name"),
"vfcategory_id": r.get("vfcategory_id"),
"vf_prototype_name": r.get("vf_prototype_name"),
"seq_hash": r.get("seq_hash"),
"cds_length": len(seq),
"mag_id": slug, # placeholder for path layout (matches embed_vfdb_lean)
"random_seed": args.seed,
"sequence": seq,
}
# Coerce pandas NA / NaN / numpy scalars to JSON-safe types
clean = {}
for k, v in rec.items():
if v is None or (isinstance(v, float) and v != v):
clean[k] = None
elif pd.isna(v):
clean[k] = None
elif hasattr(v, "item"):
clean[k] = v.item()
else:
clean[k] = v
f.write(json.dumps(clean) + "\n")
summary_rows.append({"group": group, "category": sec, "available": n_avail, "sampled": n_take, "slug": slug})
total_records += n_take
# Print + save summary
summary_df = pd.DataFrame(summary_rows).sort_values(["group", "category"])
summary_path = args.out_dir / "_sample_summary.csv"
summary_df.to_csv(summary_path, index=False)
print(f"Sampled {len(summary_rows)} categories, {total_records} total records")
print(f"Output: {args.out_dir}")
print(f"Summary: {summary_path}")
print()
print(summary_df.to_string(index=False))
if __name__ == "__main__":
main()