43v3r8 / knowledge /vector_db.py
43v3r Tech
initial
fdeb336
"""
Vector database for pattern recognition and similarity search
Uses ChromaDB for efficient vector storage and retrieval
"""
import chromadb
from chromadb.config import Settings
import numpy as np
import pandas as pd
from typing import List, Dict, Optional, Tuple
import logging
import json
from datetime import datetime
from config import Config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TradingPatternVectorDB:
"""
Vector database for storing and searching trading patterns
Uses embeddings to find similar market conditions and patterns
"""
def __init__(self, persist_dir: Optional[str] = None):
self.persist_dir = persist_dir or Config.CHROMA_PERSIST_DIR
# Initialize ChromaDB
self.client = chromadb.Client(Settings(
persist_directory=self.persist_dir,
anonymized_telemetry=False
))
# Get or create collection
try:
self.collection = self.client.get_or_create_collection(
name=Config.CHROMA_COLLECTION_NAME,
metadata={"description": "Trading patterns and market conditions"}
)
logger.info(f"Initialized vector DB with {self.collection.count()} patterns")
except Exception as e:
logger.error(f"Error initializing ChromaDB: {e}")
self.collection = None
def create_price_pattern_embedding(self, prices: List[float]) -> List[float]:
"""
Create embedding from price sequence
Normalizes and extracts features from price movements
Args:
prices: List of price values
Returns:
Embedding vector
"""
if len(prices) < 2:
return [0.0] * Config.EMBEDDING_DIMENSION
prices_array = np.array(prices)
# Normalize prices
normalized = (prices_array - prices_array.mean()) / (prices_array.std() + 1e-8)
# Extract features
features = []
# 1. Price returns
returns = np.diff(normalized)
features.extend([
returns.mean(),
returns.std(),
returns.min(),
returns.max()
])
# 2. Trend features
x = np.arange(len(normalized))
slope = np.polyfit(x, normalized, 1)[0]
features.append(slope)
# 3. Volatility features
rolling_std = pd.Series(prices).rolling(5).std().fillna(0).values
features.extend([
rolling_std.mean(),
rolling_std.std()
])
# 4. Momentum features
if len(prices) >= 10:
momentum_5 = (prices_array[-1] - prices_array[-5]) / prices_array[-5]
momentum_10 = (prices_array[-1] - prices_array[-10]) / prices_array[-10]
features.extend([momentum_5, momentum_10])
else:
features.extend([0.0, 0.0])
# 5. Pattern recognition features
# Higher highs, lower lows detection
highs = [i for i in range(1, len(prices_array)-1)
if prices_array[i] > prices_array[i-1] and prices_array[i] > prices_array[i+1]]
lows = [i for i in range(1, len(prices_array)-1)
if prices_array[i] < prices_array[i-1] and prices_array[i] < prices_array[i+1]]
features.extend([
len(highs) / len(prices_array),
len(lows) / len(prices_array)
])
# Pad or truncate to fixed size
target_size = 20
if len(features) < target_size:
features.extend([0.0] * (target_size - len(features)))
else:
features = features[:target_size]
# Normalize final embedding
features = np.array(features)
norm = np.linalg.norm(features)
if norm > 0:
features = features / norm
return features.tolist()
def create_indicator_embedding(self, indicators: Dict) -> List[float]:
"""
Create embedding from technical indicators
Args:
indicators: Dict with indicator values
Returns:
Embedding vector
"""
features = []
# RSI normalized
if 'rsi' in indicators and indicators['rsi'] is not None:
features.append((indicators['rsi'] - 50) / 50)
else:
features.append(0.0)
# MACD
if 'macd' in indicators and indicators['macd'] is not None:
features.append(np.tanh(indicators['macd'] / 100))
else:
features.append(0.0)
# Stochastic
if 'stoch_k' in indicators and indicators['stoch_k'] is not None:
features.append((indicators['stoch_k'] - 50) / 50)
else:
features.append(0.0)
# Bollinger Band position
if all(k in indicators for k in ['close', 'bb_upper', 'bb_lower']):
if indicators['bb_upper'] != indicators['bb_lower']:
bb_pos = (indicators['close'] - indicators['bb_lower']) / \
(indicators['bb_upper'] - indicators['bb_lower'])
features.append(bb_pos * 2 - 1) # Normalize to [-1, 1]
else:
features.append(0.0)
else:
features.append(0.0)
# Volume relative
if 'volume' in indicators and 'volume_sma' in indicators:
if indicators['volume_sma'] > 0:
vol_ratio = indicators['volume'] / indicators['volume_sma']
features.append(np.tanh(vol_ratio - 1))
else:
features.append(0.0)
else:
features.append(0.0)
# ATR (volatility) normalized
if 'atr' in indicators and 'close' in indicators and indicators['close'] > 0:
atr_pct = indicators['atr'] / indicators['close']
features.append(np.tanh(atr_pct * 10))
else:
features.append(0.0)
# Pad to minimum size
min_size = 10
if len(features) < min_size:
features.extend([0.0] * (min_size - len(features)))
return features
def add_pattern(
self,
pattern_id: str,
symbol: str,
timeframe: str,
prices: List[float],
indicators: Dict,
outcome: Optional[str] = None,
metadata: Optional[Dict] = None
):
"""
Add a trading pattern to the vector database
Args:
pattern_id: Unique pattern identifier
symbol: Trading symbol
timeframe: Timeframe of pattern
prices: Price sequence
indicators: Technical indicators
outcome: Actual outcome (e.g., 'bullish_success', 'bearish_success')
metadata: Additional metadata
"""
if self.collection is None:
return
try:
# Create embeddings
price_embedding = self.create_price_pattern_embedding(prices)
indicator_embedding = self.create_indicator_embedding(indicators)
# Combine embeddings
combined_embedding = price_embedding + indicator_embedding
# Prepare metadata
meta = {
'symbol': symbol,
'timeframe': timeframe,
'timestamp': datetime.now().isoformat(),
'outcome': outcome or 'unknown',
'price_count': len(prices),
**(metadata or {})
}
# Add to collection
self.collection.add(
ids=[pattern_id],
embeddings=[combined_embedding],
metadatas=[meta],
documents=[json.dumps({
'prices': prices[-20:], # Store last 20 prices
'indicators': {k: v for k, v in indicators.items() if v is not None}
})]
)
logger.info(f"Added pattern {pattern_id} to vector DB")
except Exception as e:
logger.error(f"Error adding pattern to vector DB: {e}")
def find_similar_patterns(
self,
prices: List[float],
indicators: Dict,
n_results: int = 10,
symbol: Optional[str] = None
) -> List[Dict]:
"""
Find similar historical patterns
Args:
prices: Current price sequence
indicators: Current indicators
n_results: Number of results to return
symbol: Optional symbol filter
Returns:
List of similar patterns with metadata
"""
if self.collection is None or self.collection.count() == 0:
return []
try:
# Create embedding for current pattern
price_embedding = self.create_price_pattern_embedding(prices)
indicator_embedding = self.create_indicator_embedding(indicators)
combined_embedding = price_embedding + indicator_embedding
# Search
where_filter = {'symbol': symbol} if symbol else None
results = self.collection.query(
query_embeddings=[combined_embedding],
n_results=min(n_results, self.collection.count()),
where=where_filter
)
# Format results
similar_patterns = []
if results['ids'] and len(results['ids']) > 0:
for i, pattern_id in enumerate(results['ids'][0]):
similar_patterns.append({
'pattern_id': pattern_id,
'similarity': 1 - results['distances'][0][i], # Convert distance to similarity
'metadata': results['metadatas'][0][i],
'document': json.loads(results['documents'][0][i]) if results['documents'][0][i] else {}
})
return similar_patterns
except Exception as e:
logger.error(f"Error finding similar patterns: {e}")
return []
def get_pattern_statistics(self, symbol: Optional[str] = None) -> Dict:
"""
Get statistics about stored patterns
Args:
symbol: Optional symbol filter
Returns:
Dict with statistics
"""
if self.collection is None:
return {}
try:
total_count = self.collection.count()
if total_count == 0:
return {'total_patterns': 0}
# Get all patterns (or filtered by symbol)
where_filter = {'symbol': symbol} if symbol else None
# Sample some patterns to get stats
sample_size = min(100, total_count)
results = self.collection.get(
limit=sample_size,
where=where_filter
)
# Count outcomes
outcomes = {}
symbols_count = {}
if results['metadatas']:
for meta in results['metadatas']:
outcome = meta.get('outcome', 'unknown')
outcomes[outcome] = outcomes.get(outcome, 0) + 1
sym = meta.get('symbol', 'unknown')
symbols_count[sym] = symbols_count.get(sym, 0) + 1
return {
'total_patterns': total_count,
'sampled': len(results['metadatas']) if results['metadatas'] else 0,
'outcomes_distribution': outcomes,
'symbols_distribution': symbols_count
}
except Exception as e:
logger.error(f"Error getting pattern statistics: {e}")
return {}
def predict_outcome(self, prices: List[float], indicators: Dict, symbol: Optional[str] = None) -> Dict:
"""
Predict likely outcome based on similar historical patterns
Args:
prices: Current price sequence
indicators: Current indicators
symbol: Optional symbol filter
Returns:
Dict with prediction and confidence
"""
similar_patterns = self.find_similar_patterns(prices, indicators, n_results=20, symbol=symbol)
if not similar_patterns:
return {
'prediction': 'unknown',
'confidence': 0.0,
'sample_size': 0
}
# Count outcomes weighted by similarity
outcome_scores = {}
total_weight = 0
for pattern in similar_patterns:
outcome = pattern['metadata'].get('outcome', 'unknown')
similarity = pattern['similarity']
if outcome != 'unknown':
outcome_scores[outcome] = outcome_scores.get(outcome, 0) + similarity
total_weight += similarity
if total_weight == 0:
return {
'prediction': 'unknown',
'confidence': 0.0,
'sample_size': len(similar_patterns)
}
# Normalize scores
outcome_probs = {k: v / total_weight for k, v in outcome_scores.items()}
# Get top prediction
top_outcome = max(outcome_probs.items(), key=lambda x: x[1])
return {
'prediction': top_outcome[0],
'confidence': top_outcome[1],
'probabilities': outcome_probs,
'sample_size': len(similar_patterns),
'similar_patterns': similar_patterns[:5] # Top 5 for reference
}
def clear_old_patterns(self, days: int = 90):
"""Clear patterns older than specified days"""
if self.collection is None:
return
try:
from datetime import timedelta
cutoff = (datetime.now() - timedelta(days=days)).isoformat()
# This is a simplified version - ChromaDB doesn't have built-in time-based deletion
# In production, you'd implement a more sophisticated cleanup strategy
logger.info(f"Pattern cleanup for patterns older than {days} days would go here")
except Exception as e:
logger.error(f"Error clearing old patterns: {e}")