""" Vector database for pattern recognition and similarity search Uses ChromaDB for efficient vector storage and retrieval """ import chromadb from chromadb.config import Settings import numpy as np import pandas as pd from typing import List, Dict, Optional, Tuple import logging import json from datetime import datetime from config import Config logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TradingPatternVectorDB: """ Vector database for storing and searching trading patterns Uses embeddings to find similar market conditions and patterns """ def __init__(self, persist_dir: Optional[str] = None): self.persist_dir = persist_dir or Config.CHROMA_PERSIST_DIR # Initialize ChromaDB self.client = chromadb.Client(Settings( persist_directory=self.persist_dir, anonymized_telemetry=False )) # Get or create collection try: self.collection = self.client.get_or_create_collection( name=Config.CHROMA_COLLECTION_NAME, metadata={"description": "Trading patterns and market conditions"} ) logger.info(f"Initialized vector DB with {self.collection.count()} patterns") except Exception as e: logger.error(f"Error initializing ChromaDB: {e}") self.collection = None def create_price_pattern_embedding(self, prices: List[float]) -> List[float]: """ Create embedding from price sequence Normalizes and extracts features from price movements Args: prices: List of price values Returns: Embedding vector """ if len(prices) < 2: return [0.0] * Config.EMBEDDING_DIMENSION prices_array = np.array(prices) # Normalize prices normalized = (prices_array - prices_array.mean()) / (prices_array.std() + 1e-8) # Extract features features = [] # 1. Price returns returns = np.diff(normalized) features.extend([ returns.mean(), returns.std(), returns.min(), returns.max() ]) # 2. Trend features x = np.arange(len(normalized)) slope = np.polyfit(x, normalized, 1)[0] features.append(slope) # 3. Volatility features rolling_std = pd.Series(prices).rolling(5).std().fillna(0).values features.extend([ rolling_std.mean(), rolling_std.std() ]) # 4. Momentum features if len(prices) >= 10: momentum_5 = (prices_array[-1] - prices_array[-5]) / prices_array[-5] momentum_10 = (prices_array[-1] - prices_array[-10]) / prices_array[-10] features.extend([momentum_5, momentum_10]) else: features.extend([0.0, 0.0]) # 5. Pattern recognition features # Higher highs, lower lows detection highs = [i for i in range(1, len(prices_array)-1) if prices_array[i] > prices_array[i-1] and prices_array[i] > prices_array[i+1]] lows = [i for i in range(1, len(prices_array)-1) if prices_array[i] < prices_array[i-1] and prices_array[i] < prices_array[i+1]] features.extend([ len(highs) / len(prices_array), len(lows) / len(prices_array) ]) # Pad or truncate to fixed size target_size = 20 if len(features) < target_size: features.extend([0.0] * (target_size - len(features))) else: features = features[:target_size] # Normalize final embedding features = np.array(features) norm = np.linalg.norm(features) if norm > 0: features = features / norm return features.tolist() def create_indicator_embedding(self, indicators: Dict) -> List[float]: """ Create embedding from technical indicators Args: indicators: Dict with indicator values Returns: Embedding vector """ features = [] # RSI normalized if 'rsi' in indicators and indicators['rsi'] is not None: features.append((indicators['rsi'] - 50) / 50) else: features.append(0.0) # MACD if 'macd' in indicators and indicators['macd'] is not None: features.append(np.tanh(indicators['macd'] / 100)) else: features.append(0.0) # Stochastic if 'stoch_k' in indicators and indicators['stoch_k'] is not None: features.append((indicators['stoch_k'] - 50) / 50) else: features.append(0.0) # Bollinger Band position if all(k in indicators for k in ['close', 'bb_upper', 'bb_lower']): if indicators['bb_upper'] != indicators['bb_lower']: bb_pos = (indicators['close'] - indicators['bb_lower']) / \ (indicators['bb_upper'] - indicators['bb_lower']) features.append(bb_pos * 2 - 1) # Normalize to [-1, 1] else: features.append(0.0) else: features.append(0.0) # Volume relative if 'volume' in indicators and 'volume_sma' in indicators: if indicators['volume_sma'] > 0: vol_ratio = indicators['volume'] / indicators['volume_sma'] features.append(np.tanh(vol_ratio - 1)) else: features.append(0.0) else: features.append(0.0) # ATR (volatility) normalized if 'atr' in indicators and 'close' in indicators and indicators['close'] > 0: atr_pct = indicators['atr'] / indicators['close'] features.append(np.tanh(atr_pct * 10)) else: features.append(0.0) # Pad to minimum size min_size = 10 if len(features) < min_size: features.extend([0.0] * (min_size - len(features))) return features def add_pattern( self, pattern_id: str, symbol: str, timeframe: str, prices: List[float], indicators: Dict, outcome: Optional[str] = None, metadata: Optional[Dict] = None ): """ Add a trading pattern to the vector database Args: pattern_id: Unique pattern identifier symbol: Trading symbol timeframe: Timeframe of pattern prices: Price sequence indicators: Technical indicators outcome: Actual outcome (e.g., 'bullish_success', 'bearish_success') metadata: Additional metadata """ if self.collection is None: return try: # Create embeddings price_embedding = self.create_price_pattern_embedding(prices) indicator_embedding = self.create_indicator_embedding(indicators) # Combine embeddings combined_embedding = price_embedding + indicator_embedding # Prepare metadata meta = { 'symbol': symbol, 'timeframe': timeframe, 'timestamp': datetime.now().isoformat(), 'outcome': outcome or 'unknown', 'price_count': len(prices), **(metadata or {}) } # Add to collection self.collection.add( ids=[pattern_id], embeddings=[combined_embedding], metadatas=[meta], documents=[json.dumps({ 'prices': prices[-20:], # Store last 20 prices 'indicators': {k: v for k, v in indicators.items() if v is not None} })] ) logger.info(f"Added pattern {pattern_id} to vector DB") except Exception as e: logger.error(f"Error adding pattern to vector DB: {e}") def find_similar_patterns( self, prices: List[float], indicators: Dict, n_results: int = 10, symbol: Optional[str] = None ) -> List[Dict]: """ Find similar historical patterns Args: prices: Current price sequence indicators: Current indicators n_results: Number of results to return symbol: Optional symbol filter Returns: List of similar patterns with metadata """ if self.collection is None or self.collection.count() == 0: return [] try: # Create embedding for current pattern price_embedding = self.create_price_pattern_embedding(prices) indicator_embedding = self.create_indicator_embedding(indicators) combined_embedding = price_embedding + indicator_embedding # Search where_filter = {'symbol': symbol} if symbol else None results = self.collection.query( query_embeddings=[combined_embedding], n_results=min(n_results, self.collection.count()), where=where_filter ) # Format results similar_patterns = [] if results['ids'] and len(results['ids']) > 0: for i, pattern_id in enumerate(results['ids'][0]): similar_patterns.append({ 'pattern_id': pattern_id, 'similarity': 1 - results['distances'][0][i], # Convert distance to similarity 'metadata': results['metadatas'][0][i], 'document': json.loads(results['documents'][0][i]) if results['documents'][0][i] else {} }) return similar_patterns except Exception as e: logger.error(f"Error finding similar patterns: {e}") return [] def get_pattern_statistics(self, symbol: Optional[str] = None) -> Dict: """ Get statistics about stored patterns Args: symbol: Optional symbol filter Returns: Dict with statistics """ if self.collection is None: return {} try: total_count = self.collection.count() if total_count == 0: return {'total_patterns': 0} # Get all patterns (or filtered by symbol) where_filter = {'symbol': symbol} if symbol else None # Sample some patterns to get stats sample_size = min(100, total_count) results = self.collection.get( limit=sample_size, where=where_filter ) # Count outcomes outcomes = {} symbols_count = {} if results['metadatas']: for meta in results['metadatas']: outcome = meta.get('outcome', 'unknown') outcomes[outcome] = outcomes.get(outcome, 0) + 1 sym = meta.get('symbol', 'unknown') symbols_count[sym] = symbols_count.get(sym, 0) + 1 return { 'total_patterns': total_count, 'sampled': len(results['metadatas']) if results['metadatas'] else 0, 'outcomes_distribution': outcomes, 'symbols_distribution': symbols_count } except Exception as e: logger.error(f"Error getting pattern statistics: {e}") return {} def predict_outcome(self, prices: List[float], indicators: Dict, symbol: Optional[str] = None) -> Dict: """ Predict likely outcome based on similar historical patterns Args: prices: Current price sequence indicators: Current indicators symbol: Optional symbol filter Returns: Dict with prediction and confidence """ similar_patterns = self.find_similar_patterns(prices, indicators, n_results=20, symbol=symbol) if not similar_patterns: return { 'prediction': 'unknown', 'confidence': 0.0, 'sample_size': 0 } # Count outcomes weighted by similarity outcome_scores = {} total_weight = 0 for pattern in similar_patterns: outcome = pattern['metadata'].get('outcome', 'unknown') similarity = pattern['similarity'] if outcome != 'unknown': outcome_scores[outcome] = outcome_scores.get(outcome, 0) + similarity total_weight += similarity if total_weight == 0: return { 'prediction': 'unknown', 'confidence': 0.0, 'sample_size': len(similar_patterns) } # Normalize scores outcome_probs = {k: v / total_weight for k, v in outcome_scores.items()} # Get top prediction top_outcome = max(outcome_probs.items(), key=lambda x: x[1]) return { 'prediction': top_outcome[0], 'confidence': top_outcome[1], 'probabilities': outcome_probs, 'sample_size': len(similar_patterns), 'similar_patterns': similar_patterns[:5] # Top 5 for reference } def clear_old_patterns(self, days: int = 90): """Clear patterns older than specified days""" if self.collection is None: return try: from datetime import timedelta cutoff = (datetime.now() - timedelta(days=days)).isoformat() # This is a simplified version - ChromaDB doesn't have built-in time-based deletion # In production, you'd implement a more sophisticated cleanup strategy logger.info(f"Pattern cleanup for patterns older than {days} days would go here") except Exception as e: logger.error(f"Error clearing old patterns: {e}")