Spaces:
Build error
Build error
| """ | |
| Vector database for pattern recognition and similarity search | |
| Uses ChromaDB for efficient vector storage and retrieval | |
| """ | |
| import chromadb | |
| from chromadb.config import Settings | |
| import numpy as np | |
| import pandas as pd | |
| from typing import List, Dict, Optional, Tuple | |
| import logging | |
| import json | |
| from datetime import datetime | |
| from config import Config | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class TradingPatternVectorDB: | |
| """ | |
| Vector database for storing and searching trading patterns | |
| Uses embeddings to find similar market conditions and patterns | |
| """ | |
| def __init__(self, persist_dir: Optional[str] = None): | |
| self.persist_dir = persist_dir or Config.CHROMA_PERSIST_DIR | |
| # Initialize ChromaDB | |
| self.client = chromadb.Client(Settings( | |
| persist_directory=self.persist_dir, | |
| anonymized_telemetry=False | |
| )) | |
| # Get or create collection | |
| try: | |
| self.collection = self.client.get_or_create_collection( | |
| name=Config.CHROMA_COLLECTION_NAME, | |
| metadata={"description": "Trading patterns and market conditions"} | |
| ) | |
| logger.info(f"Initialized vector DB with {self.collection.count()} patterns") | |
| except Exception as e: | |
| logger.error(f"Error initializing ChromaDB: {e}") | |
| self.collection = None | |
| def create_price_pattern_embedding(self, prices: List[float]) -> List[float]: | |
| """ | |
| Create embedding from price sequence | |
| Normalizes and extracts features from price movements | |
| Args: | |
| prices: List of price values | |
| Returns: | |
| Embedding vector | |
| """ | |
| if len(prices) < 2: | |
| return [0.0] * Config.EMBEDDING_DIMENSION | |
| prices_array = np.array(prices) | |
| # Normalize prices | |
| normalized = (prices_array - prices_array.mean()) / (prices_array.std() + 1e-8) | |
| # Extract features | |
| features = [] | |
| # 1. Price returns | |
| returns = np.diff(normalized) | |
| features.extend([ | |
| returns.mean(), | |
| returns.std(), | |
| returns.min(), | |
| returns.max() | |
| ]) | |
| # 2. Trend features | |
| x = np.arange(len(normalized)) | |
| slope = np.polyfit(x, normalized, 1)[0] | |
| features.append(slope) | |
| # 3. Volatility features | |
| rolling_std = pd.Series(prices).rolling(5).std().fillna(0).values | |
| features.extend([ | |
| rolling_std.mean(), | |
| rolling_std.std() | |
| ]) | |
| # 4. Momentum features | |
| if len(prices) >= 10: | |
| momentum_5 = (prices_array[-1] - prices_array[-5]) / prices_array[-5] | |
| momentum_10 = (prices_array[-1] - prices_array[-10]) / prices_array[-10] | |
| features.extend([momentum_5, momentum_10]) | |
| else: | |
| features.extend([0.0, 0.0]) | |
| # 5. Pattern recognition features | |
| # Higher highs, lower lows detection | |
| highs = [i for i in range(1, len(prices_array)-1) | |
| if prices_array[i] > prices_array[i-1] and prices_array[i] > prices_array[i+1]] | |
| lows = [i for i in range(1, len(prices_array)-1) | |
| if prices_array[i] < prices_array[i-1] and prices_array[i] < prices_array[i+1]] | |
| features.extend([ | |
| len(highs) / len(prices_array), | |
| len(lows) / len(prices_array) | |
| ]) | |
| # Pad or truncate to fixed size | |
| target_size = 20 | |
| if len(features) < target_size: | |
| features.extend([0.0] * (target_size - len(features))) | |
| else: | |
| features = features[:target_size] | |
| # Normalize final embedding | |
| features = np.array(features) | |
| norm = np.linalg.norm(features) | |
| if norm > 0: | |
| features = features / norm | |
| return features.tolist() | |
| def create_indicator_embedding(self, indicators: Dict) -> List[float]: | |
| """ | |
| Create embedding from technical indicators | |
| Args: | |
| indicators: Dict with indicator values | |
| Returns: | |
| Embedding vector | |
| """ | |
| features = [] | |
| # RSI normalized | |
| if 'rsi' in indicators and indicators['rsi'] is not None: | |
| features.append((indicators['rsi'] - 50) / 50) | |
| else: | |
| features.append(0.0) | |
| # MACD | |
| if 'macd' in indicators and indicators['macd'] is not None: | |
| features.append(np.tanh(indicators['macd'] / 100)) | |
| else: | |
| features.append(0.0) | |
| # Stochastic | |
| if 'stoch_k' in indicators and indicators['stoch_k'] is not None: | |
| features.append((indicators['stoch_k'] - 50) / 50) | |
| else: | |
| features.append(0.0) | |
| # Bollinger Band position | |
| if all(k in indicators for k in ['close', 'bb_upper', 'bb_lower']): | |
| if indicators['bb_upper'] != indicators['bb_lower']: | |
| bb_pos = (indicators['close'] - indicators['bb_lower']) / \ | |
| (indicators['bb_upper'] - indicators['bb_lower']) | |
| features.append(bb_pos * 2 - 1) # Normalize to [-1, 1] | |
| else: | |
| features.append(0.0) | |
| else: | |
| features.append(0.0) | |
| # Volume relative | |
| if 'volume' in indicators and 'volume_sma' in indicators: | |
| if indicators['volume_sma'] > 0: | |
| vol_ratio = indicators['volume'] / indicators['volume_sma'] | |
| features.append(np.tanh(vol_ratio - 1)) | |
| else: | |
| features.append(0.0) | |
| else: | |
| features.append(0.0) | |
| # ATR (volatility) normalized | |
| if 'atr' in indicators and 'close' in indicators and indicators['close'] > 0: | |
| atr_pct = indicators['atr'] / indicators['close'] | |
| features.append(np.tanh(atr_pct * 10)) | |
| else: | |
| features.append(0.0) | |
| # Pad to minimum size | |
| min_size = 10 | |
| if len(features) < min_size: | |
| features.extend([0.0] * (min_size - len(features))) | |
| return features | |
| def add_pattern( | |
| self, | |
| pattern_id: str, | |
| symbol: str, | |
| timeframe: str, | |
| prices: List[float], | |
| indicators: Dict, | |
| outcome: Optional[str] = None, | |
| metadata: Optional[Dict] = None | |
| ): | |
| """ | |
| Add a trading pattern to the vector database | |
| Args: | |
| pattern_id: Unique pattern identifier | |
| symbol: Trading symbol | |
| timeframe: Timeframe of pattern | |
| prices: Price sequence | |
| indicators: Technical indicators | |
| outcome: Actual outcome (e.g., 'bullish_success', 'bearish_success') | |
| metadata: Additional metadata | |
| """ | |
| if self.collection is None: | |
| return | |
| try: | |
| # Create embeddings | |
| price_embedding = self.create_price_pattern_embedding(prices) | |
| indicator_embedding = self.create_indicator_embedding(indicators) | |
| # Combine embeddings | |
| combined_embedding = price_embedding + indicator_embedding | |
| # Prepare metadata | |
| meta = { | |
| 'symbol': symbol, | |
| 'timeframe': timeframe, | |
| 'timestamp': datetime.now().isoformat(), | |
| 'outcome': outcome or 'unknown', | |
| 'price_count': len(prices), | |
| **(metadata or {}) | |
| } | |
| # Add to collection | |
| self.collection.add( | |
| ids=[pattern_id], | |
| embeddings=[combined_embedding], | |
| metadatas=[meta], | |
| documents=[json.dumps({ | |
| 'prices': prices[-20:], # Store last 20 prices | |
| 'indicators': {k: v for k, v in indicators.items() if v is not None} | |
| })] | |
| ) | |
| logger.info(f"Added pattern {pattern_id} to vector DB") | |
| except Exception as e: | |
| logger.error(f"Error adding pattern to vector DB: {e}") | |
| def find_similar_patterns( | |
| self, | |
| prices: List[float], | |
| indicators: Dict, | |
| n_results: int = 10, | |
| symbol: Optional[str] = None | |
| ) -> List[Dict]: | |
| """ | |
| Find similar historical patterns | |
| Args: | |
| prices: Current price sequence | |
| indicators: Current indicators | |
| n_results: Number of results to return | |
| symbol: Optional symbol filter | |
| Returns: | |
| List of similar patterns with metadata | |
| """ | |
| if self.collection is None or self.collection.count() == 0: | |
| return [] | |
| try: | |
| # Create embedding for current pattern | |
| price_embedding = self.create_price_pattern_embedding(prices) | |
| indicator_embedding = self.create_indicator_embedding(indicators) | |
| combined_embedding = price_embedding + indicator_embedding | |
| # Search | |
| where_filter = {'symbol': symbol} if symbol else None | |
| results = self.collection.query( | |
| query_embeddings=[combined_embedding], | |
| n_results=min(n_results, self.collection.count()), | |
| where=where_filter | |
| ) | |
| # Format results | |
| similar_patterns = [] | |
| if results['ids'] and len(results['ids']) > 0: | |
| for i, pattern_id in enumerate(results['ids'][0]): | |
| similar_patterns.append({ | |
| 'pattern_id': pattern_id, | |
| 'similarity': 1 - results['distances'][0][i], # Convert distance to similarity | |
| 'metadata': results['metadatas'][0][i], | |
| 'document': json.loads(results['documents'][0][i]) if results['documents'][0][i] else {} | |
| }) | |
| return similar_patterns | |
| except Exception as e: | |
| logger.error(f"Error finding similar patterns: {e}") | |
| return [] | |
| def get_pattern_statistics(self, symbol: Optional[str] = None) -> Dict: | |
| """ | |
| Get statistics about stored patterns | |
| Args: | |
| symbol: Optional symbol filter | |
| Returns: | |
| Dict with statistics | |
| """ | |
| if self.collection is None: | |
| return {} | |
| try: | |
| total_count = self.collection.count() | |
| if total_count == 0: | |
| return {'total_patterns': 0} | |
| # Get all patterns (or filtered by symbol) | |
| where_filter = {'symbol': symbol} if symbol else None | |
| # Sample some patterns to get stats | |
| sample_size = min(100, total_count) | |
| results = self.collection.get( | |
| limit=sample_size, | |
| where=where_filter | |
| ) | |
| # Count outcomes | |
| outcomes = {} | |
| symbols_count = {} | |
| if results['metadatas']: | |
| for meta in results['metadatas']: | |
| outcome = meta.get('outcome', 'unknown') | |
| outcomes[outcome] = outcomes.get(outcome, 0) + 1 | |
| sym = meta.get('symbol', 'unknown') | |
| symbols_count[sym] = symbols_count.get(sym, 0) + 1 | |
| return { | |
| 'total_patterns': total_count, | |
| 'sampled': len(results['metadatas']) if results['metadatas'] else 0, | |
| 'outcomes_distribution': outcomes, | |
| 'symbols_distribution': symbols_count | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting pattern statistics: {e}") | |
| return {} | |
| def predict_outcome(self, prices: List[float], indicators: Dict, symbol: Optional[str] = None) -> Dict: | |
| """ | |
| Predict likely outcome based on similar historical patterns | |
| Args: | |
| prices: Current price sequence | |
| indicators: Current indicators | |
| symbol: Optional symbol filter | |
| Returns: | |
| Dict with prediction and confidence | |
| """ | |
| similar_patterns = self.find_similar_patterns(prices, indicators, n_results=20, symbol=symbol) | |
| if not similar_patterns: | |
| return { | |
| 'prediction': 'unknown', | |
| 'confidence': 0.0, | |
| 'sample_size': 0 | |
| } | |
| # Count outcomes weighted by similarity | |
| outcome_scores = {} | |
| total_weight = 0 | |
| for pattern in similar_patterns: | |
| outcome = pattern['metadata'].get('outcome', 'unknown') | |
| similarity = pattern['similarity'] | |
| if outcome != 'unknown': | |
| outcome_scores[outcome] = outcome_scores.get(outcome, 0) + similarity | |
| total_weight += similarity | |
| if total_weight == 0: | |
| return { | |
| 'prediction': 'unknown', | |
| 'confidence': 0.0, | |
| 'sample_size': len(similar_patterns) | |
| } | |
| # Normalize scores | |
| outcome_probs = {k: v / total_weight for k, v in outcome_scores.items()} | |
| # Get top prediction | |
| top_outcome = max(outcome_probs.items(), key=lambda x: x[1]) | |
| return { | |
| 'prediction': top_outcome[0], | |
| 'confidence': top_outcome[1], | |
| 'probabilities': outcome_probs, | |
| 'sample_size': len(similar_patterns), | |
| 'similar_patterns': similar_patterns[:5] # Top 5 for reference | |
| } | |
| def clear_old_patterns(self, days: int = 90): | |
| """Clear patterns older than specified days""" | |
| if self.collection is None: | |
| return | |
| try: | |
| from datetime import timedelta | |
| cutoff = (datetime.now() - timedelta(days=days)).isoformat() | |
| # This is a simplified version - ChromaDB doesn't have built-in time-based deletion | |
| # In production, you'd implement a more sophisticated cleanup strategy | |
| logger.info(f"Pattern cleanup for patterns older than {days} days would go here") | |
| except Exception as e: | |
| logger.error(f"Error clearing old patterns: {e}") | |