File size: 8,579 Bytes
18d4089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d317049
 
18d4089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d317049
 
 
 
18d4089
d317049
18d4089
d317049
18d4089
d317049
 
 
 
 
 
 
 
18d4089
 
 
 
 
 
d317049
 
 
18d4089
 
 
 
 
d317049
18d4089
d317049
18d4089
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
"""
Feature Selection Pipeline for TFT-ASRO.

Two-stage dimensionality reduction to combat the curse of dimensionality
(~200+ features for ~500 training samples):

    Stage 1 β€” MRMR Pre-Filter (before training):
        Statistical filter using Mutual Information for relevance and
        Pearson correlation for redundancy.  Reduces 200+ β†’ top-K features.

    Stage 2 β€” VSN Importance Pruning (after initial training):
        Uses TFT's Variable Selection Network weights to identify which
        features the model actually attends to, then prunes the bottom tier.

References:
    - Ding & Peng (2005) "Minimum Redundancy Feature Selection"
    - Lim et al. (2021) "Temporal Fusion Transformers" β€” VSN interpretability
"""

from __future__ import annotations

import logging
from typing import Optional

import numpy as np
import pandas as pd

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Stage 1: MRMR Pre-Filter
# ---------------------------------------------------------------------------

def _mutual_info_relevance(
    X: pd.DataFrame,
    y: pd.Series,
    n_neighbors: int = 5,
) -> pd.Series:
    """Compute MI(feature, target) for each column in X."""
    from sklearn.feature_selection import mutual_info_regression

    mi = mutual_info_regression(
        X.values,
        y.values,
        n_neighbors=n_neighbors,
        random_state=42,
    )
    return pd.Series(mi, index=X.columns)


def _pairwise_correlation(X: pd.DataFrame) -> pd.DataFrame:
    """Absolute Pearson correlation matrix."""
    return X.corr().abs()


def mrmr_select(
    df: pd.DataFrame,
    target_col: str = "target",
    top_k: int = 80,
    exclude_cols: Optional[list[str]] = None,
) -> list[str]:
    """
    Minimum Redundancy Maximum Relevance feature selection.

    Greedily selects features that are highly relevant to the target while
    being minimally redundant with already-selected features.

    Args:
        df:          DataFrame containing features and the target column.
        target_col:  Name of the target column.
        top_k:       Number of features to select.
        exclude_cols: Columns to exclude from selection (e.g. time_idx, group_id).

    Returns:
        List of selected feature names, ordered by selection round.
    """
    if exclude_cols is None:
        exclude_cols = []

    meta_cols = {target_col, "time_idx", "group_id"} | set(exclude_cols)
    feature_cols = [c for c in df.columns if c not in meta_cols]

    if len(feature_cols) <= top_k:
        logger.info("MRMR: %d features <= top_k=%d, skipping selection", len(feature_cols), top_k)
        return feature_cols

    X = df[feature_cols].copy()
    y = df[target_col].copy()

    mask = y.notna() & X.notna().all(axis=1)
    X = X[mask]
    y = y[mask]

    if len(X) < 30:
        logger.warning("MRMR: only %d valid samples, skipping", len(X))
        return feature_cols

    relevance = _mutual_info_relevance(X, y)
    corr_matrix = _pairwise_correlation(X)

    selected: list[str] = []
    remaining = set(feature_cols)

    first = relevance.idxmax()
    selected.append(first)
    remaining.discard(first)

    for _ in range(top_k - 1):
        if not remaining:
            break

        best_score = -np.inf
        best_feat = None

        for feat in remaining:
            rel = relevance[feat]
            if not selected:
                score = rel
            else:
                redundancy = corr_matrix.loc[feat, selected].mean()
                score = rel - redundancy

            if score > best_score:
                best_score = score
                best_feat = feat

        if best_feat is None:
            break

        selected.append(best_feat)
        remaining.discard(best_feat)

    logger.info(
        "MRMR selected %d/%d features (top relevance=%.4f, bottom=%.4f)",
        len(selected), len(feature_cols),
        relevance[selected[0]] if selected else 0,
        relevance[selected[-1]] if selected else 0,
    )
    return selected


