alphaforge-quant-system / online_learning.py
Premchan369's picture
Add online learning: per-symbol adaptive models with meta-learning, concept drift adaptation
aab4bbb verified
"""Online Learning — Per-Symbol Adaptive Models
Why this matters for Jane Street level:
- Markets CHANGE. A model trained on SPY 2022 fails on SPY 2024.
- Each asset has unique microstructure, seasonality, regime behavior.
- Static models lose predictive power over time (model decay).
Solution: Online / Continual Learning
- Update models incrementally on every new observation
- Per-symbol parameters (some assets trend, others mean-revert)
- Meta-learning: learn HOW to adapt quickly
- Concept drift detection: auto-detect when old model is wrong
Based on:
- Vapnik (1998): Online SVM
- Cesa-Bianchi & Lugosi (2006): Prediction, Learning, Games
- Finn et al. (2017): MAML (Model-Agnostic Meta-Learning)
- Gama et al. (2014): A survey on concept drift adaptation
"""
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Callable
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')
def sigmoid(x):
return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
class OnlineLogisticRegression:
"""
Online logistic regression with adaptive learning rate.
Uses exponential weighting: recent data matters more.
Learning rate adapts to gradient variance.
"""
def __init__(self,
n_features: int = 10,
initial_lr: float = 0.01,
lr_decay: float = 0.999,
l2_reg: float = 0.01,
min_lr: float = 1e-6):
self.n_features = n_features
self.lr = initial_lr
self.initial_lr = initial_lr
self.lr_decay = lr_decay
self.l2_reg = l2_reg
self.min_lr = min_lr
self.weights = np.zeros(n_features)
self.bias = 0.0
# Adaptive state
self.grad_moment2 = np.zeros(n_features)
self.bias_moment2 = 0.0
self.t = 0
# Performance tracking
self.predictions = []
self.actuals = []
self.grad_norms = []
def predict_proba(self, x: np.ndarray) -> float:
"""Predict probability of positive class"""
z = np.dot(x, self.weights) + self.bias
return sigmoid(z)
def predict(self, x: np.ndarray) -> int:
return 1 if self.predict_proba(x) > 0.5 else 0
def update(self, x: np.ndarray, y: int) -> Dict:
"""
Single-step online update.
Args:
x: feature vector (n_features,)
y: label (0 or 1)
Returns:
Update metrics
"""
self.t += 1
# Forward
z = np.dot(x, self.weights) + self.bias
pred = sigmoid(z)
# Gradient
error = pred - y
grad_w = error * x + self.l2_reg * self.weights
grad_b = error
# Adaptive learning rate (AdaGrad-like)
self.grad_moment2 += grad_w ** 2
self.bias_moment2 += grad_b ** 2
lr_w = self.lr / (np.sqrt(self.grad_moment2) + 1e-8)
lr_b = self.lr / (np.sqrt(self.bias_moment2) + 1e-8)
# Update
self.weights -= lr_w * grad_w
self.bias -= lr_b * grad_b
# Decay learning rate
self.lr = max(self.lr * self.lr_decay, self.min_lr)
# Track
self.predictions.append(pred)
self.actuals.append(y)
self.grad_norms.append(np.linalg.norm(grad_w))
return {
'pred': pred,
'error': error,
'grad_norm': np.linalg.norm(grad_w),
'lr': self.lr
}
def get_performance(self, last_n: int = 100) -> Dict:
"""Get recent performance metrics"""
if len(self.actuals) < 2:
return {'accuracy': 0.5}
n = min(last_n, len(self.actuals))
preds = np.array(self.predictions[-n:]) > 0.5
actuals = np.array(self.actuals[-n:])
accuracy = np.mean(preds == actuals)
# Directional accuracy for returns
if len(actuals) >= 10:
# Use last 10 predictions as a sequence
pred_returns = np.diff(self.predictions[-10:])
actual_returns = np.diff(self.actuals[-10:])
directional = np.mean(np.sign(pred_returns) == np.sign(actual_returns)) if len(pred_returns) > 0 else 0.5
else:
directional = accuracy
return {
'accuracy': accuracy,
'directional_accuracy': directional,
'avg_grad_norm': np.mean(self.grad_norms[-n:]) if self.grad_norms else 0,
'current_lr': self.lr,
'n_updates': self.t
}
class PerSymbolAdaptiveModel:
"""
Maintain separate online models for each symbol.
Key insight: SPY behaves differently from TSLA.
Each asset needs its own:
- Feature weights
- Learning rate schedule
- Regime detection
"""
def __init__(self,
n_features: int = 10,
base_lr: float = 0.01,
symbols: Optional[List[str]] = None):
self.n_features = n_features
self.base_lr = base_lr
self.symbols = symbols or []
# Per-symbol models
self.models: Dict[str, OnlineLogisticRegression] = {}
# Performance tracking
self.symbol_performance: Dict[str, List[Dict]] = defaultdict(list)
# Auto-detect symbols
self.seen_symbols = set()
def _get_or_create_model(self, symbol: str) -> OnlineLogisticRegression:
"""Get model for symbol, create if new"""
if symbol not in self.models:
# Meta-learn initial weights from similar symbols
init_weights = self._meta_initialize(symbol)
model = OnlineLogisticRegression(
n_features=self.n_features,
initial_lr=self.base_lr * np.random.uniform(0.8, 1.2)
)
if init_weights is not None:
model.weights = init_weights
self.models[symbol] = model
self.seen_symbols.add(symbol)
return self.models[symbol]
def _meta_initialize(self, new_symbol: str) -> Optional[np.ndarray]:
"""
Meta-learning: initialize new symbol model from similar symbols.
Use average of best-performing similar models.
"""
if len(self.models) < 3:
return None
# Get recent performance
perf = []
for sym, model in self.models.items():
p = model.get_performance(last_n=50)
perf.append((sym, p.get('accuracy', 0.5), model.weights))
# Use top 3 models as initialization
perf.sort(key=lambda x: x[1], reverse=True)
top_weights = [p[2] for p in perf[:3]]
return np.mean(top_weights, axis=0)
def update(self, symbol: str, x: np.ndarray, y: int) -> Dict:
"""Update model for a specific symbol"""
model = self._get_or_create_model(symbol)
metrics = model.update(x, y)
# Track performance
perf = model.get_performance(last_n=20)
self.symbol_performance[symbol].append(perf)
metrics['symbol'] = symbol
return metrics
def predict(self, symbol: str, x: np.ndarray) -> Dict:
"""Predict for a specific symbol"""
model = self._get_or_create_model(symbol)
prob = model.predict_proba(x)
return {
'symbol': symbol,
'probability': prob,
'prediction': 1 if prob > 0.5 else 0,
'confidence': abs(prob - 0.5) * 2, # 0 = unsure, 1 = certain
'model_age': model.t
}
def get_symbol_ranking(self) -> pd.DataFrame:
"""Rank symbols by recent model performance"""
rows = []
for symbol, model in self.models.items():
perf = model.get_performance(last_n=100)
rows.append({
'symbol': symbol,
'accuracy': perf['accuracy'],
'directional_accuracy': perf['directional_accuracy'],
'n_samples': model.t,
'current_lr': perf['current_lr'],
'grad_norm': perf['avg_grad_norm']
})
df = pd.DataFrame(rows)
if not df.empty:
df = df.sort_values('directional_accuracy', ascending=False)
return df
def detect_concept_drift(self, symbol: str,
window_short: int = 50,
window_long: int = 200) -> Dict:
"""
Detect if the relationship between features and target has changed.
Uses accuracy comparison: recent vs older performance.
If recent << older → concept drift detected → need retraining/adaptation.
"""
model = self.models.get(symbol)
if model is None or len(model.actuals) < window_long:
return {'drift_detected': False, 'reason': 'insufficient_data'}
recent = model.get_performance(last_n=window_short)['accuracy']
older = model.get_performance(last_n=window_long)['accuracy']
# Drift if recent accuracy significantly worse
drift_threshold = -0.15 # 15% accuracy drop
drift_score = recent - older
drift_detected = drift_score < drift_threshold
return {
'drift_detected': drift_detected,
'drift_score': drift_score,
'recent_accuracy': recent,
'older_accuracy': older,
'threshold': drift_threshold,
'action': 'reset_learning_rate' if drift_detected else 'continue',
'symbol': symbol
}
def adapt_to_drift(self, symbol: str):
"""Adapt model when drift detected"""
model = self.models.get(symbol)
if model is None:
return
# Reset learning rate to initial (forget old, learn new)
model.lr = model.initial_lr * 2 # Higher LR to adapt faster
model.grad_moment2 = np.zeros(self.n_features)
model.bias_moment2 = 0.0
print(f" [Drift] Reset learning rate for {symbol} to {model.lr:.4f}")
def get_full_state(self) -> Dict:
"""Export full state for persistence"""
return {
'n_features': self.n_features,
'base_lr': self.base_lr,
'symbols': list(self.seen_symbols),
'models': {
sym: {
'weights': model.weights.tolist(),
'bias': model.bias,
'n_updates': model.t,
'lr': model.lr
}
for sym, model in self.models.items()
}
}
class ConceptDriftMonitor:
"""
System-wide concept drift monitoring across all symbols.
Automatically detects when markets have structurally changed
and triggers model adaptation.
"""
def __init__(self,
per_symbol_model: PerSymbolAdaptiveModel,
check_interval: int = 100,
drift_threshold: float = -0.15):
self.model = per_symbol_model
self.check_interval = check_interval
self.drift_threshold = drift_threshold
self.step_count = 0
self.drift_history = []
self.adaptation_log = []
def check_all_symbols(self) -> List[Dict]:
"""Check all symbols for drift and adapt if needed"""
self.step_count += 1
if self.step_count % self.check_interval != 0:
return []
results = []
for symbol in self.model.seen_symbols:
drift_result = self.model.detect_concept_drift(symbol)
results.append(drift_result)
if drift_result['drift_detected']:
self.model.adapt_to_drift(symbol)
self.drift_history.append({
'step': self.step_count,
'symbol': symbol,
'score': drift_result['drift_score'],
'recent_acc': drift_result['recent_accuracy'],
'older_acc': drift_result['older_accuracy']
})
return results
def get_drift_summary(self) -> pd.DataFrame:
"""Summary of all detected drifts"""
return pd.DataFrame(self.drift_history)
if __name__ == '__main__':
print("=" * 70)
print(" ONLINE LEARNING — PER-SYMBOL ADAPTIVE MODELS")
print("=" * 70)
# Simulate multiple symbols with different behaviors
np.random.seed(42)
# Symbol A: Strong momentum signal
# Symbol B: Weak/noise
# Symbol C: Regime switch at step 500
model = PerSymbolAdaptiveModel(n_features=5, base_lr=0.05)
monitor = ConceptDriftMonitor(model, check_interval=100)
n_steps = 800
for step in range(n_steps):
# Symbol A: feature 0 predicts direction with 65% accuracy
x_a = np.random.randn(5)
true_dir_a = 1 if x_a[0] > 0 else 0
if np.random.rand() > 0.65:
true_dir_a = 1 - true_dir_a # 35% noise
# Symbol B: no signal, pure noise
x_b = np.random.randn(5)
true_dir_b = np.random.randint(0, 2)
# Symbol C: regime switch at step 500
x_c = np.random.randn(5)
if step < 500:
true_dir_c = 1 if x_c[0] > 0 else 0 # feature 0 matters
if np.random.rand() > 0.6:
true_dir_c = 1 - true_dir_c
else:
# Regime switch: now feature 1 predicts (opposite!)
true_dir_c = 1 if x_c[1] < 0 else 0
if np.random.rand() > 0.6:
true_dir_c = 1 - true_dir_c
# Update models
model.update('AAPL', x_a, true_dir_a)
model.update('JUNK', x_b, true_dir_b)
model.update('REGIME', x_c, true_dir_c)
# Periodic drift check
if step % 100 == 0 and step > 0:
monitor.check_all_symbols()
# Results
print(f"\nTrained on {n_steps} steps per symbol")
print(f"\nPer-Symbol Performance:")
ranking = model.get_symbol_ranking()
print(ranking.to_string(index=False))
# Drift detection for REGIME symbol
drift_result = model.detect_concept_drift('REGIME', window_short=50, window_long=300)
print(f"\nREGIME Symbol Drift Detection:")
print(f" Drift detected: {drift_result['drift_detected']}")
print(f" Recent accuracy: {drift_result['recent_accuracy']:.3f}")
print(f" Older accuracy: {drift_result['older_accuracy']:.3f}")
print(f" Drift score: {drift_result['drift_score']:+.3f}")
print(f"\n Key Insights:")
print(f" - AAPL model should have ~60-65% accuracy (real signal)")
print(f" - JUNK model should have ~50% accuracy (pure noise)")
print(f" - REGIME model should detect drift at step 500")
print(f" - Each symbol gets its OWN learning rate and weights")
print(f" - Drift triggers adaptive LR reset")