Add online learning: per-symbol adaptive models with meta-learning, concept drift adaptation
aab4bbb verified | """Online Learning — Per-Symbol Adaptive Models | |
| Why this matters for Jane Street level: | |
| - Markets CHANGE. A model trained on SPY 2022 fails on SPY 2024. | |
| - Each asset has unique microstructure, seasonality, regime behavior. | |
| - Static models lose predictive power over time (model decay). | |
| Solution: Online / Continual Learning | |
| - Update models incrementally on every new observation | |
| - Per-symbol parameters (some assets trend, others mean-revert) | |
| - Meta-learning: learn HOW to adapt quickly | |
| - Concept drift detection: auto-detect when old model is wrong | |
| Based on: | |
| - Vapnik (1998): Online SVM | |
| - Cesa-Bianchi & Lugosi (2006): Prediction, Learning, Games | |
| - Finn et al. (2017): MAML (Model-Agnostic Meta-Learning) | |
| - Gama et al. (2014): A survey on concept drift adaptation | |
| """ | |
| import numpy as np | |
| import pandas as pd | |
| from typing import Dict, List, Tuple, Optional, Callable | |
| from collections import defaultdict | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| def sigmoid(x): | |
| return 1 / (1 + np.exp(-np.clip(x, -500, 500))) | |
| class OnlineLogisticRegression: | |
| """ | |
| Online logistic regression with adaptive learning rate. | |
| Uses exponential weighting: recent data matters more. | |
| Learning rate adapts to gradient variance. | |
| """ | |
| def __init__(self, | |
| n_features: int = 10, | |
| initial_lr: float = 0.01, | |
| lr_decay: float = 0.999, | |
| l2_reg: float = 0.01, | |
| min_lr: float = 1e-6): | |
| self.n_features = n_features | |
| self.lr = initial_lr | |
| self.initial_lr = initial_lr | |
| self.lr_decay = lr_decay | |
| self.l2_reg = l2_reg | |
| self.min_lr = min_lr | |
| self.weights = np.zeros(n_features) | |
| self.bias = 0.0 | |
| # Adaptive state | |
| self.grad_moment2 = np.zeros(n_features) | |
| self.bias_moment2 = 0.0 | |
| self.t = 0 | |
| # Performance tracking | |
| self.predictions = [] | |
| self.actuals = [] | |
| self.grad_norms = [] | |
| def predict_proba(self, x: np.ndarray) -> float: | |
| """Predict probability of positive class""" | |
| z = np.dot(x, self.weights) + self.bias | |
| return sigmoid(z) | |
| def predict(self, x: np.ndarray) -> int: | |
| return 1 if self.predict_proba(x) > 0.5 else 0 | |
| def update(self, x: np.ndarray, y: int) -> Dict: | |
| """ | |
| Single-step online update. | |
| Args: | |
| x: feature vector (n_features,) | |
| y: label (0 or 1) | |
| Returns: | |
| Update metrics | |
| """ | |
| self.t += 1 | |
| # Forward | |
| z = np.dot(x, self.weights) + self.bias | |
| pred = sigmoid(z) | |
| # Gradient | |
| error = pred - y | |
| grad_w = error * x + self.l2_reg * self.weights | |
| grad_b = error | |
| # Adaptive learning rate (AdaGrad-like) | |
| self.grad_moment2 += grad_w ** 2 | |
| self.bias_moment2 += grad_b ** 2 | |
| lr_w = self.lr / (np.sqrt(self.grad_moment2) + 1e-8) | |
| lr_b = self.lr / (np.sqrt(self.bias_moment2) + 1e-8) | |
| # Update | |
| self.weights -= lr_w * grad_w | |
| self.bias -= lr_b * grad_b | |
| # Decay learning rate | |
| self.lr = max(self.lr * self.lr_decay, self.min_lr) | |
| # Track | |
| self.predictions.append(pred) | |
| self.actuals.append(y) | |
| self.grad_norms.append(np.linalg.norm(grad_w)) | |
| return { | |
| 'pred': pred, | |
| 'error': error, | |
| 'grad_norm': np.linalg.norm(grad_w), | |
| 'lr': self.lr | |
| } | |
| def get_performance(self, last_n: int = 100) -> Dict: | |
| """Get recent performance metrics""" | |
| if len(self.actuals) < 2: | |
| return {'accuracy': 0.5} | |
| n = min(last_n, len(self.actuals)) | |
| preds = np.array(self.predictions[-n:]) > 0.5 | |
| actuals = np.array(self.actuals[-n:]) | |
| accuracy = np.mean(preds == actuals) | |
| # Directional accuracy for returns | |
| if len(actuals) >= 10: | |
| # Use last 10 predictions as a sequence | |
| pred_returns = np.diff(self.predictions[-10:]) | |
| actual_returns = np.diff(self.actuals[-10:]) | |
| directional = np.mean(np.sign(pred_returns) == np.sign(actual_returns)) if len(pred_returns) > 0 else 0.5 | |
| else: | |
| directional = accuracy | |
| return { | |
| 'accuracy': accuracy, | |
| 'directional_accuracy': directional, | |
| 'avg_grad_norm': np.mean(self.grad_norms[-n:]) if self.grad_norms else 0, | |
| 'current_lr': self.lr, | |
| 'n_updates': self.t | |
| } | |
| class PerSymbolAdaptiveModel: | |
| """ | |
| Maintain separate online models for each symbol. | |
| Key insight: SPY behaves differently from TSLA. | |
| Each asset needs its own: | |
| - Feature weights | |
| - Learning rate schedule | |
| - Regime detection | |
| """ | |
| def __init__(self, | |
| n_features: int = 10, | |
| base_lr: float = 0.01, | |
| symbols: Optional[List[str]] = None): | |
| self.n_features = n_features | |
| self.base_lr = base_lr | |
| self.symbols = symbols or [] | |
| # Per-symbol models | |
| self.models: Dict[str, OnlineLogisticRegression] = {} | |
| # Performance tracking | |
| self.symbol_performance: Dict[str, List[Dict]] = defaultdict(list) | |
| # Auto-detect symbols | |
| self.seen_symbols = set() | |
| def _get_or_create_model(self, symbol: str) -> OnlineLogisticRegression: | |
| """Get model for symbol, create if new""" | |
| if symbol not in self.models: | |
| # Meta-learn initial weights from similar symbols | |
| init_weights = self._meta_initialize(symbol) | |
| model = OnlineLogisticRegression( | |
| n_features=self.n_features, | |
| initial_lr=self.base_lr * np.random.uniform(0.8, 1.2) | |
| ) | |
| if init_weights is not None: | |
| model.weights = init_weights | |
| self.models[symbol] = model | |
| self.seen_symbols.add(symbol) | |
| return self.models[symbol] | |
| def _meta_initialize(self, new_symbol: str) -> Optional[np.ndarray]: | |
| """ | |
| Meta-learning: initialize new symbol model from similar symbols. | |
| Use average of best-performing similar models. | |
| """ | |
| if len(self.models) < 3: | |
| return None | |
| # Get recent performance | |
| perf = [] | |
| for sym, model in self.models.items(): | |
| p = model.get_performance(last_n=50) | |
| perf.append((sym, p.get('accuracy', 0.5), model.weights)) | |
| # Use top 3 models as initialization | |
| perf.sort(key=lambda x: x[1], reverse=True) | |
| top_weights = [p[2] for p in perf[:3]] | |
| return np.mean(top_weights, axis=0) | |
| def update(self, symbol: str, x: np.ndarray, y: int) -> Dict: | |
| """Update model for a specific symbol""" | |
| model = self._get_or_create_model(symbol) | |
| metrics = model.update(x, y) | |
| # Track performance | |
| perf = model.get_performance(last_n=20) | |
| self.symbol_performance[symbol].append(perf) | |
| metrics['symbol'] = symbol | |
| return metrics | |
| def predict(self, symbol: str, x: np.ndarray) -> Dict: | |
| """Predict for a specific symbol""" | |
| model = self._get_or_create_model(symbol) | |
| prob = model.predict_proba(x) | |
| return { | |
| 'symbol': symbol, | |
| 'probability': prob, | |
| 'prediction': 1 if prob > 0.5 else 0, | |
| 'confidence': abs(prob - 0.5) * 2, # 0 = unsure, 1 = certain | |
| 'model_age': model.t | |
| } | |
| def get_symbol_ranking(self) -> pd.DataFrame: | |
| """Rank symbols by recent model performance""" | |
| rows = [] | |
| for symbol, model in self.models.items(): | |
| perf = model.get_performance(last_n=100) | |
| rows.append({ | |
| 'symbol': symbol, | |
| 'accuracy': perf['accuracy'], | |
| 'directional_accuracy': perf['directional_accuracy'], | |
| 'n_samples': model.t, | |
| 'current_lr': perf['current_lr'], | |
| 'grad_norm': perf['avg_grad_norm'] | |
| }) | |
| df = pd.DataFrame(rows) | |
| if not df.empty: | |
| df = df.sort_values('directional_accuracy', ascending=False) | |
| return df | |
| def detect_concept_drift(self, symbol: str, | |
| window_short: int = 50, | |
| window_long: int = 200) -> Dict: | |
| """ | |
| Detect if the relationship between features and target has changed. | |
| Uses accuracy comparison: recent vs older performance. | |
| If recent << older → concept drift detected → need retraining/adaptation. | |
| """ | |
| model = self.models.get(symbol) | |
| if model is None or len(model.actuals) < window_long: | |
| return {'drift_detected': False, 'reason': 'insufficient_data'} | |
| recent = model.get_performance(last_n=window_short)['accuracy'] | |
| older = model.get_performance(last_n=window_long)['accuracy'] | |
| # Drift if recent accuracy significantly worse | |
| drift_threshold = -0.15 # 15% accuracy drop | |
| drift_score = recent - older | |
| drift_detected = drift_score < drift_threshold | |
| return { | |
| 'drift_detected': drift_detected, | |
| 'drift_score': drift_score, | |
| 'recent_accuracy': recent, | |
| 'older_accuracy': older, | |
| 'threshold': drift_threshold, | |
| 'action': 'reset_learning_rate' if drift_detected else 'continue', | |
| 'symbol': symbol | |
| } | |
| def adapt_to_drift(self, symbol: str): | |
| """Adapt model when drift detected""" | |
| model = self.models.get(symbol) | |
| if model is None: | |
| return | |
| # Reset learning rate to initial (forget old, learn new) | |
| model.lr = model.initial_lr * 2 # Higher LR to adapt faster | |
| model.grad_moment2 = np.zeros(self.n_features) | |
| model.bias_moment2 = 0.0 | |
| print(f" [Drift] Reset learning rate for {symbol} to {model.lr:.4f}") | |
| def get_full_state(self) -> Dict: | |
| """Export full state for persistence""" | |
| return { | |
| 'n_features': self.n_features, | |
| 'base_lr': self.base_lr, | |
| 'symbols': list(self.seen_symbols), | |
| 'models': { | |
| sym: { | |
| 'weights': model.weights.tolist(), | |
| 'bias': model.bias, | |
| 'n_updates': model.t, | |
| 'lr': model.lr | |
| } | |
| for sym, model in self.models.items() | |
| } | |
| } | |
| class ConceptDriftMonitor: | |
| """ | |
| System-wide concept drift monitoring across all symbols. | |
| Automatically detects when markets have structurally changed | |
| and triggers model adaptation. | |
| """ | |
| def __init__(self, | |
| per_symbol_model: PerSymbolAdaptiveModel, | |
| check_interval: int = 100, | |
| drift_threshold: float = -0.15): | |
| self.model = per_symbol_model | |
| self.check_interval = check_interval | |
| self.drift_threshold = drift_threshold | |
| self.step_count = 0 | |
| self.drift_history = [] | |
| self.adaptation_log = [] | |
| def check_all_symbols(self) -> List[Dict]: | |
| """Check all symbols for drift and adapt if needed""" | |
| self.step_count += 1 | |
| if self.step_count % self.check_interval != 0: | |
| return [] | |
| results = [] | |
| for symbol in self.model.seen_symbols: | |
| drift_result = self.model.detect_concept_drift(symbol) | |
| results.append(drift_result) | |
| if drift_result['drift_detected']: | |
| self.model.adapt_to_drift(symbol) | |
| self.drift_history.append({ | |
| 'step': self.step_count, | |
| 'symbol': symbol, | |
| 'score': drift_result['drift_score'], | |
| 'recent_acc': drift_result['recent_accuracy'], | |
| 'older_acc': drift_result['older_accuracy'] | |
| }) | |
| return results | |
| def get_drift_summary(self) -> pd.DataFrame: | |
| """Summary of all detected drifts""" | |
| return pd.DataFrame(self.drift_history) | |
| if __name__ == '__main__': | |
| print("=" * 70) | |
| print(" ONLINE LEARNING — PER-SYMBOL ADAPTIVE MODELS") | |
| print("=" * 70) | |
| # Simulate multiple symbols with different behaviors | |
| np.random.seed(42) | |
| # Symbol A: Strong momentum signal | |
| # Symbol B: Weak/noise | |
| # Symbol C: Regime switch at step 500 | |
| model = PerSymbolAdaptiveModel(n_features=5, base_lr=0.05) | |
| monitor = ConceptDriftMonitor(model, check_interval=100) | |
| n_steps = 800 | |
| for step in range(n_steps): | |
| # Symbol A: feature 0 predicts direction with 65% accuracy | |
| x_a = np.random.randn(5) | |
| true_dir_a = 1 if x_a[0] > 0 else 0 | |
| if np.random.rand() > 0.65: | |
| true_dir_a = 1 - true_dir_a # 35% noise | |
| # Symbol B: no signal, pure noise | |
| x_b = np.random.randn(5) | |
| true_dir_b = np.random.randint(0, 2) | |
| # Symbol C: regime switch at step 500 | |
| x_c = np.random.randn(5) | |
| if step < 500: | |
| true_dir_c = 1 if x_c[0] > 0 else 0 # feature 0 matters | |
| if np.random.rand() > 0.6: | |
| true_dir_c = 1 - true_dir_c | |
| else: | |
| # Regime switch: now feature 1 predicts (opposite!) | |
| true_dir_c = 1 if x_c[1] < 0 else 0 | |
| if np.random.rand() > 0.6: | |
| true_dir_c = 1 - true_dir_c | |
| # Update models | |
| model.update('AAPL', x_a, true_dir_a) | |
| model.update('JUNK', x_b, true_dir_b) | |
| model.update('REGIME', x_c, true_dir_c) | |
| # Periodic drift check | |
| if step % 100 == 0 and step > 0: | |
| monitor.check_all_symbols() | |
| # Results | |
| print(f"\nTrained on {n_steps} steps per symbol") | |
| print(f"\nPer-Symbol Performance:") | |
| ranking = model.get_symbol_ranking() | |
| print(ranking.to_string(index=False)) | |
| # Drift detection for REGIME symbol | |
| drift_result = model.detect_concept_drift('REGIME', window_short=50, window_long=300) | |
| print(f"\nREGIME Symbol Drift Detection:") | |
| print(f" Drift detected: {drift_result['drift_detected']}") | |
| print(f" Recent accuracy: {drift_result['recent_accuracy']:.3f}") | |
| print(f" Older accuracy: {drift_result['older_accuracy']:.3f}") | |
| print(f" Drift score: {drift_result['drift_score']:+.3f}") | |
| print(f"\n Key Insights:") | |
| print(f" - AAPL model should have ~60-65% accuracy (real signal)") | |
| print(f" - JUNK model should have ~50% accuracy (pure noise)") | |
| print(f" - REGIME model should detect drift at step 500") | |
| print(f" - Each symbol gets its OWN learning rate and weights") | |
| print(f" - Drift triggers adaptive LR reset") | |