File size: 5,190 Bytes
16e4ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14ba315
 
16e4ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14ba315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16e4ad5
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
"""
CLI script to generate all precomputed data.
Usage: python scripts/run_precompute.py --data-dir data/ --output-dir precomputed/
"""

import argparse
import sys
import os
import time

# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.data_loader import (
    load_pav, parse_gff_genes, parse_protein_fasta,
    build_contig_index, build_contig_name_mapping, validate_joins,
)
from src.precompute import (
    compute_gene_frequency, compute_line_stats, compute_line_embedding,
    compute_similarity_topk, build_gff_gene_parquet, build_protein_parquet,
    save_contig_index, compute_hotspot_bins, compute_cluster_markers,
    compute_line_embedding_3d, build_sunburst_data, build_polar_contig_layout,
    compute_radar_axes,
)
from src.utils import logger, find_file


def main():
    parser = argparse.ArgumentParser(description="Precompute pangenome data")
    parser.add_argument("--data-dir", default="data/", help="Input data directory")
    parser.add_argument("--output-dir", default="precomputed/", help="Output directory")
    args = parser.parse_args()

    data_dir = os.path.abspath(args.data_dir)
    output_dir = os.path.abspath(args.output_dir)
    os.makedirs(output_dir, exist_ok=True)

    t_total = time.time()

    # 1. Load raw data
    logger.info("=== Phase 1: Loading raw data ===")
    pav_path = os.path.join(data_dir, "89_line_PAV.txt")
    from pathlib import Path
    data_p = Path(data_dir)

    gff_files = list(data_p.glob("*.gff"))
    protein_files = list(data_p.glob("*protein*.fasta"))
    genome_files = [f for f in data_p.glob("*.fasta") if "protein" not in f.name]

    if not gff_files:
        logger.error("No GFF file found in data directory")
        sys.exit(1)
    if not protein_files:
        logger.error("No protein FASTA file found in data directory")
        sys.exit(1)

    pav = load_pav(pav_path)
    gff_genes = parse_gff_genes(str(gff_files[0]))
    protein_index = parse_protein_fasta(str(protein_files[0]))

    contig_index = {}
    if genome_files:
        contig_index = build_contig_index(str(genome_files[0]))
    else:
        logger.warning("No genome FASTA found; contig index will be empty")

    # Validation
    logger.info("=== Validation ===")
    contig_mapping = build_contig_name_mapping(gff_genes, contig_index)
    report = validate_joins(pav, gff_genes, protein_index, contig_index)
    for k, v in report.items():
        logger.info(f"  {k}: {v}")

    # 2. Compute derived data
    logger.info("=== Phase 2: Computing derived data ===")

    gene_freq = compute_gene_frequency(pav)
    gene_freq.to_parquet(os.path.join(output_dir, "pav_gene_frequency.parquet"), index=False)

    line_stats = compute_line_stats(pav)
    line_stats.to_parquet(os.path.join(output_dir, "line_stats.parquet"), index=False)

    embedding = compute_line_embedding(pav)
    embedding.to_parquet(os.path.join(output_dir, "line_embedding.parquet"), index=False)

    similarity = compute_similarity_topk(pav, k=15)
    similarity.to_parquet(os.path.join(output_dir, "line_similarity_topk.parquet"), index=False)

    build_gff_gene_parquet(gff_genes, os.path.join(output_dir, "gff_gene_index.parquet"))
    build_protein_parquet(protein_index, os.path.join(output_dir, "protein_index.parquet"))
    save_contig_index(contig_index, contig_mapping, os.path.join(output_dir, "genome_contig_index.json"))

    hotspots = compute_hotspot_bins(gff_genes, gene_freq, contig_index)
    hotspots.to_parquet(os.path.join(output_dir, "hotspot_bins.parquet"), index=False)

    markers = compute_cluster_markers(pav, embedding)
    markers.to_parquet(os.path.join(output_dir, "cluster_markers.parquet"), index=False)

    # Also save the PAV matrix as parquet for efficient loading
    pav.to_parquet(os.path.join(output_dir, "pav_matrix.parquet"))

    # 3. New derived data for UI overhaul
    logger.info("=== Phase 3: New UI overhaul artifacts ===")

    t_step = time.time()
    embedding_3d = compute_line_embedding_3d(pav, embedding)
    embedding_3d.to_parquet(os.path.join(output_dir, "line_embedding_3d.parquet"), index=False)
    logger.info(f"  -> line_embedding_3d.parquet ({time.time() - t_step:.1f}s)")

    t_step = time.time()
    build_sunburst_data(gene_freq, os.path.join(output_dir, "sunburst_hierarchy.json"))
    logger.info(f"  -> sunburst_hierarchy.json ({time.time() - t_step:.1f}s)")

    t_step = time.time()
    build_polar_contig_layout(hotspots, contig_index,
                              os.path.join(output_dir, "polar_contig_layout.json"))
    logger.info(f"  -> polar_contig_layout.json ({time.time() - t_step:.1f}s)")

    t_step = time.time()
    compute_radar_axes(protein_index, os.path.join(output_dir, "radar_axes.json"))
    logger.info(f"  -> radar_axes.json ({time.time() - t_step:.1f}s)")

    dt = time.time() - t_total
    logger.info(f"=== All precomputation done in {dt:.1f}s ===")

    # List output files
    for f in sorted(Path(output_dir).glob("*")):
        size_mb = f.stat().st_size / 1024 / 1024
        logger.info(f"  {f.name}: {size_mb:.2f} MB")


if __name__ == "__main__":
    main()