|
|
|
|
|
""" |
|
|
CLI wrapper for KAUST Infectious Diseases Genomic Risk Prediction |
|
|
Allows running predictions from command line without Streamlit UI |
|
|
""" |
|
|
|
|
|
import re |
|
|
import gzip |
|
|
import argparse |
|
|
import sys |
|
|
import warnings |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import joblib |
|
|
import xgboost as xgb |
|
|
import shap |
|
|
import matplotlib.pyplot as plt |
|
|
from scipy import stats |
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed |
|
|
import multiprocessing |
|
|
import subprocess |
|
|
import os |
|
|
from ngboost import NGBRegressor |
|
|
from ngboost.distns import Normal |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="pickle") |
|
|
|
|
|
|
|
|
UNITIGS_CSV_DEATH = "./Unitigs_predictor_DEATH.csv" |
|
|
MODEL_PATH_DEATH = "./xgb_fold1_8_Death.joblib" |
|
|
|
|
|
UNITIGS_CSV_ICU = "./Unitigs_predictor_ICU.csv" |
|
|
MODEL_PATH_ICU = "./xgb_fold5_8_ICU.joblib" |
|
|
|
|
|
UNITIGS_CSV_LOS = "./Unitigs_predictor_los.csv" |
|
|
MODEL_PATH_LOS = "./Unitig_model_ngb_log1p_fold1.joblib" |
|
|
|
|
|
UNITIGS_ARE_COLUMNS = True |
|
|
UNITIGS_SEQ_COLUMN = "unitig" |
|
|
|
|
|
CARD_FASTA = "./CARD.fasta" |
|
|
VFDB_FASTA = "./VFDB.fasta" |
|
|
CARD_DB = "./card_db" |
|
|
VFDB_DB = "./vfdb_db" |
|
|
BLAST_MAX_RESULTS = 20 |
|
|
|
|
|
|
|
|
def reverse_complement(seq: str) -> str: |
|
|
comp = str.maketrans("ACGTacgtnN", "TGCAtgcanN") |
|
|
return seq.translate(comp)[::-1] |
|
|
|
|
|
def load_unitigs(unitigs_csv: str, are_columns: bool, seq_col: str): |
|
|
"""Load unitigs from CSV file.""" |
|
|
df = pd.read_csv(unitigs_csv) |
|
|
if are_columns: |
|
|
dna_cols = [] |
|
|
for c in df.columns: |
|
|
if isinstance(c, str) and re.fullmatch(r"[ACGTNacgtn]+", c) and len(c) >= 5: |
|
|
dna_cols.append(c) |
|
|
if not dna_cols: |
|
|
dna_cols = list(df.columns) |
|
|
unitigs = dna_cols |
|
|
else: |
|
|
if seq_col not in df.columns: |
|
|
raise ValueError(f"Column '{seq_col}' not found in {unitigs_csv}") |
|
|
unitigs = df[seq_col].astype(str).tolist() |
|
|
|
|
|
|
|
|
seen, ordered = set(), [] |
|
|
for u in unitigs: |
|
|
if u not in seen: |
|
|
seen.add(u) |
|
|
ordered.append(u) |
|
|
return ordered |
|
|
|
|
|
def parse_fasta(file_bytes: bytes) -> dict: |
|
|
"""Parse FASTA sequences from bytes.""" |
|
|
text = file_bytes.decode(errors="ignore") |
|
|
seqs = {} |
|
|
header, chunks = None, [] |
|
|
for line in text.splitlines(): |
|
|
if not line: |
|
|
continue |
|
|
if line.startswith(">"): |
|
|
if header is not None: |
|
|
seqs[header] = "".join(chunks) |
|
|
header = line[1:].strip() |
|
|
chunks = [] |
|
|
else: |
|
|
chunks.append(line.strip()) |
|
|
if header is not None: |
|
|
seqs[header] = "".join(chunks) |
|
|
return seqs |
|
|
|
|
|
def concat_sequences(fasta_dict: dict) -> str: |
|
|
"""Concatenate all sequences with separator.""" |
|
|
return "NNNNN".join(fasta_dict.values()) |
|
|
|
|
|
def unitig_presence_in_text_single(args): |
|
|
"""Helper for parallel unitig scanning.""" |
|
|
unitigs_chunk, genome_text = args |
|
|
genome_text_upper = genome_text.upper() |
|
|
calls = [] |
|
|
for u in unitigs_chunk: |
|
|
u_upper = u.upper() |
|
|
rc = reverse_complement(u_upper) |
|
|
present = (u_upper in genome_text_upper) or (rc in genome_text_upper) |
|
|
calls.append(1 if present else 0) |
|
|
return calls |
|
|
|
|
|
def unitig_presence_in_text_parallel(unitigs, genome_text, n_jobs=None): |
|
|
"""Parallel unitig scanning with optional progress callback.""" |
|
|
if n_jobs is None: |
|
|
n_jobs = max(1, multiprocessing.cpu_count() - 1) |
|
|
|
|
|
chunk_size = int(np.ceil(len(unitigs) / n_jobs)) |
|
|
chunks = [unitigs[i:i + chunk_size] for i in range(0, len(unitigs), chunk_size)] |
|
|
|
|
|
with ProcessPoolExecutor(max_workers=n_jobs) as executor: |
|
|
futures = [executor.submit(unitig_presence_in_text_single, (chunk, genome_text)) for chunk in chunks] |
|
|
results = [] |
|
|
for idx, future in enumerate(futures): |
|
|
result = future.result() |
|
|
results.append(result) |
|
|
|
|
|
calls = [c for chunk_result in results for c in chunk_result] |
|
|
return calls |
|
|
|
|
|
def wilson_ci_vectorized(p: np.ndarray, n_eff: int = 200, z: float = 1.96): |
|
|
"""Fast, vectorized 95% CI via Wilson interval.""" |
|
|
p = np.clip(p, 1e-9, 1 - 1e-9) |
|
|
z2 = z ** 2 |
|
|
denom = 1 + z2 / n_eff |
|
|
center = (p + z2 / (2 * n_eff)) / denom |
|
|
margin = z * np.sqrt(p * (1 - p) / n_eff + z2 / (4 * n_eff**2)) / denom |
|
|
lo = np.clip(center - margin, 0, 1) |
|
|
hi = np.clip(center + margin, 0, 1) |
|
|
return lo, hi |
|
|
|
|
|
def get_z_score(ci_level: int) -> float: |
|
|
"""Return z-score for given confidence level.""" |
|
|
z_scores = {90: 1.645, 95: 1.96, 99: 2.576} |
|
|
return z_scores.get(ci_level, 1.96) |
|
|
|
|
|
def predict_los_distribution(model, X): |
|
|
"""Extract predicted distribution from NGBRegressor.""" |
|
|
if hasattr(model, 'pred_dist'): |
|
|
dist = model.pred_dist(X) |
|
|
return dist |
|
|
else: |
|
|
preds = model.predict(X) |
|
|
return preds |
|
|
|
|
|
def scan_genomes(files, unitigs): |
|
|
"""Scan multiple FASTA files for unitig presence.""" |
|
|
rows = [] |
|
|
print(f"π Processing {len(files)} files...") |
|
|
|
|
|
for i, file_path in enumerate(files, start=1): |
|
|
print(f" [{i}/{len(files)}] {Path(file_path).name}...", end=" ") |
|
|
|
|
|
with open(file_path, "rb") as f: |
|
|
raw = f.read() |
|
|
|
|
|
if file_path.endswith(".gz"): |
|
|
raw = gzip.decompress(raw) |
|
|
|
|
|
fasta_dict = parse_fasta(raw) |
|
|
if not fasta_dict: |
|
|
print("β οΈ No sequences found") |
|
|
continue |
|
|
|
|
|
concat = concat_sequences(fasta_dict) |
|
|
calls = unitig_presence_in_text_parallel(unitigs, concat) |
|
|
|
|
|
row = {"sample": Path(file_path).stem} |
|
|
row.update({u: c for u, c in zip(unitigs, calls)}) |
|
|
rows.append(row) |
|
|
print("β
") |
|
|
|
|
|
pa_df = pd.DataFrame(rows, columns=["sample"] + unitigs) |
|
|
pa_df[unitigs] = pa_df[unitigs].astype(np.uint8) |
|
|
return pa_df |
|
|
|
|
|
def run_binary_prediction(pa_df, unitigs, model_path, outcome, n_eff, threshold): |
|
|
"""Run binary classification prediction (Death/ICU).""" |
|
|
print("\nπ Loading model...") |
|
|
model = joblib.load(model_path) |
|
|
print("β
Model loaded") |
|
|
|
|
|
X = pa_df[unitigs].astype(np.float32) |
|
|
|
|
|
print("𧬠Running inference...") |
|
|
if hasattr(model, "predict_proba"): |
|
|
proba = model.predict_proba(X) |
|
|
prob = proba[:, 1] if proba.shape[1] > 1 else np.zeros(len(X), dtype=float) |
|
|
else: |
|
|
pred_raw = model.predict(X) |
|
|
uniq = np.unique(pred_raw) |
|
|
if set(uniq) - {0, 1}: |
|
|
mapping = {uniq.min(): 0, uniq.max(): 1} |
|
|
prob = np.vectorize(mapping.get)(pred_raw).astype(float) |
|
|
else: |
|
|
prob = pred_raw.astype(float) |
|
|
print("β
Inference complete") |
|
|
|
|
|
print("π Computing confidence intervals...") |
|
|
ci_lo, ci_hi = wilson_ci_vectorized(prob, n_eff=n_eff, z=1.96) |
|
|
pred = (prob >= threshold).astype(int) |
|
|
print("β
CIs computed") |
|
|
|
|
|
results = pd.DataFrame({ |
|
|
"Sample": pa_df["sample"], |
|
|
f"Predicted_Probability_{outcome}": prob, |
|
|
"CI_95_Lower": ci_lo, |
|
|
"CI_95_Upper": ci_hi, |
|
|
f"Prediction_threshold_{threshold:.2f}": pred |
|
|
}) |
|
|
|
|
|
return results, model, X |
|
|
|
|
|
def run_los_prediction(pa_df, unitigs, model_path): |
|
|
"""Run LOS prediction with uncertainty quantification.""" |
|
|
print("\nπ Loading model...") |
|
|
model = joblib.load(model_path) |
|
|
print("β
Model loaded") |
|
|
|
|
|
X = pa_df[unitigs].astype(np.float32) |
|
|
|
|
|
print("𧬠Running inference...") |
|
|
try: |
|
|
pred_dist = predict_los_distribution(model, X) |
|
|
if hasattr(pred_dist, 'mean'): |
|
|
mean_los = pred_dist.mean() |
|
|
std_los = pred_dist.std() |
|
|
else: |
|
|
mean_los = pred_dist |
|
|
std_los = np.ones_like(pred_dist) * np.std(pred_dist) |
|
|
print("β
Inference complete") |
|
|
|
|
|
print("π Computing prediction intervals...") |
|
|
pi_levels = [90, 95, 99] |
|
|
mean_los_original = np.expm1(mean_los) |
|
|
|
|
|
results_dict = { |
|
|
"Sample": pa_df["sample"], |
|
|
"Predicted_LOS_days": mean_los_original, |
|
|
"Std_Dev_log_scale": std_los |
|
|
} |
|
|
|
|
|
for pi_level in pi_levels: |
|
|
z_score = get_z_score(pi_level) |
|
|
pi_lo_log = mean_los - z_score * std_los |
|
|
pi_hi_log = mean_los + z_score * std_los |
|
|
pi_lo_original = np.expm1(pi_lo_log) |
|
|
pi_hi_original = np.expm1(pi_hi_log) |
|
|
pi_lo_original = np.maximum(pi_lo_original, 0) |
|
|
results_dict[f"{pi_level}pct_PI_Lower"] = pi_lo_original |
|
|
results_dict[f"{pi_level}pct_PI_Upper"] = pi_hi_original |
|
|
|
|
|
results = pd.DataFrame(results_dict) |
|
|
print("β
PIs computed") |
|
|
return results, model, X |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β LOS prediction failed: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
def run_shap_analysis(model, X, unitigs, pa_df, top_n=20): |
|
|
"""Run SHAP analysis to identify predictive biomarkers.""" |
|
|
print("\nπ‘ Computing SHAP values...") |
|
|
try: |
|
|
explainer = shap.TreeExplainer(model, feature_names=unitigs) |
|
|
shap_vals = explainer(X) |
|
|
print("β
SHAP values computed") |
|
|
|
|
|
|
|
|
summary_data = [] |
|
|
for i, sample_name in enumerate(pa_df["sample"]): |
|
|
sv = shap_vals[i] |
|
|
shap_values_abs = np.abs(sv.values) |
|
|
top_indices = np.argsort(shap_values_abs)[-top_n:][::-1] |
|
|
|
|
|
for rank, idx in enumerate(top_indices, 1): |
|
|
feature_name = sv.feature_names[idx] if hasattr(sv, 'feature_names') else unitigs[idx] |
|
|
summary_data.append({ |
|
|
"Sample": sample_name, |
|
|
"Rank": rank, |
|
|
"Biomarker": feature_name, |
|
|
"SHAP_Value": sv.values[idx], |
|
|
"SHAP_Abs": shap_values_abs[idx] |
|
|
}) |
|
|
|
|
|
shap_summary = pd.DataFrame(summary_data) |
|
|
print(f"β
Top {top_n} biomarkers identified per sample") |
|
|
return shap_summary |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β SHAP analysis failed: {e}") |
|
|
return None |
|
|
|
|
|
def create_blast_databases(): |
|
|
"""Create BLAST databases if they don't exist.""" |
|
|
if not os.path.exists(CARD_DB + ".nin"): |
|
|
if os.path.exists(CARD_FASTA): |
|
|
print("π Creating CARD BLAST database...") |
|
|
try: |
|
|
subprocess.run( |
|
|
["makeblastdb", "-in", CARD_FASTA, "-dbtype", "nucl", "-out", CARD_DB], |
|
|
check=True, |
|
|
capture_output=True |
|
|
) |
|
|
print("β
CARD database created") |
|
|
except Exception as e: |
|
|
print(f"β οΈ Could not create CARD BLAST database: {e}") |
|
|
|
|
|
if not os.path.exists(VFDB_DB + ".nin"): |
|
|
if os.path.exists(VFDB_FASTA): |
|
|
print("π Creating VFDB BLAST database...") |
|
|
try: |
|
|
subprocess.run( |
|
|
["makeblastdb", "-in", VFDB_FASTA, "-dbtype", "nucl", "-out", VFDB_DB], |
|
|
check=True, |
|
|
capture_output=True |
|
|
) |
|
|
print("β
VFDB database created") |
|
|
except Exception as e: |
|
|
print(f"β οΈ Could not create VFDB BLAST database: {e}") |
|
|
|
|
|
def blast_unitig(unitig_seq, unitig_id, db_path, db_name): |
|
|
"""Run BLAST search for a unitig sequence.""" |
|
|
try: |
|
|
import time |
|
|
query_file = f"/tmp/query_{db_name}_{int(time.time()*1000)}.fasta" |
|
|
with open(query_file, "w") as f: |
|
|
f.write(f">{unitig_id}\n{unitig_seq}\n") |
|
|
|
|
|
result = subprocess.run( |
|
|
[ |
|
|
"blastn", |
|
|
"-query", query_file, |
|
|
"-db", db_path, |
|
|
"-max_target_seqs", "1", |
|
|
"-outfmt", "6 qseqid sseqid pident length mismatch gapopen qstart qend sstart send evalue bitscore" |
|
|
], |
|
|
capture_output=True, |
|
|
text=True, |
|
|
timeout=30 |
|
|
) |
|
|
|
|
|
if os.path.exists(query_file): |
|
|
os.remove(query_file) |
|
|
|
|
|
if result.stdout and result.stdout.strip(): |
|
|
line = result.stdout.strip().split('\n')[0] |
|
|
parts = line.split('\t') |
|
|
|
|
|
if len(parts) >= 12: |
|
|
hit = { |
|
|
'Unitig_Sequence': unitig_seq, |
|
|
'Subject': parts[1][:100], |
|
|
'Identity_pct': float(parts[2]), |
|
|
'Length': int(parts[3]), |
|
|
'Mismatches': int(parts[4]), |
|
|
'Gaps': int(parts[5]), |
|
|
'Query_Start': int(parts[6]), |
|
|
'Query_End': int(parts[7]), |
|
|
'Subject_Start': int(parts[8]), |
|
|
'Subject_End': int(parts[9]), |
|
|
'Evalue': float(parts[10]), |
|
|
'Bitscore': float(parts[11]), |
|
|
'Database': db_name.upper() |
|
|
} |
|
|
return hit |
|
|
|
|
|
return None |
|
|
except subprocess.TimeoutExpired: |
|
|
print(f"β οΈ BLAST timeout for unitig {unitig_id[:30]}") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"β οΈ BLAST search issue: {str(e)[:100]}") |
|
|
return None |
|
|
|
|
|
def run_blast_annotation(shap_summary, unitigs, top_n=50): |
|
|
"""Run BLAST annotation on top biomarkers.""" |
|
|
print("\nπ Running BLAST annotation...") |
|
|
|
|
|
create_blast_databases() |
|
|
|
|
|
|
|
|
top_biomarkers = shap_summary.nlargest(top_n, 'SHAP_Abs')['Biomarker'].unique() |
|
|
|
|
|
all_results = [] |
|
|
for idx, unitig in enumerate(top_biomarkers): |
|
|
card_hit = blast_unitig(unitig, f"unitig_{idx}", CARD_DB, "card") |
|
|
if card_hit: |
|
|
all_results.append(card_hit) |
|
|
|
|
|
vfdb_hit = blast_unitig(unitig, f"unitig_{idx}", VFDB_DB, "vfdb") |
|
|
if vfdb_hit: |
|
|
all_results.append(vfdb_hit) |
|
|
|
|
|
if all_results: |
|
|
results_df = pd.DataFrame(all_results) |
|
|
print(f"β
BLAST complete - {len(results_df)} hits found") |
|
|
return results_df |
|
|
else: |
|
|
print("β οΈ No BLAST hits found") |
|
|
return None |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="KAUST Genomic Risk Prediction CLI", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
# Mortality prediction |
|
|
python cli_wrapper.py -i sample1.fasta sample2.fasta -o mortality.csv -t death |
|
|
|
|
|
# ICU prediction with SHAP analysis |
|
|
python cli_wrapper.py -i genomes/*.fasta -o icu.csv -t icu --shap --top-biomarkers 15 |
|
|
|
|
|
# LOS prediction with BLAST annotation |
|
|
python cli_wrapper.py -i genome.fasta -o los.csv -t los --shap --blast |
|
|
""" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"-i", "--input", |
|
|
nargs="+", |
|
|
required=True, |
|
|
help="Input FASTA files (can use wildcards: genomes/*.fasta)" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"-o", "--output", |
|
|
default="predictions.csv", |
|
|
help="Output CSV file for predictions (default: predictions.csv)" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"-t", "--outcome", |
|
|
choices=["death", "icu", "los"], |
|
|
required=True, |
|
|
help="Prediction outcome: death (mortality), icu, or los (length of stay)" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--threshold", |
|
|
type=float, |
|
|
default=0.5, |
|
|
help="Decision threshold for binary predictions (default: 0.5)" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--n-eff", |
|
|
type=int, |
|
|
default=200, |
|
|
help="Uncertainty strength for CI (default: 200)" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--shap", |
|
|
action="store_true", |
|
|
help="Run SHAP analysis for predictive biomarker identification" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--top-biomarkers", |
|
|
type=int, |
|
|
default=20, |
|
|
help="Number of top biomarkers to display in SHAP analysis (default: 20)" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--blast", |
|
|
action="store_true", |
|
|
help="Run BLAST annotation against CARD and VFDB databases" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--shap-output", |
|
|
default=None, |
|
|
help="Output file for SHAP results (default: predictions_shap.csv)" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--blast-output", |
|
|
default=None, |
|
|
help="Output file for BLAST results (default: predictions_blast.csv)" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
from glob import glob |
|
|
files = [] |
|
|
for pattern in args.input: |
|
|
expanded = glob(pattern) |
|
|
if expanded: |
|
|
files.extend(expanded) |
|
|
elif os.path.isfile(pattern): |
|
|
files.append(pattern) |
|
|
|
|
|
if not files: |
|
|
print("β No FASTA files found") |
|
|
sys.exit(1) |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"𧬠KAUST Genomic Risk Prediction - CLI Tool") |
|
|
print(f"{'='*60}") |
|
|
print(f"π Outcome: {args.outcome.upper()}") |
|
|
print(f"π Input files: {len(files)}") |
|
|
print(f"πΎ Output: {args.output}") |
|
|
if args.shap: |
|
|
print(f"π SHAP analysis: Enabled (top {args.top_biomarkers} biomarkers)") |
|
|
if args.blast: |
|
|
print(f"π BLAST annotation: Enabled") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
|
|
|
if args.outcome == "death": |
|
|
unitigs_csv = UNITIGS_CSV_DEATH |
|
|
model_path = MODEL_PATH_DEATH |
|
|
outcome_label = "mortality" |
|
|
is_los = False |
|
|
elif args.outcome == "icu": |
|
|
unitigs_csv = UNITIGS_CSV_ICU |
|
|
model_path = MODEL_PATH_ICU |
|
|
outcome_label = "icu_admission" |
|
|
is_los = False |
|
|
else: |
|
|
unitigs_csv = UNITIGS_CSV_LOS |
|
|
model_path = MODEL_PATH_LOS |
|
|
outcome_label = "los" |
|
|
is_los = True |
|
|
|
|
|
|
|
|
if not os.path.exists(unitigs_csv): |
|
|
print(f"β Unitigs CSV not found: {unitigs_csv}") |
|
|
sys.exit(1) |
|
|
if not os.path.exists(model_path): |
|
|
print(f"β Model not found: {model_path}") |
|
|
sys.exit(1) |
|
|
|
|
|
print("π Loading unitigs...") |
|
|
unitigs = load_unitigs(unitigs_csv, UNITIGS_ARE_COLUMNS, UNITIGS_SEQ_COLUMN) |
|
|
print(f"β
Loaded {len(unitigs)} unitigs") |
|
|
|
|
|
|
|
|
pa_df = scan_genomes(files, unitigs) |
|
|
|
|
|
if len(pa_df) == 0: |
|
|
print("β No valid genomes processed") |
|
|
sys.exit(1) |
|
|
|
|
|
print(f"β
Scanned {len(pa_df)} genomes") |
|
|
|
|
|
|
|
|
if is_los: |
|
|
results, model, X = run_los_prediction(pa_df, unitigs, model_path) |
|
|
else: |
|
|
results, model, X = run_binary_prediction( |
|
|
pa_df, unitigs, model_path, outcome_label, args.n_eff, args.threshold |
|
|
) |
|
|
|
|
|
|
|
|
results.to_csv(args.output, index=False) |
|
|
print(f"\nβ
Predictions saved to: {args.output}") |
|
|
|
|
|
|
|
|
if args.shap: |
|
|
shap_output = args.shap_output or f"{Path(args.output).stem}_shap.csv" |
|
|
shap_results = run_shap_analysis(model, X, unitigs, pa_df, args.top_biomarkers) |
|
|
|
|
|
if shap_results is not None: |
|
|
shap_results.to_csv(shap_output, index=False) |
|
|
print(f"β
SHAP results saved to: {shap_output}") |
|
|
|
|
|
|
|
|
if args.blast: |
|
|
blast_output = args.blast_output or f"{Path(args.output).stem}_blast.csv" |
|
|
blast_results = run_blast_annotation(shap_results, unitigs, args.top_biomarkers) |
|
|
|
|
|
if blast_results is not None: |
|
|
blast_results.to_csv(blast_output, index=False) |
|
|
print(f"β
BLAST results saved to: {blast_output}") |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("π Analysis complete!") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|