"""SVSPR wrapper around the trained RandomForest. Public surface: - SVSPR class : load + predict on DataFrame / DataFrame / single SV - score(vcf, ref) : convenience for whole-VCF scoring (returns DataFrame) - classify(chrom, ...) : convenience for one SV (returns dict) - load_default() : load the bundled model """ from __future__ import annotations from pathlib import Path from typing import Optional, Union import numpy as np import pandas as pd import pickle from .features import (SVCall, FEATURE_COLS, extract_one, extract_batch, from_vcf, WINDOW) _DEFAULT_MODEL_PATH = (Path(__file__).resolve().parent.parent / 'model' / 'svspr_v14_seq.pkl') def _assign_tier(cs: float) -> str: # Thresholds match Methods 2.7.2 (manuscript). CS is uncalibrated; if you # apply isotonic/Platt calibration, re-derive these cutoffs on calibrated CS. if cs >= 0.9: return 'high' if cs >= 0.7: return 'moderate' if cs >= 0.5: return 'warning' return 'low' class SVSPR: """SV-SPR confidence scorer. Parameters ---------- model_path : str | Path | None Path to a pickled sklearn classifier. If None, the bundled svspr_v14_seq.pkl is loaded. feature_cols : list[str] | None Override the feature column ordering if the model was trained on a different set. Defaults to FEATURE_COLS. """ def __init__(self, model_path: Optional[Union[str, Path]] = None, feature_cols: Optional[list] = None): path = Path(model_path) if model_path else _DEFAULT_MODEL_PATH if not path.exists(): raise FileNotFoundError(f'Model file not found: {path}') with open(path, 'rb') as f: obj = pickle.load(f) # The training pickle is a dict {model, features, ...} or a raw estimator. if isinstance(obj, dict): self.model = obj.get('model') or obj.get('estimator') cols = obj.get('features') or obj.get('feature_cols') self.feature_cols = feature_cols or cols or FEATURE_COLS self.metadata = {k: v for k, v in obj.items() if k != 'model'} else: self.model = obj self.feature_cols = feature_cols or FEATURE_COLS self.metadata = {} # ── Scoring methods ────────────────────────────────────────────────────── def _align(self, feat_df: pd.DataFrame) -> pd.DataFrame: """Map user-facing feature names to the model's expected names. The trained model uses legacy column names like 'svlen_abs_manta', 'svtype_DEL_manta', etc. The feature extractor produces clean names like 'svlen_abs', 'svtype_DEL'. This method bridges the two. """ df = feat_df.copy() # Add legacy aliases for any column the model expects but isn't present. aliases = {'svlen_abs_manta': 'svlen_abs', 'svtype_DEL_manta': 'svtype_DEL', 'svtype_INS_manta': 'svtype_INS', 'svtype_DUP_manta': 'svtype_DUP', 'svtype_BND_manta': 'svtype_BND'} for legacy, clean in aliases.items(): if legacy not in df.columns and clean in df.columns: df[legacy] = df[clean] return df.reindex(columns=self.feature_cols).fillna(0) def predict_df(self, feat_df: pd.DataFrame) -> pd.DataFrame: """Score a feature DataFrame. Adds CS and tier columns.""" X = self._align(feat_df) cs = self.model.predict_proba(X.values)[:, 1] out = feat_df.copy() out['CS'] = cs out['tier'] = [_assign_tier(v) for v in cs] return out def predict_vcf(self, vcf_path: str, ref_path: str) -> pd.DataFrame: """Score every SV in a VCF. Returns DataFrame with coord + CS + tier.""" feat_df = from_vcf(vcf_path, ref_path) return self.predict_df(feat_df) def predict_one(self, chrom: str, pos: int, end: int, svtype: str, svlen: int, total_alt_support: float, ref_path: str ) -> dict: """Score one SV call. Returns {'CS': float, 'tier': str}.""" import pysam fa = pysam.FastaFile(ref_path) try: call = SVCall(chrom=chrom, pos=pos, end=end, svtype=svtype, svlen=svlen, total_alt_support=total_alt_support) row = extract_one(call, fa) finally: fa.close() X = self._align(pd.DataFrame([row])) cs = float(self.model.predict_proba(X.values)[0, 1]) return {'CS': cs, 'tier': _assign_tier(cs)} # ── Module-level convenience functions ─────────────────────────────────────── _DEFAULT: Optional[SVSPR] = None def load_default() -> SVSPR: """Return a cached SVSPR loaded from the bundled default weights.""" global _DEFAULT if _DEFAULT is None: _DEFAULT = SVSPR() return _DEFAULT def score(vcf_path: str, ref_path: str, model_path: Optional[str] = None) -> pd.DataFrame: """One-call API: score every SV in a VCF and return a DataFrame. Returns columns: chrom, pos, end, svtype, svlen, CS, tier, plus the 11 feature columns used by the model. """ model = SVSPR(model_path) if model_path else load_default() return model.predict_vcf(vcf_path, ref_path) def classify(chrom: str, pos: int, end: int, svtype: str, svlen: int, total_alt_support: float, ref_path: str, model_path: Optional[str] = None) -> dict: """One-call API: classify a single SV. Returns {'CS': ..., 'tier': ...}.""" model = SVSPR(model_path) if model_path else load_default() return model.predict_one(chrom, pos, end, svtype, svlen, total_alt_support, ref_path)