| """Sample ~20 records per genuinely-hierarchical secondary-level category from |
| master_annotations_clean.parquet for SAE qualitative exploration. |
| |
| Categories sampled (rationale — see THREADS.md Thread C): |
| - 24 AMR drug-class secondary_labels (e.g. BETA-LACTAM, MACROLIDE) |
| - 12 STRESS metals/biocides secondary_labels (e.g. MERCURY, ARSENIC) |
| - 6 iGEM type secondary_labels (CRISPR, fluorescent, …) |
| - 14 VFDB vfcategory_name (true virulence subclass) (Effector delivery system, …) |
| |
| Skipped (non-hierarchical / degenerate / provenance-only): |
| - virulence → full / core / VIRULENCE (VFDB curation tier, not mechanism) |
| - STRESS → STRESS (echo of primary label) |
| - AMR → AMR_gene (CARD's catch-all "unspecified mechanism") |
| - AMR → synthetic_AMR (provenance label, not mechanism) |
| |
| Output: one JSONL per category at data/targeted_jsonl/qual/<slug>.jsonl |
| Schema mirrors the VFDB pipeline; positives only (no matched negatives needed |
| for qualitative SAE exploration). |
| |
| Caveat: iGEM-derived categories sampled here are multi-component constructs |
| (~84% have internal stop codons). SAE features will reflect the whole |
| construct rather than a single CDS — flagged to the consumer. |
| """ |
| import argparse |
| import json |
| import re |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| import pandas as pd |
|
|
|
|
| |
| EXCLUDE_SECONDARY = { |
| |
| "full", "core", "VIRULENCE", |
| "STRESS", |
| "AMR_gene", |
| "synthetic_AMR", |
| } |
|
|
|
|
| def category_slug(s: str) -> str: |
| return re.sub(r"[^A-Za-z0-9]+", "_", str(s)).strip("_") |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--master-parquet", type=Path, |
| default=Path("/home/ror25cal/MGnify/data/master_annotations_clean.parquet")) |
| ap.add_argument("--out-dir", type=Path, |
| default=Path("/home/ror25cal/MGnify/data/targeted_jsonl/qual")) |
| ap.add_argument("--per-category", type=int, default=20) |
| ap.add_argument("--seed", type=int, default=42) |
| args = ap.parse_args() |
|
|
| args.out_dir.mkdir(parents=True, exist_ok=True) |
| df = pd.read_parquet(args.master_parquet) |
| df = df[df["actual_sequence"].notna()].copy() |
|
|
| |
| |
| |
| |
| rows_by_cat: dict[tuple[str, str], pd.DataFrame] = {} |
|
|
| |
| sub = df[(df["primary_label"] == "AMR") & ~df["secondary_label"].isin(EXCLUDE_SECONDARY)] |
| for sec, grp in sub.groupby("secondary_label"): |
| rows_by_cat[("AMR", str(sec))] = grp |
|
|
| |
| sub = df[(df["primary_label"] == "STRESS") & ~df["secondary_label"].isin(EXCLUDE_SECONDARY)] |
| for sec, grp in sub.groupby("secondary_label"): |
| rows_by_cat[("STRESS", str(sec))] = grp |
|
|
| |
| |
| |
| igem_primaries = {"reporter", "gene_editing", "integration", "toxin", "biosafety", "containment"} |
| sub = df[df["primary_label"].isin(igem_primaries) & ~df["secondary_label"].isin(EXCLUDE_SECONDARY)] |
| for sec, grp in sub.groupby("secondary_label"): |
| rows_by_cat[("iGEM_synthetic", str(sec))] = grp |
|
|
| |
| sub = df[df["vfcategory_name"].notna()] |
| for cat, grp in sub.groupby("vfcategory_name"): |
| rows_by_cat[("VFDB_virulence", str(cat))] = grp |
|
|
| |
| summary_rows = [] |
| total_records = 0 |
| for (group, sec), grp in sorted(rows_by_cat.items()): |
| n_avail = len(grp) |
| n_take = min(args.per_category, n_avail) |
| sampled = grp.sample(n=n_take, random_state=args.seed).reset_index(drop=True) |
| slug = f"{group}__{category_slug(sec)}" |
| out_path = args.out_dir / f"{slug}.jsonl" |
| with open(out_path, "w") as f: |
| for _, r in sampled.iterrows(): |
| seq = str(r["actual_sequence"]).upper() |
| rec = { |
| "region_id": f"{r.get('seq_hash')}_QUAL", |
| "is_positive": True, |
| "label": "QUAL", |
| "label_group": group, |
| "label_class": sec, |
| "primary_label": r.get("primary_label"), |
| "functional_class": r.get("functional_class"), |
| "gene_symbol": r.get("gene_name"), |
| "product_name": r.get("product_name"), |
| "organism": r.get("organism"), |
| "source_db": r.get("db"), |
| "source_accession": r.get("source_accession"), |
| "vf_id": r.get("vf_id"), |
| "vfcategory_name": r.get("vfcategory_name"), |
| "vfcategory_id": r.get("vfcategory_id"), |
| "vf_prototype_name": r.get("vf_prototype_name"), |
| "seq_hash": r.get("seq_hash"), |
| "cds_length": len(seq), |
| "mag_id": slug, |
| "random_seed": args.seed, |
| "sequence": seq, |
| } |
| |
| clean = {} |
| for k, v in rec.items(): |
| if v is None or (isinstance(v, float) and v != v): |
| clean[k] = None |
| elif pd.isna(v): |
| clean[k] = None |
| elif hasattr(v, "item"): |
| clean[k] = v.item() |
| else: |
| clean[k] = v |
| f.write(json.dumps(clean) + "\n") |
| summary_rows.append({"group": group, "category": sec, "available": n_avail, "sampled": n_take, "slug": slug}) |
| total_records += n_take |
|
|
| |
| summary_df = pd.DataFrame(summary_rows).sort_values(["group", "category"]) |
| summary_path = args.out_dir / "_sample_summary.csv" |
| summary_df.to_csv(summary_path, index=False) |
| print(f"Sampled {len(summary_rows)} categories, {total_records} total records") |
| print(f"Output: {args.out_dir}") |
| print(f"Summary: {summary_path}") |
| print() |
| print(summary_df.to_string(index=False)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|