#!/usr/bin/env python3 """ CLI wrapper for KAUST Infectious Diseases Genomic Risk Prediction Allows running predictions from command line without Streamlit UI """ import re import gzip import argparse import sys import warnings from pathlib import Path import numpy as np import pandas as pd import joblib import xgboost as xgb import shap import matplotlib.pyplot as plt from scipy import stats from concurrent.futures import ProcessPoolExecutor, as_completed import multiprocessing import subprocess import os from ngboost import NGBRegressor from ngboost.distns import Normal # Suppress XGBoost serialization warnings warnings.filterwarnings("ignore", category=UserWarning, module="pickle") # ------------------- CONFIG ------------------- UNITIGS_CSV_DEATH = "./Unitigs_predictor_DEATH.csv" MODEL_PATH_DEATH = "./xgb_fold1_8_Death.joblib" UNITIGS_CSV_ICU = "./Unitigs_predictor_ICU.csv" MODEL_PATH_ICU = "./xgb_fold5_8_ICU.joblib" UNITIGS_CSV_LOS = "./Unitigs_predictor_los.csv" MODEL_PATH_LOS = "./Unitig_model_ngb_log1p_fold1.joblib" UNITIGS_ARE_COLUMNS = True UNITIGS_SEQ_COLUMN = "unitig" CARD_FASTA = "./CARD.fasta" VFDB_FASTA = "./VFDB.fasta" CARD_DB = "./card_db" VFDB_DB = "./vfdb_db" BLAST_MAX_RESULTS = 20 # ------------------- HELPERS ------------------- def reverse_complement(seq: str) -> str: comp = str.maketrans("ACGTacgtnN", "TGCAtgcanN") return seq.translate(comp)[::-1] def load_unitigs(unitigs_csv: str, are_columns: bool, seq_col: str): """Load unitigs from CSV file.""" df = pd.read_csv(unitigs_csv) if are_columns: dna_cols = [] for c in df.columns: if isinstance(c, str) and re.fullmatch(r"[ACGTNacgtn]+", c) and len(c) >= 5: dna_cols.append(c) if not dna_cols: dna_cols = list(df.columns) unitigs = dna_cols else: if seq_col not in df.columns: raise ValueError(f"Column '{seq_col}' not found in {unitigs_csv}") unitigs = df[seq_col].astype(str).tolist() # De-duplicate preserving order seen, ordered = set(), [] for u in unitigs: if u not in seen: seen.add(u) ordered.append(u) return ordered def parse_fasta(file_bytes: bytes) -> dict: """Parse FASTA sequences from bytes.""" text = file_bytes.decode(errors="ignore") seqs = {} header, chunks = None, [] for line in text.splitlines(): if not line: continue if line.startswith(">"): if header is not None: seqs[header] = "".join(chunks) header = line[1:].strip() chunks = [] else: chunks.append(line.strip()) if header is not None: seqs[header] = "".join(chunks) return seqs def concat_sequences(fasta_dict: dict) -> str: """Concatenate all sequences with separator.""" return "NNNNN".join(fasta_dict.values()) def unitig_presence_in_text_single(args): """Helper for parallel unitig scanning.""" unitigs_chunk, genome_text = args genome_text_upper = genome_text.upper() calls = [] for u in unitigs_chunk: u_upper = u.upper() rc = reverse_complement(u_upper) present = (u_upper in genome_text_upper) or (rc in genome_text_upper) calls.append(1 if present else 0) return calls def unitig_presence_in_text_parallel(unitigs, genome_text, n_jobs=None): """Parallel unitig scanning with optional progress callback.""" if n_jobs is None: n_jobs = max(1, multiprocessing.cpu_count() - 1) chunk_size = int(np.ceil(len(unitigs) / n_jobs)) chunks = [unitigs[i:i + chunk_size] for i in range(0, len(unitigs), chunk_size)] with ProcessPoolExecutor(max_workers=n_jobs) as executor: futures = [executor.submit(unitig_presence_in_text_single, (chunk, genome_text)) for chunk in chunks] results = [] for idx, future in enumerate(futures): result = future.result() results.append(result) calls = [c for chunk_result in results for c in chunk_result] return calls def wilson_ci_vectorized(p: np.ndarray, n_eff: int = 200, z: float = 1.96): """Fast, vectorized 95% CI via Wilson interval.""" p = np.clip(p, 1e-9, 1 - 1e-9) z2 = z ** 2 denom = 1 + z2 / n_eff center = (p + z2 / (2 * n_eff)) / denom margin = z * np.sqrt(p * (1 - p) / n_eff + z2 / (4 * n_eff**2)) / denom lo = np.clip(center - margin, 0, 1) hi = np.clip(center + margin, 0, 1) return lo, hi def get_z_score(ci_level: int) -> float: """Return z-score for given confidence level.""" z_scores = {90: 1.645, 95: 1.96, 99: 2.576} return z_scores.get(ci_level, 1.96) def predict_los_distribution(model, X): """Extract predicted distribution from NGBRegressor.""" if hasattr(model, 'pred_dist'): dist = model.pred_dist(X) return dist else: preds = model.predict(X) return preds def scan_genomes(files, unitigs): """Scan multiple FASTA files for unitig presence.""" rows = [] print(f"šŸ“ Processing {len(files)} files...") for i, file_path in enumerate(files, start=1): print(f" [{i}/{len(files)}] {Path(file_path).name}...", end=" ") with open(file_path, "rb") as f: raw = f.read() if file_path.endswith(".gz"): raw = gzip.decompress(raw) fasta_dict = parse_fasta(raw) if not fasta_dict: print("āš ļø No sequences found") continue concat = concat_sequences(fasta_dict) calls = unitig_presence_in_text_parallel(unitigs, concat) row = {"sample": Path(file_path).stem} row.update({u: c for u, c in zip(unitigs, calls)}) rows.append(row) print("āœ…") pa_df = pd.DataFrame(rows, columns=["sample"] + unitigs) pa_df[unitigs] = pa_df[unitigs].astype(np.uint8) return pa_df def run_binary_prediction(pa_df, unitigs, model_path, outcome, n_eff, threshold): """Run binary classification prediction (Death/ICU).""" print("\nšŸ”„ Loading model...") model = joblib.load(model_path) print("āœ… Model loaded") X = pa_df[unitigs].astype(np.float32) print("🧬 Running inference...") if hasattr(model, "predict_proba"): proba = model.predict_proba(X) prob = proba[:, 1] if proba.shape[1] > 1 else np.zeros(len(X), dtype=float) else: pred_raw = model.predict(X) uniq = np.unique(pred_raw) if set(uniq) - {0, 1}: mapping = {uniq.min(): 0, uniq.max(): 1} prob = np.vectorize(mapping.get)(pred_raw).astype(float) else: prob = pred_raw.astype(float) print("āœ… Inference complete") print("šŸ“Š Computing confidence intervals...") ci_lo, ci_hi = wilson_ci_vectorized(prob, n_eff=n_eff, z=1.96) pred = (prob >= threshold).astype(int) print("āœ… CIs computed") results = pd.DataFrame({ "Sample": pa_df["sample"], f"Predicted_Probability_{outcome}": prob, "CI_95_Lower": ci_lo, "CI_95_Upper": ci_hi, f"Prediction_threshold_{threshold:.2f}": pred }) return results, model, X def run_los_prediction(pa_df, unitigs, model_path): """Run LOS prediction with uncertainty quantification.""" print("\nšŸ”„ Loading model...") model = joblib.load(model_path) print("āœ… Model loaded") X = pa_df[unitigs].astype(np.float32) print("🧬 Running inference...") try: pred_dist = predict_los_distribution(model, X) if hasattr(pred_dist, 'mean'): mean_los = pred_dist.mean() std_los = pred_dist.std() else: mean_los = pred_dist std_los = np.ones_like(pred_dist) * np.std(pred_dist) print("āœ… Inference complete") print("šŸ“Š Computing prediction intervals...") pi_levels = [90, 95, 99] mean_los_original = np.expm1(mean_los) results_dict = { "Sample": pa_df["sample"], "Predicted_LOS_days": mean_los_original, "Std_Dev_log_scale": std_los } for pi_level in pi_levels: z_score = get_z_score(pi_level) pi_lo_log = mean_los - z_score * std_los pi_hi_log = mean_los + z_score * std_los pi_lo_original = np.expm1(pi_lo_log) pi_hi_original = np.expm1(pi_hi_log) pi_lo_original = np.maximum(pi_lo_original, 0) results_dict[f"{pi_level}pct_PI_Lower"] = pi_lo_original results_dict[f"{pi_level}pct_PI_Upper"] = pi_hi_original results = pd.DataFrame(results_dict) print("āœ… PIs computed") return results, model, X except Exception as e: print(f"āŒ LOS prediction failed: {e}") sys.exit(1) def run_shap_analysis(model, X, unitigs, pa_df, top_n=20): """Run SHAP analysis to identify predictive biomarkers.""" print("\nšŸ’” Computing SHAP values...") try: explainer = shap.TreeExplainer(model, feature_names=unitigs) shap_vals = explainer(X) print("āœ… SHAP values computed") # Create summary summary_data = [] for i, sample_name in enumerate(pa_df["sample"]): sv = shap_vals[i] shap_values_abs = np.abs(sv.values) top_indices = np.argsort(shap_values_abs)[-top_n:][::-1] for rank, idx in enumerate(top_indices, 1): feature_name = sv.feature_names[idx] if hasattr(sv, 'feature_names') else unitigs[idx] summary_data.append({ "Sample": sample_name, "Rank": rank, "Biomarker": feature_name, "SHAP_Value": sv.values[idx], "SHAP_Abs": shap_values_abs[idx] }) shap_summary = pd.DataFrame(summary_data) print(f"āœ… Top {top_n} biomarkers identified per sample") return shap_summary except Exception as e: print(f"āŒ SHAP analysis failed: {e}") return None def create_blast_databases(): """Create BLAST databases if they don't exist.""" if not os.path.exists(CARD_DB + ".nin"): if os.path.exists(CARD_FASTA): print("šŸ“š Creating CARD BLAST database...") try: subprocess.run( ["makeblastdb", "-in", CARD_FASTA, "-dbtype", "nucl", "-out", CARD_DB], check=True, capture_output=True ) print("āœ… CARD database created") except Exception as e: print(f"āš ļø Could not create CARD BLAST database: {e}") if not os.path.exists(VFDB_DB + ".nin"): if os.path.exists(VFDB_FASTA): print("šŸ“š Creating VFDB BLAST database...") try: subprocess.run( ["makeblastdb", "-in", VFDB_FASTA, "-dbtype", "nucl", "-out", VFDB_DB], check=True, capture_output=True ) print("āœ… VFDB database created") except Exception as e: print(f"āš ļø Could not create VFDB BLAST database: {e}") def blast_unitig(unitig_seq, unitig_id, db_path, db_name): """Run BLAST search for a unitig sequence.""" try: import time query_file = f"/tmp/query_{db_name}_{int(time.time()*1000)}.fasta" with open(query_file, "w") as f: f.write(f">{unitig_id}\n{unitig_seq}\n") result = subprocess.run( [ "blastn", "-query", query_file, "-db", db_path, "-max_target_seqs", "1", "-outfmt", "6 qseqid sseqid pident length mismatch gapopen qstart qend sstart send evalue bitscore" ], capture_output=True, text=True, timeout=30 ) if os.path.exists(query_file): os.remove(query_file) if result.stdout and result.stdout.strip(): line = result.stdout.strip().split('\n')[0] parts = line.split('\t') if len(parts) >= 12: hit = { 'Unitig_Sequence': unitig_seq, 'Subject': parts[1][:100], 'Identity_pct': float(parts[2]), 'Length': int(parts[3]), 'Mismatches': int(parts[4]), 'Gaps': int(parts[5]), 'Query_Start': int(parts[6]), 'Query_End': int(parts[7]), 'Subject_Start': int(parts[8]), 'Subject_End': int(parts[9]), 'Evalue': float(parts[10]), 'Bitscore': float(parts[11]), 'Database': db_name.upper() } return hit return None except subprocess.TimeoutExpired: print(f"āš ļø BLAST timeout for unitig {unitig_id[:30]}") return None except Exception as e: print(f"āš ļø BLAST search issue: {str(e)[:100]}") return None def run_blast_annotation(shap_summary, unitigs, top_n=50): """Run BLAST annotation on top biomarkers.""" print("\nšŸ” Running BLAST annotation...") create_blast_databases() # Get unique biomarkers from top results top_biomarkers = shap_summary.nlargest(top_n, 'SHAP_Abs')['Biomarker'].unique() all_results = [] for idx, unitig in enumerate(top_biomarkers): card_hit = blast_unitig(unitig, f"unitig_{idx}", CARD_DB, "card") if card_hit: all_results.append(card_hit) vfdb_hit = blast_unitig(unitig, f"unitig_{idx}", VFDB_DB, "vfdb") if vfdb_hit: all_results.append(vfdb_hit) if all_results: results_df = pd.DataFrame(all_results) print(f"āœ… BLAST complete - {len(results_df)} hits found") return results_df else: print("āš ļø No BLAST hits found") return None # ------------------- MAIN CLI ------------------- def main(): parser = argparse.ArgumentParser( description="KAUST Genomic Risk Prediction CLI", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Mortality prediction python cli_wrapper.py -i sample1.fasta sample2.fasta -o mortality.csv -t death # ICU prediction with SHAP analysis python cli_wrapper.py -i genomes/*.fasta -o icu.csv -t icu --shap --top-biomarkers 15 # LOS prediction with BLAST annotation python cli_wrapper.py -i genome.fasta -o los.csv -t los --shap --blast """ ) parser.add_argument( "-i", "--input", nargs="+", required=True, help="Input FASTA files (can use wildcards: genomes/*.fasta)" ) parser.add_argument( "-o", "--output", default="predictions.csv", help="Output CSV file for predictions (default: predictions.csv)" ) parser.add_argument( "-t", "--outcome", choices=["death", "icu", "los"], required=True, help="Prediction outcome: death (mortality), icu, or los (length of stay)" ) parser.add_argument( "--threshold", type=float, default=0.5, help="Decision threshold for binary predictions (default: 0.5)" ) parser.add_argument( "--n-eff", type=int, default=200, help="Uncertainty strength for CI (default: 200)" ) parser.add_argument( "--shap", action="store_true", help="Run SHAP analysis for predictive biomarker identification" ) parser.add_argument( "--top-biomarkers", type=int, default=20, help="Number of top biomarkers to display in SHAP analysis (default: 20)" ) parser.add_argument( "--blast", action="store_true", help="Run BLAST annotation against CARD and VFDB databases" ) parser.add_argument( "--shap-output", default=None, help="Output file for SHAP results (default: predictions_shap.csv)" ) parser.add_argument( "--blast-output", default=None, help="Output file for BLAST results (default: predictions_blast.csv)" ) args = parser.parse_args() # Expand wildcards and validate files from glob import glob files = [] for pattern in args.input: expanded = glob(pattern) if expanded: files.extend(expanded) elif os.path.isfile(pattern): files.append(pattern) if not files: print("āŒ No FASTA files found") sys.exit(1) print(f"\n{'='*60}") print(f"🧬 KAUST Genomic Risk Prediction - CLI Tool") print(f"{'='*60}") print(f"šŸ“‹ Outcome: {args.outcome.upper()}") print(f"šŸ“ Input files: {len(files)}") print(f"šŸ’¾ Output: {args.output}") if args.shap: print(f"šŸ“Š SHAP analysis: Enabled (top {args.top_biomarkers} biomarkers)") if args.blast: print(f"šŸ” BLAST annotation: Enabled") print(f"{'='*60}\n") # Select outcome and paths if args.outcome == "death": unitigs_csv = UNITIGS_CSV_DEATH model_path = MODEL_PATH_DEATH outcome_label = "mortality" is_los = False elif args.outcome == "icu": unitigs_csv = UNITIGS_CSV_ICU model_path = MODEL_PATH_ICU outcome_label = "icu_admission" is_los = False else: # los unitigs_csv = UNITIGS_CSV_LOS model_path = MODEL_PATH_LOS outcome_label = "los" is_los = True # Check if required files exist if not os.path.exists(unitigs_csv): print(f"āŒ Unitigs CSV not found: {unitigs_csv}") sys.exit(1) if not os.path.exists(model_path): print(f"āŒ Model not found: {model_path}") sys.exit(1) print("šŸ“– Loading unitigs...") unitigs = load_unitigs(unitigs_csv, UNITIGS_ARE_COLUMNS, UNITIGS_SEQ_COLUMN) print(f"āœ… Loaded {len(unitigs)} unitigs") # Scan genomes pa_df = scan_genomes(files, unitigs) if len(pa_df) == 0: print("āŒ No valid genomes processed") sys.exit(1) print(f"āœ… Scanned {len(pa_df)} genomes") # Run prediction if is_los: results, model, X = run_los_prediction(pa_df, unitigs, model_path) else: results, model, X = run_binary_prediction( pa_df, unitigs, model_path, outcome_label, args.n_eff, args.threshold ) # Save predictions results.to_csv(args.output, index=False) print(f"\nāœ… Predictions saved to: {args.output}") # SHAP analysis if args.shap: shap_output = args.shap_output or f"{Path(args.output).stem}_shap.csv" shap_results = run_shap_analysis(model, X, unitigs, pa_df, args.top_biomarkers) if shap_results is not None: shap_results.to_csv(shap_output, index=False) print(f"āœ… SHAP results saved to: {shap_output}") # BLAST annotation if args.blast: blast_output = args.blast_output or f"{Path(args.output).stem}_blast.csv" blast_results = run_blast_annotation(shap_results, unitigs, args.top_biomarkers) if blast_results is not None: blast_results.to_csv(blast_output, index=False) print(f"āœ… BLAST results saved to: {blast_output}") print(f"\n{'='*60}") print("šŸŽ‰ Analysis complete!") print(f"{'='*60}\n") if __name__ == "__main__": main()