PanGenomeWatchAI / scripts /run_precompute.py
Ashkan Taghipour (The University of Western Australia)
UI overhaul: immersive chapter-based experience
14ba315
#!/usr/bin/env python3
"""
CLI script to generate all precomputed data.
Usage: python scripts/run_precompute.py --data-dir data/ --output-dir precomputed/
"""
import argparse
import sys
import os
import time
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.data_loader import (
load_pav, parse_gff_genes, parse_protein_fasta,
build_contig_index, build_contig_name_mapping, validate_joins,
)
from src.precompute import (
compute_gene_frequency, compute_line_stats, compute_line_embedding,
compute_similarity_topk, build_gff_gene_parquet, build_protein_parquet,
save_contig_index, compute_hotspot_bins, compute_cluster_markers,
compute_line_embedding_3d, build_sunburst_data, build_polar_contig_layout,
compute_radar_axes,
)
from src.utils import logger, find_file
def main():
parser = argparse.ArgumentParser(description="Precompute pangenome data")
parser.add_argument("--data-dir", default="data/", help="Input data directory")
parser.add_argument("--output-dir", default="precomputed/", help="Output directory")
args = parser.parse_args()
data_dir = os.path.abspath(args.data_dir)
output_dir = os.path.abspath(args.output_dir)
os.makedirs(output_dir, exist_ok=True)
t_total = time.time()
# 1. Load raw data
logger.info("=== Phase 1: Loading raw data ===")
pav_path = os.path.join(data_dir, "89_line_PAV.txt")
from pathlib import Path
data_p = Path(data_dir)
gff_files = list(data_p.glob("*.gff"))
protein_files = list(data_p.glob("*protein*.fasta"))
genome_files = [f for f in data_p.glob("*.fasta") if "protein" not in f.name]
if not gff_files:
logger.error("No GFF file found in data directory")
sys.exit(1)
if not protein_files:
logger.error("No protein FASTA file found in data directory")
sys.exit(1)
pav = load_pav(pav_path)
gff_genes = parse_gff_genes(str(gff_files[0]))
protein_index = parse_protein_fasta(str(protein_files[0]))
contig_index = {}
if genome_files:
contig_index = build_contig_index(str(genome_files[0]))
else:
logger.warning("No genome FASTA found; contig index will be empty")
# Validation
logger.info("=== Validation ===")
contig_mapping = build_contig_name_mapping(gff_genes, contig_index)
report = validate_joins(pav, gff_genes, protein_index, contig_index)
for k, v in report.items():
logger.info(f" {k}: {v}")
# 2. Compute derived data
logger.info("=== Phase 2: Computing derived data ===")
gene_freq = compute_gene_frequency(pav)
gene_freq.to_parquet(os.path.join(output_dir, "pav_gene_frequency.parquet"), index=False)
line_stats = compute_line_stats(pav)
line_stats.to_parquet(os.path.join(output_dir, "line_stats.parquet"), index=False)
embedding = compute_line_embedding(pav)
embedding.to_parquet(os.path.join(output_dir, "line_embedding.parquet"), index=False)
similarity = compute_similarity_topk(pav, k=15)
similarity.to_parquet(os.path.join(output_dir, "line_similarity_topk.parquet"), index=False)
build_gff_gene_parquet(gff_genes, os.path.join(output_dir, "gff_gene_index.parquet"))
build_protein_parquet(protein_index, os.path.join(output_dir, "protein_index.parquet"))
save_contig_index(contig_index, contig_mapping, os.path.join(output_dir, "genome_contig_index.json"))
hotspots = compute_hotspot_bins(gff_genes, gene_freq, contig_index)
hotspots.to_parquet(os.path.join(output_dir, "hotspot_bins.parquet"), index=False)
markers = compute_cluster_markers(pav, embedding)
markers.to_parquet(os.path.join(output_dir, "cluster_markers.parquet"), index=False)
# Also save the PAV matrix as parquet for efficient loading
pav.to_parquet(os.path.join(output_dir, "pav_matrix.parquet"))
# 3. New derived data for UI overhaul
logger.info("=== Phase 3: New UI overhaul artifacts ===")
t_step = time.time()
embedding_3d = compute_line_embedding_3d(pav, embedding)
embedding_3d.to_parquet(os.path.join(output_dir, "line_embedding_3d.parquet"), index=False)
logger.info(f" -> line_embedding_3d.parquet ({time.time() - t_step:.1f}s)")
t_step = time.time()
build_sunburst_data(gene_freq, os.path.join(output_dir, "sunburst_hierarchy.json"))
logger.info(f" -> sunburst_hierarchy.json ({time.time() - t_step:.1f}s)")
t_step = time.time()
build_polar_contig_layout(hotspots, contig_index,
os.path.join(output_dir, "polar_contig_layout.json"))
logger.info(f" -> polar_contig_layout.json ({time.time() - t_step:.1f}s)")
t_step = time.time()
compute_radar_axes(protein_index, os.path.join(output_dir, "radar_axes.json"))
logger.info(f" -> radar_axes.json ({time.time() - t_step:.1f}s)")
dt = time.time() - t_total
logger.info(f"=== All precomputation done in {dt:.1f}s ===")
# List output files
for f in sorted(Path(output_dir).glob("*")):
size_mb = f.stat().st_size / 1024 / 1024
logger.info(f" {f.name}: {size_mb:.2f} MB")
if __name__ == "__main__":
main()