ifieryarrows commited on
Commit
b8a1997
·
verified ·
1 Parent(s): 167e1a2

Sync from GitHub (tests passed)

Browse files
Files changed (1) hide show
  1. app/inference.py +80 -4
app/inference.py CHANGED
@@ -10,6 +10,7 @@ Handles:
10
 
11
  import json
12
  import logging
 
13
 
14
  # Suppress httpx request logging to prevent API keys in URLs from appearing in logs
15
  logging.getLogger("httpx").setLevel(logging.WARNING)
@@ -39,6 +40,72 @@ logging.basicConfig(level=logging.INFO)
39
  logger = logging.getLogger(__name__)
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def get_current_price(session: Session, symbol: str) -> Optional[float]:
43
  """
44
  Get the current price for a symbol.
@@ -205,6 +272,10 @@ def build_features_for_prediction(
205
  Build feature vector for live prediction.
206
  Uses the most recent available data.
207
  MUST use training_symbols to match the model's training data.
 
 
 
 
208
  """
209
  settings = get_settings()
210
  # Use training_symbols (not symbols_list) to match model training
@@ -256,12 +327,17 @@ def build_features_for_prediction(
256
  # Get latest row
257
  latest = all_features.iloc[[-1]].copy()
258
 
259
- # Robust feature alignment:
260
- # - Reindex to exactly match model's expected features
 
 
 
 
 
261
  # - Missing features get 0.0 (same as missing data handling in training)
262
  # - Extra features are dropped
263
- # This prevents ValueError when symbol set changes between training and inference
264
- latest = latest.reindex(columns=feature_names, fill_value=0.0)
265
 
266
  # Ensure float dtype for XGBoost
267
  latest = latest.astype(float)
 
10
 
11
  import json
12
  import logging
13
+ import re
14
 
15
  # Suppress httpx request logging to prevent API keys in URLs from appearing in logs
16
  logging.getLogger("httpx").setLevel(logging.WARNING)
 
40
  logger = logging.getLogger(__name__)
41
 
42
 
43
+ # =============================================================================
44
+ # Feature Alignment Helpers (Train/Inference compatibility)
45
+ # =============================================================================
46
+
47
+ def _sanitize_symbol(sym: str) -> str:
48
+ """Convert symbol to safe column prefix (HG=F -> HG_F)."""
49
+ return re.sub(r"[^A-Za-z0-9]+", "_", sym).strip("_")
50
+
51
+
52
+ def _rename_sanitized_to_raw(df: pd.DataFrame, symbols: list[str]) -> pd.DataFrame:
53
+ """
54
+ Rename sanitized column prefixes back to raw symbol names.
55
+ Example: HG_F_ret1 -> HG=F_ret1
56
+ """
57
+ rename_map = {}
58
+ cols = list(df.columns)
59
+
60
+ for sym in symbols:
61
+ sanitized = _sanitize_symbol(sym)
62
+ if sanitized == sym:
63
+ continue # No change needed
64
+
65
+ sanitized_prefix = sanitized + "_"
66
+ raw_prefix = sym + "_"
67
+
68
+ for col in cols:
69
+ if col.startswith(sanitized_prefix):
70
+ new_name = raw_prefix + col[len(sanitized_prefix):]
71
+ rename_map[col] = new_name
72
+
73
+ if rename_map:
74
+ logger.debug(f"Renaming {len(rename_map)} columns from sanitized to raw")
75
+ return df.rename(columns=rename_map)
76
+ return df
77
+
78
+
79
+ def _align_features_to_model(df: pd.DataFrame, expected_features: list[str]) -> pd.DataFrame:
80
+ """
81
+ Align DataFrame columns to match model's expected feature names.
82
+ - Missing features are filled with 0.0
83
+ - Extra features are dropped
84
+ - Column order matches expected_features
85
+ """
86
+ if not expected_features:
87
+ logger.warning("No expected features provided; skipping alignment")
88
+ return df
89
+
90
+ present = set(df.columns)
91
+ expected = set(expected_features)
92
+
93
+ missing = expected - present
94
+ extra = present - expected
95
+
96
+ if missing or extra:
97
+ logger.info(
98
+ f"Feature alignment: expected={len(expected_features)} present={len(df.columns)} "
99
+ f"missing={len(missing)} extra={len(extra)}"
100
+ )
101
+ if missing:
102
+ logger.debug(f"Missing features (first 10): {list(missing)[:10]}")
103
+ if extra:
104
+ logger.debug(f"Extra features (first 10): {list(extra)[:10]}")
105
+
106
+ return df.reindex(columns=expected_features, fill_value=0.0)
107
+
108
+
109
  def get_current_price(session: Session, symbol: str) -> Optional[float]:
110
  """
111
  Get the current price for a symbol.
 
272
  Build feature vector for live prediction.
273
  Uses the most recent available data.
274
  MUST use training_symbols to match the model's training data.
275
+
276
+ Includes robust alignment to handle:
277
+ - Sanitized vs raw symbol name differences (HG_F vs HG=F)
278
+ - Missing/extra features between training and inference
279
  """
280
  settings = get_settings()
281
  # Use training_symbols (not symbols_list) to match model training
 
327
  # Get latest row
328
  latest = all_features.iloc[[-1]].copy()
329
 
330
+ # STEP 1: Rename sanitized prefixes to raw symbol names if needed
331
+ # This handles cases where feature generation used sanitized names (HG_F)
332
+ # but model was trained with raw names (HG=F)
333
+ all_symbols = [target_symbol] + list(symbols)
334
+ latest = _rename_sanitized_to_raw(latest, all_symbols)
335
+
336
+ # STEP 2: Align to model's expected features
337
  # - Missing features get 0.0 (same as missing data handling in training)
338
  # - Extra features are dropped
339
+ # - Column order matches expected feature_names
340
+ latest = _align_features_to_model(latest, feature_names)
341
 
342
  # Ensure float dtype for XGBoost
343
  latest = latest.astype(float)