SVSTR-Score / seqonly /src /model.py
khyeom's picture
Add sequence-only headline model (svspr_v14_seq, 11-feature) + inference package
90d0b4b verified
Raw
History Blame
6.04 kB
"""SVSPR wrapper around the trained RandomForest.
Public surface:
- SVSPR class : load + predict on DataFrame / DataFrame / single SV
- score(vcf, ref) : convenience for whole-VCF scoring (returns DataFrame)
- classify(chrom, ...) : convenience for one SV (returns dict)
- load_default() : load the bundled model
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional, Union
import numpy as np
import pandas as pd
import pickle
from .features import (SVCall, FEATURE_COLS, extract_one, extract_batch,
from_vcf, WINDOW)
_DEFAULT_MODEL_PATH = (Path(__file__).resolve().parent.parent / 'model' /
'svspr_v14_seq.pkl')
def _assign_tier(cs: float) -> str:
# Thresholds match Methods 2.7.2 (manuscript). CS is uncalibrated; if you
# apply isotonic/Platt calibration, re-derive these cutoffs on calibrated CS.
if cs >= 0.9:
return 'high'
if cs >= 0.7:
return 'moderate'
if cs >= 0.5:
return 'warning'
return 'low'
class SVSPR:
"""SV-SPR confidence scorer.
Parameters
----------
model_path : str | Path | None
Path to a pickled sklearn classifier. If None, the bundled
svspr_v14_seq.pkl is loaded.
feature_cols : list[str] | None
Override the feature column ordering if the model was trained on
a different set. Defaults to FEATURE_COLS.
"""
def __init__(self, model_path: Optional[Union[str, Path]] = None,
feature_cols: Optional[list] = None):
path = Path(model_path) if model_path else _DEFAULT_MODEL_PATH
if not path.exists():
raise FileNotFoundError(f'Model file not found: {path}')
with open(path, 'rb') as f:
obj = pickle.load(f)
# The training pickle is a dict {model, features, ...} or a raw estimator.
if isinstance(obj, dict):
self.model = obj.get('model') or obj.get('estimator')
cols = obj.get('features') or obj.get('feature_cols')
self.feature_cols = feature_cols or cols or FEATURE_COLS
self.metadata = {k: v for k, v in obj.items() if k != 'model'}
else:
self.model = obj
self.feature_cols = feature_cols or FEATURE_COLS
self.metadata = {}
# ── Scoring methods ──────────────────────────────────────────────────────
def _align(self, feat_df: pd.DataFrame) -> pd.DataFrame:
"""Map user-facing feature names to the model's expected names.
The trained model uses legacy column names like 'svlen_abs_manta',
'svtype_DEL_manta', etc. The feature extractor produces clean names
like 'svlen_abs', 'svtype_DEL'. This method bridges the two.
"""
df = feat_df.copy()
# Add legacy aliases for any column the model expects but isn't present.
aliases = {'svlen_abs_manta': 'svlen_abs',
'svtype_DEL_manta': 'svtype_DEL',
'svtype_INS_manta': 'svtype_INS',
'svtype_DUP_manta': 'svtype_DUP',
'svtype_BND_manta': 'svtype_BND'}
for legacy, clean in aliases.items():
if legacy not in df.columns and clean in df.columns:
df[legacy] = df[clean]
return df.reindex(columns=self.feature_cols).fillna(0)
def predict_df(self, feat_df: pd.DataFrame) -> pd.DataFrame:
"""Score a feature DataFrame. Adds CS and tier columns."""
X = self._align(feat_df)
cs = self.model.predict_proba(X.values)[:, 1]
out = feat_df.copy()
out['CS'] = cs
out['tier'] = [_assign_tier(v) for v in cs]
return out
def predict_vcf(self, vcf_path: str, ref_path: str) -> pd.DataFrame:
"""Score every SV in a VCF. Returns DataFrame with coord + CS + tier."""
feat_df = from_vcf(vcf_path, ref_path)
return self.predict_df(feat_df)
def predict_one(self, chrom: str, pos: int, end: int, svtype: str,
svlen: int, total_alt_support: float, ref_path: str
) -> dict:
"""Score one SV call. Returns {'CS': float, 'tier': str}."""
import pysam
fa = pysam.FastaFile(ref_path)
try:
call = SVCall(chrom=chrom, pos=pos, end=end, svtype=svtype,
svlen=svlen, total_alt_support=total_alt_support)
row = extract_one(call, fa)
finally:
fa.close()
X = self._align(pd.DataFrame([row]))
cs = float(self.model.predict_proba(X.values)[0, 1])
return {'CS': cs, 'tier': _assign_tier(cs)}
# ── Module-level convenience functions ───────────────────────────────────────
_DEFAULT: Optional[SVSPR] = None
def load_default() -> SVSPR:
"""Return a cached SVSPR loaded from the bundled default weights."""
global _DEFAULT
if _DEFAULT is None:
_DEFAULT = SVSPR()
return _DEFAULT
def score(vcf_path: str, ref_path: str,
model_path: Optional[str] = None) -> pd.DataFrame:
"""One-call API: score every SV in a VCF and return a DataFrame.
Returns columns: chrom, pos, end, svtype, svlen, CS, tier, plus the 11
feature columns used by the model.
"""
model = SVSPR(model_path) if model_path else load_default()
return model.predict_vcf(vcf_path, ref_path)
def classify(chrom: str, pos: int, end: int, svtype: str, svlen: int,
total_alt_support: float, ref_path: str,
model_path: Optional[str] = None) -> dict:
"""One-call API: classify a single SV. Returns {'CS': ..., 'tier': ...}."""
model = SVSPR(model_path) if model_path else load_default()
return model.predict_one(chrom, pos, end, svtype, svlen,
total_alt_support, ref_path)