File size: 6,035 Bytes
90d0b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""SVSPR wrapper around the trained RandomForest.

Public surface:
    - SVSPR class         : load + predict on DataFrame / DataFrame / single SV
    - score(vcf, ref)     : convenience for whole-VCF scoring (returns DataFrame)
    - classify(chrom, ...) : convenience for one SV (returns dict)
    - load_default()      : load the bundled model
"""
from __future__ import annotations

from pathlib import Path
from typing import Optional, Union

import numpy as np
import pandas as pd
import pickle

from .features import (SVCall, FEATURE_COLS, extract_one, extract_batch,
                       from_vcf, WINDOW)


_DEFAULT_MODEL_PATH = (Path(__file__).resolve().parent.parent / 'model' /
                       'svspr_v14_seq.pkl')


def _assign_tier(cs: float) -> str:
    # Thresholds match Methods 2.7.2 (manuscript). CS is uncalibrated; if you
    # apply isotonic/Platt calibration, re-derive these cutoffs on calibrated CS.
    if cs >= 0.9:
        return 'high'
    if cs >= 0.7:
        return 'moderate'
    if cs >= 0.5:
        return 'warning'
    return 'low'


class SVSPR:
    """SV-SPR confidence scorer.

    Parameters
    ----------
    model_path : str | Path | None
        Path to a pickled sklearn classifier. If None, the bundled
        svspr_v14_seq.pkl is loaded.
    feature_cols : list[str] | None
        Override the feature column ordering if the model was trained on
        a different set. Defaults to FEATURE_COLS.
    """

    def __init__(self, model_path: Optional[Union[str, Path]] = None,
                 feature_cols: Optional[list] = None):
        path = Path(model_path) if model_path else _DEFAULT_MODEL_PATH
        if not path.exists():
            raise FileNotFoundError(f'Model file not found: {path}')
        with open(path, 'rb') as f:
            obj = pickle.load(f)
        # The training pickle is a dict {model, features, ...} or a raw estimator.
        if isinstance(obj, dict):
            self.model = obj.get('model') or obj.get('estimator')
            cols = obj.get('features') or obj.get('feature_cols')
            self.feature_cols = feature_cols or cols or FEATURE_COLS
            self.metadata = {k: v for k, v in obj.items() if k != 'model'}
        else:
            self.model = obj
            self.feature_cols = feature_cols or FEATURE_COLS
            self.metadata = {}

    # ── Scoring methods ──────────────────────────────────────────────────────
    def _align(self, feat_df: pd.DataFrame) -> pd.DataFrame:
        """Map user-facing feature names to the model's expected names.

        The trained model uses legacy column names like 'svlen_abs_manta',
        'svtype_DEL_manta', etc. The feature extractor produces clean names
        like 'svlen_abs', 'svtype_DEL'. This method bridges the two.
        """
        df = feat_df.copy()
        # Add legacy aliases for any column the model expects but isn't present.
        aliases = {'svlen_abs_manta': 'svlen_abs',
                   'svtype_DEL_manta': 'svtype_DEL',
                   'svtype_INS_manta': 'svtype_INS',
                   'svtype_DUP_manta': 'svtype_DUP',
                   'svtype_BND_manta': 'svtype_BND'}
        for legacy, clean in aliases.items():
            if legacy not in df.columns and clean in df.columns:
                df[legacy] = df[clean]
        return df.reindex(columns=self.feature_cols).fillna(0)

    def predict_df(self, feat_df: pd.DataFrame) -> pd.DataFrame:
        """Score a feature DataFrame. Adds CS and tier columns."""
        X = self._align(feat_df)
        cs = self.model.predict_proba(X.values)[:, 1]
        out = feat_df.copy()
        out['CS'] = cs
        out['tier'] = [_assign_tier(v) for v in cs]
        return out

    def predict_vcf(self, vcf_path: str, ref_path: str) -> pd.DataFrame:
        """Score every SV in a VCF. Returns DataFrame with coord + CS + tier."""
        feat_df = from_vcf(vcf_path, ref_path)
        return self.predict_df(feat_df)

    def predict_one(self, chrom: str, pos: int, end: int, svtype: str,
                    svlen: int, total_alt_support: float, ref_path: str
                    ) -> dict:
        """Score one SV call. Returns {'CS': float, 'tier': str}."""
        import pysam
        fa = pysam.FastaFile(ref_path)
        try:
            call = SVCall(chrom=chrom, pos=pos, end=end, svtype=svtype,
                          svlen=svlen, total_alt_support=total_alt_support)
            row = extract_one(call, fa)
        finally:
            fa.close()
        X = self._align(pd.DataFrame([row]))
        cs = float(self.model.predict_proba(X.values)[0, 1])
        return {'CS': cs, 'tier': _assign_tier(cs)}


# ── Module-level convenience functions ───────────────────────────────────────
_DEFAULT: Optional[SVSPR] = None


def load_default() -> SVSPR:
    """Return a cached SVSPR loaded from the bundled default weights."""
    global _DEFAULT
    if _DEFAULT is None:
        _DEFAULT = SVSPR()
    return _DEFAULT


def score(vcf_path: str, ref_path: str,
          model_path: Optional[str] = None) -> pd.DataFrame:
    """One-call API: score every SV in a VCF and return a DataFrame.

    Returns columns: chrom, pos, end, svtype, svlen, CS, tier, plus the 11
    feature columns used by the model.
    """
    model = SVSPR(model_path) if model_path else load_default()
    return model.predict_vcf(vcf_path, ref_path)


def classify(chrom: str, pos: int, end: int, svtype: str, svlen: int,
             total_alt_support: float, ref_path: str,
             model_path: Optional[str] = None) -> dict:
    """One-call API: classify a single SV. Returns {'CS': ..., 'tier': ...}."""
    model = SVSPR(model_path) if model_path else load_default()
    return model.predict_one(chrom, pos, end, svtype, svlen,
                             total_alt_support, ref_path)