# ---------------------------------------------------------------------------
# Stage 2: VSN Importance Pruning
# ---------------------------------------------------------------------------

def vsn_prune(
    importance: dict[str, float],
    feature_list: list[str],
    min_features: int = 40,
    cumulative_threshold: float = 0.92,
) -> list[str]:
    """
    Prune features using TFT Variable Selection Network importance scores.

    Keeps features until their cumulative importance exceeds the threshold,
    with a minimum floor to avoid over-pruning.

    Args:
        importance:           {feature_name: normalised_importance} from
                              ``get_variable_importance()``.
        feature_list:         Full list of features currently in use.
        min_features:         Never prune below this count.
        cumulative_threshold: Keep features until cumulative importance hits this.

    Returns:
        Pruned feature list (subset of feature_list).
    """
    if not importance:
        logger.info("VSN prune: no importance scores, returning all %d features", len(feature_list))
        return feature_list

    scored = {f: importance.get(f, 0.0) for f in feature_list}
    ranked = sorted(scored.items(), key=lambda x: -x[1])

    total = sum(v for _, v in ranked)
    if total < 1e-12:
        return feature_list

    kept: list[str] = []
    cumulative = 0.0

    for feat, score in ranked:
        kept.append(feat)
        cumulative += score / total
        if cumulative >= cumulative_threshold and len(kept) >= min_features:
            break

    if len(kept) < min_features:
        kept = [f for f, _ in ranked[:min_features]]

    logger.info(
        "VSN pruned %d β†’ %d features (cumulative importance=%.2f%%)",
        len(feature_list), len(kept), cumulative * 100,
    )
    return kept


# ---------------------------------------------------------------------------
# Combined pipeline
# ---------------------------------------------------------------------------

def select_features(
    df: pd.DataFrame,
    target_col: str = "target",
    mrmr_top_k: int = 80,
    known_features: Optional[list[str]] = None,
    forced_unknown_features: Optional[list[str]] = None,
    forbidden_features: Optional[list[str]] = None,
) -> tuple[pd.DataFrame, list[str], list[str]]:
    """
    Run MRMR selection on unknown features while preserving known features.

    Args:
        df:              Master DataFrame from feature_store.
        target_col:      Target column name.
        mrmr_top_k:      How many unknown features to keep.
        known_features:  List of time_varying_known_reals (calendar etc.) β€” always kept.

    Returns:
        (filtered_df, new_unknown_features, known_features)
    """
    if known_features is None:
        known_features = []
    if forced_unknown_features is None:
        forced_unknown_features = []
    if forbidden_features is None:
        forbidden_features = []

    forbidden = set(forbidden_features)
    meta_cols = ["time_idx", "group_id", target_col]
    preserve_cols = set(meta_cols) | set(known_features) | forbidden

    unknown_candidates = [
        c for c in df.columns
        if c not in preserve_cols and c not in forbidden
    ]
    forced_unknown = [
        c for c in forced_unknown_features
        if c in df.columns and c not in preserve_cols
    ]

    if len(unknown_candidates) <= mrmr_top_k:
        logger.info(
            "Feature selection: %d unknown features <= top_k=%d, no pruning needed",
            len(unknown_candidates), mrmr_top_k,
        )
        selected = sorted(set(unknown_candidates) | set(forced_unknown))
        keep_cols = [c for c in list(preserve_cols) + selected if c in df.columns]
        return df[keep_cols].copy(), selected, known_features

    selected_unknown = mrmr_select(
        df,
        target_col=target_col,
        top_k=mrmr_top_k,
        exclude_cols=list(preserve_cols | forbidden),
    )
    selected_unknown = sorted(set(selected_unknown) | set(forced_unknown))

    keep_cols = list(preserve_cols) + selected_unknown
    keep_cols = [c for c in keep_cols if c in df.columns]

    filtered = df[keep_cols].copy()

    logger.info(
        "Feature selection complete: %d cols β†’ %d cols "
        "(%d unknown, %d known, %d meta)",
        len(df.columns), len(filtered.columns),
        len(selected_unknown), len(known_features), len(meta_cols),
    )

    return filtered, selected_unknown, known_features