Spaces:
Running
Running
| import os | |
| import pickle | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from flashrank import Ranker, RerankRequest | |
| import logging | |
| import threading | |
| import time | |
| import ast | |
| import re | |
| from filelock import FileLock | |
| import atexit | |
| import gc | |
| from typing import List, Dict, Any, Optional, Tuple, Union | |
| from collections import defaultdict, OrderedDict # <-- FIX 1: Add OrderedDict | |
| try: | |
| import tree_sitter | |
| from tree_sitter import Language, Parser | |
| # Import individual language modules | |
| try: | |
| from tree_sitter_languages import get_language, get_parser | |
| TREE_SITTER_IMPORTS_AVAILABLE = True | |
| except ImportError: | |
| TREE_SITTER_IMPORTS_AVAILABLE = False | |
| TREE_SITTER_AVAILABLE = True | |
| logger = logging.getLogger("NeuralSessionEngine") | |
| logger.info("π³ Tree-sitter successfully imported") | |
| # Initialize parsers dictionary | |
| TREE_SITTER_PARSERS = {} | |
| TREE_SITTER_LANGUAGES = {} | |
| except ImportError as e: | |
| TREE_SITTER_AVAILABLE = False | |
| TREE_SITTER_IMPORTS_AVAILABLE = False | |
| logging.warning(f"β Tree-sitter import failed: {e}") | |
| logging.warning("Install: pip install tree-sitter tree-sitter-languages") | |
| # === HYBRID SEARCH IMPORTS === | |
| try: | |
| from rank_bm25 import BM25Okapi | |
| BM25_AVAILABLE = True | |
| except ImportError: | |
| BM25_AVAILABLE = False | |
| logging.warning("BM25 not available. Install: pip install rank-bm25") | |
| try: | |
| import nltk | |
| from nltk.tokenize import word_tokenize, sent_tokenize | |
| NLTK_AVAILABLE = True | |
| except ImportError: | |
| NLTK_AVAILABLE = False | |
| logging.warning("NLTK not available. Install: pip install nltk") | |
| # Configure Logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger("NeuralSessionEngine") | |
| class VectorDatabase: | |
| def __init__(self, index_path="faiss_session_index.bin", metadata_path="session_metadata.pkl"): | |
| self.index_path = index_path | |
| self.metadata_path = metadata_path | |
| self.lock_path = index_path + ".lock" | |
| # File lock for multi-process safety | |
| self.file_lock = FileLock(self.lock_path, timeout=60) | |
| self.memory_lock = threading.RLock() | |
| logger.info("π§ Initializing Production Vector Engine with Hybrid Search...") | |
| # Load models with error handling | |
| try: | |
| self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu') | |
| self.ranker = Ranker(model_name="ms-marco-MiniLM-L-12-v2", cache_dir="./flashrank_cache") | |
| except Exception as e: | |
| logger.error(f"β Failed to load models: {e}") | |
| raise RuntimeError(f"Model initialization failed: {e}") | |
| self.tree_sitter_parsers = {} | |
| self.tree_sitter_languages = {} | |
| # Load or create index with file locking | |
| self._load_or_create_index() | |
| # === FIX 1: LAZY LOADING & LRU CACHE (Memory Safe) === | |
| # REMOVED: self._initialize_bm25_from_metadata() - No OOM on startup! | |
| # Instead, use LRU Cache to load sessions only when searched | |
| self.bm25_cache_size = 50 # Limit concurrent BM25 indices in memory | |
| self.bm25_indices = OrderedDict() # {(user_id, chat_id): BM25Okapi} with LRU | |
| self.bm25_docs = {} # {(user_id, chat_id): [tokenized_documents]} | |
| self.bm25_doc_to_vector = {} # {(user_id, chat_id): [vector_ids]} | |
| self.bm25_lock = threading.RLock() | |
| # Performance tracking | |
| self.query_history = [] | |
| self.performance_stats = { | |
| "exact_matches": 0, | |
| "semantic_matches": 0, | |
| "bm25_matches": 0, | |
| "hybrid_matches": 0, | |
| "fallback_matches": 0, | |
| "avg_retrieval_time": 0 | |
| } | |
| # Query type classification stats | |
| self.query_types = defaultdict(int) | |
| # Register cleanup | |
| atexit.register(self._cleanup) | |
| logger.info(f"β Vector Engine Ready. Index: {self.index.ntotal} vectors, {len(self.metadata)} metadata entries") | |
| logger.info(f"β BM25 LRU Cache: {self.bm25_cache_size} sessions max, BM25 Available: {BM25_AVAILABLE}") | |
| # ==================== FIX 2: LAZY BM25 LOADING ==================== | |
| def _get_or_build_bm25(self, user_id: str, chat_id: str) -> Optional[BM25Okapi]: | |
| """ | |
| Retrieve BM25 index from cache or build it on-demand (Lazy Load). | |
| Uses LRU eviction to prevent memory explosion. | |
| """ | |
| if not BM25_AVAILABLE: | |
| return None | |
| key = (user_id, chat_id) | |
| with self.bm25_lock: | |
| # 1. CACHE HIT: Move to end (mark as recently used) | |
| if key in self.bm25_indices: | |
| self.bm25_indices.move_to_end(key) | |
| return self.bm25_indices[key] | |
| # 2. CACHE MISS: Build index on the fly | |
| logger.debug(f"π Building BM25 index on-demand for session {key}") | |
| tokenized_corpus = [] | |
| vector_ids = [] | |
| # Filter documents for this user only (session isolation) | |
| with self.memory_lock: | |
| for idx, meta in enumerate(self.metadata): | |
| if meta.get("user_id") == user_id and meta.get("chat_id") == chat_id: | |
| text = meta.get("text", "") | |
| tokens = self._tokenize_for_bm25(text) | |
| if tokens: # Only add non-empty tokenized docs | |
| tokenized_corpus.append(tokens) | |
| vector_ids.append(idx) | |
| if not tokenized_corpus: | |
| logger.debug(f"β οΈ No documents found for BM25 index {key}") | |
| return None | |
| # Build BM25 index | |
| try: | |
| bm25 = BM25Okapi(tokenized_corpus) | |
| # Store additional metadata for scoring | |
| self.bm25_docs[key] = tokenized_corpus | |
| self.bm25_doc_to_vector[key] = vector_ids | |
| # 3. STORE IN CACHE with LRU EVICTION POLICY | |
| if len(self.bm25_indices) >= self.bm25_cache_size: | |
| # Remove oldest entry | |
| oldest_key, _ = self.bm25_indices.popitem(last=False) | |
| # Clean up associated data | |
| if oldest_key in self.bm25_docs: | |
| del self.bm25_docs[oldest_key] | |
| if oldest_key in self.bm25_doc_to_vector: | |
| del self.bm25_doc_to_vector[oldest_key] | |
| logger.debug(f"π§Ή Evicted BM25 cache for session {oldest_key}") | |
| self.bm25_indices[key] = bm25 | |
| logger.debug(f"β Built BM25 index for session {key}: {len(tokenized_corpus)} docs") | |
| return bm25 | |
| except Exception as e: | |
| logger.error(f"β Failed to build BM25 index for {key}: {e}") | |
| return None | |
| def _invalidate_bm25_cache(self, user_id: str, chat_id: str): | |
| """ | |
| Invalidate BM25 cache for a session (fast, no rebuild). | |
| Called when new documents are added. | |
| """ | |
| key = (user_id, chat_id) | |
| with self.bm25_lock: | |
| if key in self.bm25_indices: | |
| del self.bm25_indices[key] | |
| if key in self.bm25_docs: | |
| del self.bm25_docs[key] | |
| if key in self.bm25_doc_to_vector: | |
| del self.bm25_doc_to_vector[key] | |
| logger.debug(f"π§Ή Invalidated BM25 cache for session {key}") | |
| def _tokenize_for_bm25(self, text: str) -> List[str]: | |
| if not text: return [] | |
| # Try NLTK first | |
| if NLTK_AVAILABLE: | |
| try: | |
| return word_tokenize(text.lower()) | |
| except: pass | |
| # FALLBACK: Improved Regex for Code & Technical Terms | |
| # Captures: | |
| # 1. Standard words (word) | |
| # 2. Words with dots/dashes (v1.0, my-class) | |
| # 3. Code symbols combined with text (C++, #include) | |
| token_pattern = r'(?u)\b\w[\w.-]*\w\b|\b\w\b|[!#@$]\w+' | |
| return re.findall(token_pattern, text.lower()) | |
| # ==================== ENHANCED STORAGE WITH CACHE INVALIDATION ==================== | |
| def store_session_document(self, text: str, filename: str, user_id: str, chat_id: str, file_id: str = None) -> bool: | |
| """Store extracted file content with enhanced chunking and cache invalidation""" | |
| if not text or len(text) < 10 or not user_id: | |
| logger.warning(f"Invalid input for {filename}") | |
| return False | |
| logger.info(f"π₯ Storing {filename} ({len(text)} chars) for user {user_id[:8]}...") | |
| chunks_data = [] | |
| ext = os.path.splitext(filename)[1].lower() | |
| try: | |
| if TREE_SITTER_AVAILABLE and ext in [ | |
| '.py', '.js', '.jsx', '.ts', '.tsx', '.java', '.cpp', '.c', '.cc', | |
| '.go', '.rs', '.php', '.rb', '.cs', '.swift', '.kt', '.scala', | |
| '.lua', '.r', '.sh', '.bash', '.sql', '.html', '.css', '.xml', | |
| '.json', '.yaml', '.yml', '.toml', '.vue', '.md' | |
| ]: | |
| chunks_data = self._chunk_with_tree_sitter(text, filename) | |
| logger.debug(f"Used Tree-sitter for {filename}") | |
| elif ext == '.py': | |
| chunks_data = self._chunk_python_ast_enhanced(text, filename) | |
| elif ext in ['.js', '.html', '.css', '.java', '.cpp', '.ts', '.tsx', '.jsx', '.vue', '.xml', '.scss']: | |
| chunks_data = self._chunk_smart_code(text, filename) | |
| else: | |
| chunks_data = self._chunk_text_enhanced(text, chunk_size=600, overlap=100) | |
| except Exception as e: | |
| logger.error(f"Chunking failed for {filename}: {e}") | |
| chunks_data = self._chunk_text_enhanced(text, chunk_size=600, overlap=100) | |
| if not chunks_data and text: | |
| chunks_data = [{ | |
| "text": text[:2000], | |
| "type": "fallback", | |
| "name": "full_document" | |
| }] | |
| if not chunks_data: | |
| logger.error(f"No chunks generated for {filename}") | |
| return False | |
| final_texts = [] | |
| final_meta = [] | |
| for chunk in chunks_data: | |
| final_texts.append(chunk["text"]) | |
| final_meta.append({ | |
| "text": chunk["text"], | |
| "source": filename, | |
| "file_id": file_id, | |
| "type": "file", | |
| "subtype": chunk.get("type", "general"), | |
| "name": chunk.get("name", "unknown"), | |
| "user_id": user_id, | |
| "chat_id": chat_id, | |
| "timestamp": time.time(), | |
| "chunk_index": len(final_texts) | |
| }) | |
| # Whole file embedding for comprehensive answers | |
| whole_file_text = text[:4000] if len(text) > 4000 else text | |
| final_texts.append(f"Complete File: {filename} | Full Content: {whole_file_text}") | |
| final_meta.append({ | |
| "text": whole_file_text, | |
| "actual_content": text, | |
| "source": filename, | |
| "file_id": file_id, | |
| "type": "file", | |
| "subtype": "whole_file", | |
| "is_whole_file": True, | |
| "user_id": user_id, | |
| "chat_id": chat_id, | |
| "timestamp": time.time(), | |
| "chunk_index": -1 | |
| }) | |
| try: | |
| # Optimized embedding | |
| embeddings = self.embedder.encode( | |
| final_texts, | |
| show_progress_bar=False, | |
| batch_size=32, | |
| convert_to_numpy=True, | |
| normalize_embeddings=True | |
| ) | |
| faiss.normalize_L2(embeddings) | |
| with self.memory_lock: | |
| self.index.add(np.array(embeddings).astype('float32')) | |
| self.metadata.extend(final_meta) | |
| self._save_index() | |
| logger.info(f"β Stored {len(final_texts)} chunks from {filename} for user {user_id[:8]}") | |
| # ===== FIX 4: CACHE INVALIDATION instead of Immediate Rebuild ===== | |
| # When new files arrive, just invalidate the old cache. | |
| # It will auto-rebuild (including the new file) on next search. | |
| self._invalidate_bm25_cache(user_id, chat_id) | |
| self._verify_storage(user_id, chat_id, len(final_texts)) | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Failed to store vectors for {filename}: {e}") | |
| # Clean up partial storage | |
| with self.memory_lock: | |
| if self.index.ntotal >= len(final_texts): | |
| logger.warning("Rolling back partial storage...") | |
| self._rollback_partial_storage(user_id, chat_id) | |
| return False | |
| def _get_tree_sitter_parser(self, language_name: str) -> Optional[Any]: | |
| """Get or create a tree-sitter parser for a specific language (Robust Loader).""" | |
| if not TREE_SITTER_AVAILABLE: | |
| return None | |
| # 1. CHECK CACHE FIRST | |
| if language_name in self.tree_sitter_parsers: | |
| return self.tree_sitter_parsers[language_name] | |
| # 2. DEFINE MAP EARLY (Critical for fallback logic) | |
| lang_lib_map = { | |
| 'python': 'tree_sitter_python', | |
| 'javascript': 'tree_sitter_javascript', | |
| 'typescript': 'tree_sitter_typescript', | |
| 'java': 'tree_sitter_java', | |
| 'cpp': 'tree_sitter_cpp', | |
| 'c': 'tree_sitter_c', | |
| 'go': 'tree_sitter_go', | |
| 'rust': 'tree_sitter_rust', | |
| 'php': 'tree_sitter_php', | |
| 'ruby': 'tree_sitter_ruby', | |
| 'c_sharp': 'tree_sitter_c_sharp', | |
| 'swift': 'tree_sitter_swift', | |
| 'kotlin': 'tree_sitter_kotlin', | |
| 'scala': 'tree_sitter_scala', | |
| 'html': 'tree_sitter_html', | |
| 'css': 'tree_sitter_css', | |
| 'json': 'tree_sitter_json', | |
| 'yaml': 'tree_sitter_yaml', | |
| 'toml': 'tree_sitter_toml', | |
| 'xml': 'tree_sitter_xml', | |
| 'markdown': 'tree_sitter_markdown', | |
| 'bash': 'tree_sitter_bash', | |
| 'sql': 'tree_sitter_sql' | |
| } | |
| try: | |
| logger.debug(f"π³ Creating parser for {language_name}") | |
| # 3. PLAN A: Try using tree_sitter_languages (The Easy Way) | |
| if TREE_SITTER_IMPORTS_AVAILABLE: | |
| try: | |
| parser = get_parser(language_name) | |
| if parser: | |
| self.tree_sitter_parsers[language_name] = parser | |
| # self.tree_sitter_languages[language_name] = ... (helper handles this usually) | |
| logger.debug(f"β Got parser for {language_name} via tree_sitter_languages") | |
| return parser | |
| except Exception as e: | |
| logger.warning(f"β οΈ Plan A failed (tree_sitter_languages) for {language_name}: {e}") | |
| # 4. PLAN B: Manual Loading (The Robust Way) | |
| # This handles cases where the helper lib fails but the specific lang lib is installed | |
| if language_name in lang_lib_map: | |
| lib_name = lang_lib_map[language_name] | |
| try: | |
| parser = Parser() | |
| language = None | |
| # Import the specific module | |
| module = __import__(lib_name) | |
| # Extract Language object (supports both Property and Function styles) | |
| if hasattr(module, 'language'): | |
| lang_obj = module.language | |
| if callable(lang_obj): | |
| language = lang_obj() | |
| else: | |
| language = lang_obj | |
| if language: | |
| parser.set_language(language) | |
| self.tree_sitter_parsers[language_name] = parser | |
| self.tree_sitter_languages[language_name] = language | |
| logger.debug(f"β Loaded {language_name} manually from {lib_name}") | |
| return parser | |
| except ImportError: | |
| # Silence this warning usually, or log debug if needed | |
| logger.debug(f"β οΈ Manual load skipped: {lib_name} not installed.") | |
| except Exception as e: | |
| logger.warning(f"β Manual load error for {lib_name}: {e}") | |
| logger.warning(f"β Could not load parser for {language_name} (Plan A and B failed)") | |
| return None | |
| except Exception as e: | |
| logger.error(f"β Critical parser error for {language_name}: {e}") | |
| return None | |
| def _chunk_with_tree_sitter(self, text: str, filename: str) -> List[Dict[str, Any]]: | |
| """ | |
| ENHANCED Tree-sitter based code chunking with hybrid language support. | |
| Now properly handles files with multiple languages (HTML/CSS/JS, Vue, etc.) | |
| """ | |
| if not TREE_SITTER_AVAILABLE: | |
| logger.warning("β TREE-SITTER UNAVAILABLE: Falling back to alternative methods") | |
| ext = os.path.splitext(filename)[1].lower() | |
| if ext == '.py': | |
| return self._chunk_python_ast_enhanced(text, filename) | |
| return self._chunk_smart_code(text, filename) | |
| ext = os.path.splitext(filename)[1].lower() | |
| # Map extensions to tree-sitter language names | |
| language_map = { | |
| '.py': 'python', | |
| '.js': 'javascript', | |
| '.jsx': 'javascript', | |
| '.ts': 'typescript', | |
| '.tsx': 'typescript', | |
| '.java': 'java', | |
| '.cpp': 'cpp', | |
| '.c': 'c', | |
| '.cc': 'cpp', | |
| '.h': 'c', | |
| '.hpp': 'cpp', | |
| '.go': 'go', | |
| '.rs': 'rust', | |
| '.php': 'php', | |
| '.rb': 'ruby', | |
| '.cs': 'c_sharp', | |
| '.swift': 'swift', | |
| '.kt': 'kotlin', | |
| '.kts': 'kotlin', | |
| '.scala': 'scala', | |
| '.lua': 'lua', | |
| '.r': 'r', | |
| '.sh': 'bash', | |
| '.bash': 'bash', | |
| '.zsh': 'bash', | |
| '.sql': 'sql', | |
| '.html': 'html', | |
| '.htm': 'html', | |
| '.css': 'css', | |
| '.scss': 'css', | |
| '.sass': 'css', | |
| '.json': 'json', | |
| '.yaml': 'yaml', | |
| '.yml': 'yaml', | |
| '.toml': 'toml', | |
| '.xml': 'xml', | |
| '.vue': 'vue', | |
| '.md': 'markdown', | |
| } | |
| language_name = language_map.get(ext) | |
| if not language_name: | |
| logger.warning(f"π NO PARSER FOR EXTENSION: {ext} for {filename}, falling back to smart chunking") | |
| return self._chunk_smart_code(text, filename) | |
| # Define fallback chains for robust parsing | |
| fallback_sequence = [language_name] | |
| if language_name == 'javascript': | |
| fallback_sequence = ['javascript', 'tsx', 'typescript'] | |
| elif language_name == 'typescript': | |
| fallback_sequence = ['typescript', 'tsx'] | |
| elif language_name == 'jsx': | |
| fallback_sequence = ['javascript', 'tsx'] | |
| elif language_name == 'tsx': | |
| fallback_sequence = ['tsx', 'typescript'] | |
| # Special handling for hybrid language files | |
| if language_name in ['html', 'vue']: | |
| return self._chunk_hybrid_file(text, filename, language_name) | |
| return self._chunk_single_language(text, filename, fallback_sequence) | |
| def _chunk_single_language(self, text: str, filename: str, language_names: Union[str, List[str]]) -> List[Dict[str, Any]]: | |
| """Chunk a file with a single programming language, trying multiple parsers if needed.""" | |
| if isinstance(language_names, str): | |
| language_names = [language_names] | |
| chunks = [] | |
| for lang in language_names: | |
| try: | |
| parser = self._get_tree_sitter_parser(lang) | |
| if not parser: | |
| continue | |
| # Ensure text is bytes for tree-sitter | |
| text_bytes = bytes(text, 'utf-8') | |
| tree = parser.parse(text_bytes) | |
| root_node = tree.root_node | |
| # CRITICAL CHECK: If root is ERROR, this parser failed completely | |
| if not root_node or root_node.type == 'ERROR': | |
| logger.warning(f"β οΈ Parser {lang} failed (Root ERROR) for {filename}. Trying next..." if len(language_names) > 1 else f"β οΈ Parser {lang} failed for {filename}") | |
| continue | |
| # Define node types to extract based on language | |
| node_types_config = self._get_node_types_config(lang) | |
| target_types = node_types_config.get('extract', []) | |
| skip_types = node_types_config.get('skip', []) | |
| name_fields = node_types_config.get('name_fields', ['identifier', 'name']) | |
| local_chunks = [] | |
| # Helper to extract node text with context | |
| def extract_node_with_context(node, node_type, current_lang): | |
| start_line = node.start_point[0] | |
| end_line = node.end_point[0] | |
| # Adjust context based on language type | |
| context_config = node_types_config.get('context', {}) | |
| context_before = context_config.get('before', 5) | |
| context_after = context_config.get('after', 5) | |
| # Extract the node text | |
| node_text = text_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='ignore') | |
| # Get context lines | |
| lines = text.splitlines() | |
| context_start = max(0, start_line - context_before) | |
| context_end = min(len(lines), end_line + context_after + 1) | |
| # Build context segment | |
| if context_start < start_line or context_end > end_line + 1: | |
| segment_lines = lines[context_start:context_end] | |
| segment = '\n'.join(segment_lines) | |
| else: | |
| segment = node_text | |
| # Extract node name | |
| node_name = self._extract_node_name(node, text_bytes, name_fields) | |
| if not node_name: | |
| node_name = f"{node_type}_{start_line + 1}" | |
| return { | |
| "text": f"File: {filename} | Type: {node_type} | Name: {node_name}\n{segment}", | |
| "type": f"code_{node_type}", | |
| "name": node_name, | |
| "line_start": start_line + 1, | |
| "line_end": end_line + 1, | |
| "context_start": context_start + 1, | |
| "context_end": context_end, | |
| "language": current_lang | |
| } | |
| # Recursively find target nodes | |
| def find_target_nodes(node, depth=0): | |
| if depth > 200: # Prevent infinite recursion | |
| return | |
| if node.type in skip_types: | |
| return | |
| if node.type in target_types: | |
| extract = True | |
| # Heuristic: If node has ERROR child, it might be granularly broken | |
| # But for now we accept it unless it's total garbage | |
| if extract: | |
| local_chunks.append(extract_node_with_context(node, node.type, lang)) | |
| for child in node.children: | |
| find_target_nodes(child, depth + 1) | |
| # Start traversal | |
| find_target_nodes(root_node) | |
| # Add imports/top-level declarations | |
| import_chunks = self._extract_imports(root_node, text_bytes, lang, filename) | |
| if import_chunks: | |
| local_chunks = import_chunks + local_chunks | |
| # Success criteria: If we found chunks, we consider this parser successful | |
| if local_chunks: | |
| chunks = local_chunks | |
| logger.info(f"β TREE-SITTER SUCCESS: Parsed {filename} with ({lang}) into {len(chunks)} chunks") | |
| return chunks | |
| # If no chunks found, it might mean the parser didn't match anything useful (or syntax was weird) | |
| # We continue to next parser if available | |
| logger.debug(f"βΉοΈ Parser {lang} yielded 0 chunks for {filename}. Trying next...") | |
| except Exception as e: | |
| logger.warning(f"β οΈ Parser {lang} exception for {filename}: {e}") | |
| continue | |
| # If we get here, all parsers failed or returned 0 chunks | |
| logger.warning(f"β ALL Parsers failed for {filename}, falling back to smart chunking") | |
| # Final fallback check | |
| ext = os.path.splitext(filename)[1].lower() | |
| if ext == '.py': | |
| return self._chunk_python_ast_enhanced(text, filename) | |
| return self._chunk_smart_code(text, filename) | |
| def _chunk_hybrid_file(self, text: str, filename: str, primary_lang: str) -> List[Dict[str, Any]]: | |
| """ | |
| Chunk files that contain multiple languages (HTML with CSS/JS, Vue files, etc.) | |
| """ | |
| chunks = [] | |
| if primary_lang == 'html': | |
| # Use regex-based approach for HTML to avoid tree-sitter issues | |
| return self._chunk_html_with_embedded_languages(text, filename) | |
| elif primary_lang == 'vue': | |
| # Vue files have template, script, style sections | |
| return self._chunk_vue_file(text, filename) | |
| # Default fallback | |
| return self._chunk_smart_code(text, filename) | |
| def _chunk_html_with_embedded_languages(self, text: str, filename: str) -> List[Dict[str, Any]]: | |
| """Chunk HTML files with embedded CSS and JavaScript.""" | |
| chunks = [] | |
| # Split HTML into sections | |
| lines = text.splitlines() | |
| # Find all script and style tags | |
| script_pattern = re.compile(r'<script(\s[^>]*)?>([\s\S]*?)</script>', re.IGNORECASE) | |
| style_pattern = re.compile(r'<style(\s[^>]*)?>([\s\S]*?)</style>', re.IGNORECASE) | |
| # Extract and chunk script blocks | |
| for match in script_pattern.finditer(text): | |
| full_match = match.group(0) | |
| attrs = match.group(1) or "" | |
| content = match.group(2) | |
| # Determine language | |
| lang = 'javascript' | |
| if 'type="text/typescript"' in attrs or 'lang="ts"' in attrs: | |
| lang = 'typescript' | |
| # Find line numbers | |
| start_pos = match.start() | |
| line_num = text[:start_pos].count('\n') + 1 | |
| # Chunk the script content | |
| if content.strip(): | |
| script_chunks = self._chunk_single_language(content, filename, lang) | |
| if script_chunks: | |
| for chunk in script_chunks: | |
| chunk['text'] = f"File: {filename} | In <script> block (starting line {line_num}) | Language: {lang}\n{chunk['text']}" | |
| chunk['type'] = 'html_script_' + chunk['type'] | |
| chunk['language'] = lang | |
| chunks.extend(script_chunks) | |
| # Extract and chunk style blocks | |
| for match in style_pattern.finditer(text): | |
| full_match = match.group(0) | |
| attrs = match.group(1) or "" | |
| content = match.group(2) | |
| # Determine language | |
| lang = 'css' | |
| if 'lang="scss"' in attrs: | |
| lang = 'css' # Treat SCSS as CSS for now | |
| # Find line numbers | |
| start_pos = match.start() | |
| line_num = text[:start_pos].count('\n') + 1 | |
| # Chunk the style content | |
| if content.strip(): | |
| style_chunks = self._chunk_single_language(content, filename, lang) | |
| if style_chunks: | |
| for chunk in style_chunks: | |
| chunk['text'] = f"File: {filename} | In <style> block (starting line {line_num}) | Language: {lang}\n{chunk['text']}" | |
| chunk['type'] = 'html_style_' + chunk['type'] | |
| chunk['language'] = lang | |
| chunks.extend(style_chunks) | |
| # Chunk remaining HTML content | |
| # Remove script and style blocks for HTML-only chunking | |
| html_only = text | |
| for match in script_pattern.finditer(text): | |
| # Calculate line numbers separately to avoid backslash in f-string | |
| start_line = text[:match.start()].count('\n') + 1 | |
| end_line = text[:match.end()].count('\n') + 1 | |
| html_only = html_only.replace(match.group(0), f"<!-- SCRIPT BLOCK REMOVED (lines {start_line}-{end_line}) -->") | |
| for match in style_pattern.finditer(text): | |
| # Calculate line numbers separately to avoid backslash in f-string | |
| start_line = text[:match.start()].count('\n') + 1 | |
| end_line = text[:match.end()].count('\n') + 1 | |
| html_only = html_only.replace(match.group(0), f"<!-- STYLE BLOCK REMOVED (lines {start_line}-{end_line}) -->") | |
| # Use smart chunking for HTML | |
| html_chunks = self._chunk_smart_code(html_only, filename) | |
| if html_chunks: | |
| for chunk in html_chunks: | |
| chunk['type'] = 'html_' + chunk['type'] | |
| chunk['language'] = 'html' | |
| chunks.extend(html_chunks) | |
| if not chunks: | |
| return self._chunk_smart_code(text, filename) | |
| logger.info(f"β HYBRID HTML PARSED: {filename} into {len(chunks)} mixed-language chunks") | |
| return chunks | |
| def _chunk_vue_file(self, text: str, filename: str) -> List[Dict[str, Any]]: | |
| """Chunk Vue.js files with template, script, and style sections.""" | |
| chunks = [] | |
| # Extract template section | |
| template_match = re.search(r'<template[^>]*>([\s\S]*?)</template>', text) | |
| if template_match: | |
| template_content = template_match.group(1) | |
| # Find line numbers | |
| start_pos = template_match.start() | |
| line_num = text[:start_pos].count('\n') + 1 | |
| # Chunk template (treat as HTML) | |
| template_chunks = self._chunk_smart_code(template_content, filename) | |
| if template_chunks: | |
| for chunk in template_chunks: | |
| chunk['text'] = f"File: {filename} | Vue Template Section (starting line {line_num})\n{chunk['text']}" | |
| chunk['type'] = 'vue_template_' + chunk['type'] | |
| chunk['language'] = 'html' | |
| chunks.extend(template_chunks) | |
| # Extract script section | |
| script_match = re.search(r'<script[^>]*>([\s\S]*?)</script>', text, re.DOTALL) | |
| if script_match: | |
| script_content = script_match.group(1) | |
| attrs = script_match.group(0)[:script_match.group(0).index('>')] | |
| # Find line numbers | |
| start_pos = script_match.start() | |
| line_num = text[:start_pos].count('\n') + 1 | |
| # Detect language | |
| lang = 'javascript' | |
| if 'lang="ts"' in attrs or 'lang="typescript"' in attrs: | |
| lang = 'typescript' | |
| # Chunk script | |
| script_chunks = self._chunk_single_language(script_content, filename, lang) | |
| if script_chunks: | |
| for chunk in script_chunks: | |
| chunk['text'] = f"File: {filename} | Vue Script Section (starting line {line_num}) | Language: {lang}\n{chunk['text']}" | |
| chunk['type'] = 'vue_script_' + chunk['type'] | |
| chunk['language'] = lang | |
| chunks.extend(script_chunks) | |
| # Extract style section | |
| style_match = re.search(r'<style[^>]*>([\s\S]*?)</style>', text, re.DOTALL) | |
| if style_match: | |
| style_content = style_match.group(1) | |
| attrs = style_match.group(0)[:style_match.group(0).index('>')] | |
| # Find line numbers | |
| start_pos = style_match.start() | |
| line_num = text[:start_pos].count('\n') + 1 | |
| # Detect language | |
| lang = 'css' | |
| if 'lang="scss"' in attrs: | |
| lang = 'css' # Treat SCSS as CSS | |
| # Chunk style | |
| style_chunks = self._chunk_single_language(style_content, filename, lang) | |
| if style_chunks: | |
| for chunk in style_chunks: | |
| chunk['text'] = f"File: {filename} | Vue Style Section (starting line {line_num}) | Language: {lang}\n{chunk['text']}" | |
| chunk['type'] = 'vue_style_' + chunk['type'] | |
| chunk['language'] = lang | |
| chunks.extend(style_chunks) | |
| if not chunks: | |
| return self._chunk_smart_code(text, filename) | |
| logger.info(f"β VUE PARSED: {filename} into {len(chunks)} chunks") | |
| return chunks | |
| def _get_node_types_config(self, language_name: str) -> Dict[str, Any]: | |
| """Get configuration for what node types to extract for each language.""" | |
| configs = { | |
| 'python': { | |
| 'extract': ['function_definition', 'class_definition', 'async_function_definition'], | |
| 'skip': ['decorated_definition'], | |
| 'name_fields': ['identifier', 'name'], | |
| 'context': {'before': 2, 'after': 2} | |
| }, | |
| 'javascript': { | |
| 'extract': ['function_declaration', 'method_definition', 'class_declaration', | |
| 'arrow_function', 'function_expression', 'variable_declaration', | |
| 'export_statement'], | |
| 'skip': [], | |
| 'name_fields': ['identifier', 'name', 'property_identifier'], | |
| 'context': {'before': 5, 'after': 5} | |
| }, | |
| 'tsx': { | |
| 'extract': ['function_declaration', 'method_declaration', 'class_declaration', | |
| 'arrow_function', 'interface_declaration', 'type_alias_declaration', | |
| 'enum_declaration', 'export_statement', 'variable_declaration', | |
| 'lexical_declaration' | |
| ], | |
| 'skip': [], | |
| 'name_fields': ['identifier', 'name', 'type_identifier'], | |
| 'context': {'before': 2, 'after': 2} | |
| }, | |
| 'java': { | |
| 'extract': ['method_declaration', 'class_declaration', 'interface_declaration', | |
| 'constructor_declaration'], | |
| 'skip': [], | |
| 'name_fields': ['identifier'], | |
| 'context': {'before': 2, 'after': 2} | |
| }, | |
| 'cpp': { | |
| 'extract': ['function_definition', 'class_specifier', 'struct_specifier', | |
| 'namespace_definition'], | |
| 'skip': [], | |
| 'name_fields': ['identifier', 'type_identifier'], | |
| 'context': {'before': 2, 'after': 2} | |
| }, | |
| 'c': { | |
| 'extract': ['function_definition', 'struct_specifier', 'declaration'], | |
| 'skip': [], | |
| 'name_fields': ['identifier'], | |
| 'context': {'before': 2, 'after': 2} | |
| }, | |
| 'go': { | |
| 'extract': ['function_declaration', 'method_declaration', 'type_declaration'], | |
| 'skip': [], | |
| 'name_fields': ['identifier'], | |
| 'context': {'before': 2, 'after': 2} | |
| }, | |
| 'rust': { | |
| 'extract': ['function_item', 'impl_item', 'struct_item', 'trait_item', | |
| 'enum_item', 'mod_item'], | |
| 'skip': [], | |
| 'name_fields': ['identifier'], | |
| 'context': {'before': 2, 'after': 2} | |
| }, | |
| 'html': { | |
| 'extract': ['element', 'script_element', 'style_element'], | |
| 'skip': ['text'], | |
| 'name_fields': ['tag_name'], | |
| 'context': {'before': 1, 'after': 1} | |
| }, | |
| 'css': { | |
| 'extract': ['rule_set', 'at_rule'], | |
| 'skip': [], | |
| 'name_fields': [], | |
| 'context': {'before': 1, 'after': 1} | |
| }, | |
| 'sql': { | |
| 'extract': ['select_statement', 'insert_statement', 'update_statement', | |
| 'delete_statement', 'create_statement'], | |
| 'skip': [], | |
| 'name_fields': [], | |
| 'context': {'before': 1, 'after': 1} | |
| } | |
| } | |
| return configs.get(language_name, { | |
| 'extract': ['function_definition', 'class_definition'], | |
| 'skip': [], | |
| 'name_fields': ['identifier', 'name'], | |
| 'context': {'before': 2, 'after': 2} | |
| }) | |
| def _extract_node_name(self, node, text_bytes: bytes, name_fields: List[str]) -> str: | |
| """Extract the name/identifier from a node.""" | |
| for field in name_fields: | |
| for child in node.children: | |
| if child.type == field: | |
| return text_bytes[child.start_byte:child.end_byte].decode('utf-8', errors='ignore') | |
| # Try to find any identifier | |
| for child in node.children: | |
| if 'identifier' in child.type or 'name' in child.type: | |
| return text_bytes[child.start_byte:child.end_byte].decode('utf-8', errors='ignore') | |
| return "" | |
| def _extract_imports(self, root_node, text_bytes: bytes, language_name: str, filename: str) -> List[Dict[str, Any]]: | |
| """Extract import statements from the code.""" | |
| import_chunks = [] | |
| import_types = { | |
| 'python': ['import_statement', 'import_from_statement'], | |
| 'javascript': ['import_statement', 'import_declaration'], | |
| 'typescript': ['import_statement', 'import_declaration'], | |
| 'java': ['import_declaration'], | |
| 'cpp': ['preproc_include'], | |
| 'rust': ['use_declaration'], | |
| 'go': ['import_declaration'], | |
| 'php': ['use_declaration'], | |
| 'c_sharp': ['using_directive'] | |
| } | |
| target_types = import_types.get(language_name, []) | |
| def collect_imports(node): | |
| if node.type in target_types: | |
| import_text = text_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='ignore') | |
| if import_text: | |
| import_chunks.append({ | |
| "text": f"File: {filename} | Import Statement:\n{import_text}", | |
| "type": "code_imports", | |
| "name": "imports", | |
| "line_start": node.start_point[0] + 1, | |
| "line_end": node.end_point[0] + 1, | |
| "language": language_name | |
| }) | |
| for child in node.children: | |
| collect_imports(child) | |
| collect_imports(root_node) | |
| # Group imports if there are many | |
| if len(import_chunks) > 5: | |
| import_texts = [] | |
| for chunk in import_chunks: | |
| # Extract just the import statement from the chunk text | |
| import_lines = chunk['text'].split('\n', 1) | |
| if len(import_lines) > 1: | |
| import_texts.append(import_lines[1]) | |
| return [{ | |
| "text": f"File: {filename} | Import Statements:\n" + "\n".join(import_texts[:10]) + | |
| (f"\n... and {len(import_texts) - 10} more" if len(import_texts) > 10 else ""), | |
| "type": "code_imports", | |
| "name": "imports_grouped", | |
| "language": language_name | |
| }] | |
| return import_chunks | |
| def _fallback_chunking(self, text: str, filename: str) -> List[Dict[str, Any]]: | |
| """Fallback chunking method when tree-sitter fails.""" | |
| ext = os.path.splitext(filename)[1].lower() | |
| if ext == '.py': | |
| return self._chunk_python_ast_enhanced(text, filename) | |
| elif ext in ['.js', '.jsx', '.ts', '.tsx', '.java', '.cpp', '.c', '.html', '.css', '.vue']: | |
| return self._chunk_smart_code(text, filename) | |
| else: | |
| return self._chunk_text_enhanced(text) | |
| def delete_file(self, user_id: str, chat_id: str, file_id: str) -> bool: | |
| """Surgical Strike: Remove chunks belonging to a specific file ID""" | |
| with self.memory_lock: | |
| new_metadata = [] | |
| removed_count = 0 | |
| # Filter loop: Keep everything that DOESN'T match our file_id | |
| for meta in self.metadata: | |
| # Check matches: Must match User + Chat + FileID | |
| if (meta.get("user_id") == user_id and | |
| meta.get("chat_id") == chat_id and | |
| meta.get("file_id") == file_id): | |
| removed_count += 1 | |
| else: | |
| new_metadata.append(meta) | |
| if removed_count == 0: | |
| logger.info(f"βΉοΈ No vectors found for file_id {file_id}") | |
| return False | |
| logger.info(f"π§Ή Surgically removing {removed_count} vectors for file {file_id}...") | |
| # Rebuild Index (Standard Faiss Pattern) | |
| if not new_metadata: | |
| self.index = faiss.IndexFlatIP(384) | |
| else: | |
| surviving_texts = [m["text"] for m in new_metadata] | |
| try: | |
| embeddings = self.embedder.encode(surviving_texts, show_progress_bar=False) | |
| faiss.normalize_L2(embeddings) | |
| new_index = faiss.IndexFlatIP(384) | |
| new_index.add(np.array(embeddings).astype('float32')) | |
| self.index = new_index | |
| except Exception as e: | |
| logger.error(f"β Rebuild failed during file deletion: {e}") | |
| return False | |
| self.metadata = new_metadata | |
| self._save_index() | |
| # Invalidate Cache | |
| self._invalidate_bm25_cache(user_id, chat_id) | |
| logger.info(f"β Successfully deleted file {file_id}") | |
| return True | |
| # ==================== UPDATED BM25 SEARCH WITH LAZY LOADING ==================== | |
| def bm25_search(self, query: str, user_id: str, chat_id: str, | |
| filter_type: str = None, # <--- NEW ARGUMENT | |
| top_k: int = 50, min_score: float = 0.0) -> List[Dict[str, Any]]: | |
| """ | |
| Pure BM25 search within a session with lazy loading and STRICT FILTERING. | |
| """ | |
| if not BM25_AVAILABLE: | |
| logger.warning("BM25 not available. Falling back to semantic search.") | |
| return [] | |
| start_time = time.time() | |
| bm25_index = self._get_or_build_bm25(user_id, chat_id) | |
| if not bm25_index: | |
| return [] | |
| # Tokenize query | |
| query_tokens = self._tokenize_for_bm25(query) | |
| if not query_tokens: | |
| return [] | |
| try: | |
| key = (user_id, chat_id) | |
| bm25_scores = bm25_index.get_scores(query_tokens) | |
| # Get MORE candidates initially to account for filtering loss | |
| # If we filter 50% of items, we need 2x the buffer. | |
| candidate_limit = top_k * 4 | |
| top_indices = np.argsort(bm25_scores)[::-1][:candidate_limit] | |
| results = [] | |
| for idx in top_indices: | |
| score = float(bm25_scores[idx]) | |
| if score < min_score: | |
| continue | |
| if (key in self.bm25_doc_to_vector and | |
| idx < len(self.bm25_doc_to_vector[key])): | |
| vector_idx = self.bm25_doc_to_vector[key][idx] | |
| if vector_idx < len(self.metadata): | |
| meta = self.metadata[vector_idx] | |
| # --- THE CRITICAL FIX: APPLY FILTER --- | |
| if filter_type and meta.get("type") != filter_type: | |
| continue | |
| # -------------------------------------- | |
| normalized_score = min(score / 10.0, 1.0) if score > 0 else 0.0 | |
| results.append({ | |
| "id": int(vector_idx), | |
| "text": meta.get("text", ""), | |
| "meta": meta, | |
| "score": normalized_score, | |
| "match_type": "bm25", | |
| "bm25_raw_score": score, | |
| "is_whole_file": meta.get("is_whole_file", False) | |
| }) | |
| results.sort(key=lambda x: x["score"], reverse=True) | |
| return results[:top_k] | |
| except Exception as e: | |
| logger.error(f"BM25 search failed: {e}") | |
| return [] | |
| # ==================== HYBRID RETRIEVAL ENGINE (UPDATED) ==================== | |
| def hybrid_retrieve(self, query: str, user_id: str, chat_id: str, | |
| filter_type: str = None, top_k: int = 100, | |
| final_k: int = 5, strategy: str = "smart") -> List[Dict[str, Any]]: | |
| """ | |
| HYBRID RETRIEVAL: BM25 + Semantic + Exact Fusion | |
| Now with lazy-loaded BM25 indices for memory safety. | |
| """ | |
| logger.info(f"π€ HYBRID SEARCH: '{query[:80]}...' | Strategy: {strategy}") | |
| # Classify query type | |
| query_category = self._classify_query(query) | |
| self.query_types[query_category] += 1 | |
| # Choose strategy based on query type if "smart" | |
| if strategy == "smart": | |
| if query_category == "code": | |
| strategy = "bm25_first" | |
| elif query_category == "natural": | |
| strategy = "semantic_first" | |
| else: | |
| strategy = "fusion" | |
| start_time = time.time() | |
| # === PHASE 1: GET RESULTS FROM BOTH METHODS === | |
| bm25_results = [] | |
| semantic_results = [] | |
| if strategy in ["bm25_first", "fusion", "weighted", "smart"]: | |
| bm25_results = self.bm25_search( | |
| query=query, | |
| user_id=user_id, | |
| chat_id=chat_id, | |
| filter_type=filter_type, | |
| top_k=top_k * 2, | |
| min_score=0.1 | |
| ) | |
| if strategy in ["semantic_first", "fusion", "weighted", "smart"]: | |
| semantic_results = self._semantic_search( | |
| query=query, | |
| user_id=user_id, | |
| chat_id=chat_id, | |
| filter_type=filter_type, | |
| top_k=top_k * 2, | |
| min_score=0.1, | |
| final_k=top_k | |
| ) | |
| # === PHASE 2: APPLY STRATEGY === | |
| if strategy == "bm25_first": | |
| results = self._bm25_first_fusion(bm25_results, semantic_results, final_k) | |
| elif strategy == "semantic_first": | |
| results = self._semantic_first_fusion(semantic_results, bm25_results, final_k) | |
| elif strategy == "fusion": | |
| results = self._reciprocal_rank_fusion(bm25_results, semantic_results, final_k) | |
| else: | |
| # Default to fusion | |
| results = self._reciprocal_rank_fusion(bm25_results, semantic_results, final_k) | |
| # === PHASE 3: EXACT FALLBACK IF NO RESULTS === | |
| if not results: | |
| logger.info("π No hybrid results, trying exact fallback...") | |
| results = self.retrieve_exact( | |
| query=query, | |
| user_id=user_id, | |
| chat_id=chat_id, | |
| filter_type=filter_type, | |
| aggressive=True | |
| ) | |
| if results: | |
| self.performance_stats["fallback_matches"] += 1 | |
| return results[:final_k] | |
| # === PHASE 4: SMART RERANKING === | |
| if results and len(results) > 1: | |
| try: | |
| results = self._smart_rerank(query, results, final_k) | |
| except Exception as e: | |
| logger.warning(f"Reranking failed: {e}") | |
| # === PHASE 5: FINAL PROCESSING === | |
| elapsed = time.time() - start_time | |
| # Boost whole files for complete answers | |
| for result in results: | |
| if result.get("is_whole_file"): | |
| result["score"] = min(result["score"] * 1.2, 1.0) | |
| # Ensure scores are in 0-1 range | |
| for result in results: | |
| result["score"] = min(max(result["score"], 0.0), 1.0) | |
| # Sort by final score | |
| results.sort(key=lambda x: x["score"], reverse=True) | |
| # Update performance stats | |
| MIN_CONFIDENCE_THRESHOLD = 0.010 | |
| filtered_results = [] | |
| if results: | |
| # Check the winner. If the BEST result is trash, discard everything. | |
| top_score = results[0]["score"] | |
| if top_score >= MIN_CONFIDENCE_THRESHOLD: | |
| # The top result is good! Now filter the rest of the list. | |
| filtered_results = [r for r in results if r["score"] >= MIN_CONFIDENCE_THRESHOLD] | |
| logger.info(f"β Hybrid search found {len(filtered_results)} RELEVANT results (Top: {top_score:.3f})") | |
| self.performance_stats["hybrid_matches"] += 1 | |
| else: | |
| # The best we found was garbage (e.g. 0.011 for 'thanks'). Return NOTHING. | |
| logger.warning(f"π Results found but discarded due to low confidence (Top: {top_score:.3f} < {MIN_CONFIDENCE_THRESHOLD})") | |
| return [] | |
| else: | |
| logger.warning(f"β Hybrid search found no results") | |
| return [] | |
| return filtered_results[:final_k] | |
| # ==================== CORE METHODS (PRESERVED WITH FIXES) ==================== | |
| def _chunk_python_ast_enhanced(self, text: str, filename: str) -> List[Dict[str, Any]]: | |
| chunks = [] | |
| try: | |
| tree = ast.parse(text) | |
| lines = text.splitlines() | |
| # Helper to extract exact source including decorators | |
| def get_source_segment(node): | |
| # 1. Find start line (check decorators first) | |
| start_lineno = node.lineno | |
| if hasattr(node, 'decorator_list') and node.decorator_list: | |
| start_lineno = node.decorator_list[0].lineno | |
| # 2. Add minimal context buffer (1 line) | |
| start_idx = max(0, start_lineno - 2) | |
| end_idx = getattr(node, 'end_lineno', start_lineno) + 1 | |
| return "\n".join(lines[start_idx:end_idx]), start_idx, end_idx | |
| # Recursive visitor to flatten nested structures | |
| class CodeVisitor(ast.NodeVisitor): | |
| def visit_FunctionDef(self, node): | |
| self._add_chunk(node, "function") | |
| # Do NOT generic_visit chunks we've already handled to avoid duplicates | |
| # But DO visit nested functions if needed (optional) | |
| def visit_AsyncFunctionDef(self, node): | |
| self._add_chunk(node, "async_function") | |
| def visit_ClassDef(self, node): | |
| # 1. Create a "Summary Chunk" for the class definition (docstring + init) | |
| class_header, start, _ = get_source_segment(node) | |
| # Truncate body for the summary | |
| summary_text = f"Class Definition: {node.name}\n" + "\n".join(class_header.splitlines()[:10]) | |
| chunks.append({ | |
| "text": f"File: {filename} | Type: class_def | Name: {node.name}\n{summary_text}", | |
| "type": "code_class", | |
| "name": node.name, | |
| "line_start": start | |
| }) | |
| # 2. Recursively visit children (methods) | |
| self.generic_visit(node) | |
| def _add_chunk(self, node, type_label): | |
| content, start, end = get_source_segment(node) | |
| # Enforce context window limits here if needed | |
| chunks.append({ | |
| "text": f"File: {filename} | Type: {type_label} | Name: {node.name}\n{content}", | |
| "type": f"code_{type_label}", | |
| "name": node.name, | |
| "line_start": start, | |
| "line_end": end | |
| }) | |
| # Run the visitor | |
| CodeVisitor().visit(tree) | |
| # Capture Globals (Imports, Constants, Main Guard) | |
| global_context = [] | |
| for node in tree.body: | |
| if isinstance(node, (ast.Import, ast.ImportFrom, ast.Assign, ast.If)): | |
| # Only capture short logic blocks, skip giant if-blocks | |
| segment, _, _ = get_source_segment(node) | |
| if len(segment) < 500: | |
| global_context.append(segment) | |
| if global_context: | |
| chunks.insert(0, { | |
| "text": f"File: {filename} | Global Context\n" + "\n".join(global_context), | |
| "type": "code_globals", | |
| "name": "globals" | |
| }) | |
| except Exception as e: | |
| logger.warning(f"AST Parsing failed: {e}") | |
| return self._chunk_text_enhanced(text) # Fallback | |
| return chunks | |
| def _chunk_smart_code(self, text: str, filename: str) -> List[Dict[str, Any]]: | |
| """ENHANCED Structure-aware chunker with context preservation""" | |
| ext = os.path.splitext(filename)[1].lower() | |
| chunks = [] | |
| # Define split patterns for different languages | |
| patterns = { | |
| '.html': r'(?=\n\s*<[^/])', | |
| '.htm': r'(?=\n\s*<[^/])', | |
| '.xml': r'(?=\n\s*<[^/])', | |
| '.vue': r'(?=\n\s*<[^/])', | |
| '.js': r'(?=\n\s*(?:function|class|export|import|async|def))', | |
| '.jsx': r'(?=\n\s*(?:function|class|export|import|async|def))', | |
| '.ts': r'(?=\n\s*(?:function|class|export|import|async|interface|type|def))', | |
| '.tsx': r'(?=\n\s*(?:function|class|export|import|async|interface|type|def))', | |
| '.css': r'(?=\n\s*[.#@a-zA-Z])', | |
| '.scss': r'(?=\n\s*[.#@a-zA-Z])', | |
| '.java': r'(?=\n\s*(?:public|private|protected|class|interface|enum|@))', | |
| '.cpp': r'(?=\n\s*(?:#include|using|namespace|class|struct|enum|template))', | |
| } | |
| pattern = patterns.get(ext) | |
| # Fallback to standard if no pattern matches or regex fails | |
| if not pattern: | |
| return self._chunk_text_enhanced(text) | |
| try: | |
| segments = re.split(pattern, text) | |
| # Process with CONTEXT OVERLAP for better retrieval | |
| current_chunk = "" | |
| TARGET_SIZE = 1900 | |
| OVERLAP_SIZE = 100 | |
| for seg_idx, seg in enumerate(segments): | |
| if not seg.strip(): | |
| continue | |
| # Check if adding this segment would exceed target | |
| if len(current_chunk) + len(seg) > TARGET_SIZE and len(current_chunk) > 50: | |
| # Save current chunk | |
| chunk_text = current_chunk.strip() | |
| if chunk_text: | |
| chunks.append({ | |
| "text": f"File: {filename} | Content: {chunk_text}", | |
| "type": "code_block", | |
| "name": f"block_{len(chunks)}", | |
| "context_id": seg_idx | |
| }) | |
| # Start new chunk with overlap from previous | |
| current_chunk = current_chunk[-OVERLAP_SIZE:] + "\n" + seg if OVERLAP_SIZE > 0 else seg | |
| else: | |
| current_chunk += seg | |
| # Add final chunk | |
| if current_chunk: | |
| chunks.append({ | |
| "text": f"File: {filename} | Content: {current_chunk.strip()}", | |
| "type": "code_block", | |
| "name": f"block_{len(chunks)}", | |
| "context_id": len(segments) | |
| }) | |
| return chunks | |
| except Exception as e: | |
| logger.warning(f"Smart chunking failed for {filename}: {e}. Falling back.") | |
| return self._chunk_text_enhanced(text) | |
| def _chunk_text_enhanced(self, text: str, chunk_size: int = 600, overlap: int = 100) -> List[Dict[str, Any]]: | |
| """Enhanced text chunking that preserves natural boundaries""" | |
| chunks = [] | |
| # Try to split by paragraphs first | |
| paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] | |
| if not paragraphs: | |
| # Fallback to standard chunking | |
| return self._chunk_text_standard(text, chunk_size, overlap) | |
| current_chunk = "" | |
| for para in paragraphs: | |
| if len(current_chunk) + len(para) > chunk_size and current_chunk: | |
| chunks.append({ | |
| "text": current_chunk.strip(), | |
| "type": "text_paragraph", | |
| "name": f"para_{len(chunks)}" | |
| }) | |
| # Keep last overlap portion | |
| current_chunk = current_chunk[-overlap:] + "\n\n" + para if overlap > 0 else para | |
| else: | |
| current_chunk += "\n\n" + para if current_chunk else para | |
| if current_chunk: | |
| chunks.append({ | |
| "text": current_chunk.strip(), | |
| "type": "text_paragraph", | |
| "name": f"para_{len(chunks)}" | |
| }) | |
| return chunks | |
| def _chunk_text_standard(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[Dict[str, Any]]: | |
| """Standard text chunking with sliding window""" | |
| chunks = [] | |
| if len(text) <= chunk_size: | |
| return [{ | |
| "text": text, | |
| "type": "text_block", | |
| "name": "full_content" | |
| }] | |
| for i in range(0, len(text), chunk_size - overlap): | |
| chunk = text[i:i + chunk_size] | |
| if len(chunk) > 100: | |
| chunks.append({ | |
| "text": chunk, | |
| "type": "text_block", | |
| "name": f"chunk_{i//chunk_size}" | |
| }) | |
| return chunks | |
| # ==================== HELPER METHODS FOR HYBRID SEARCH ==================== | |
| def _classify_query(self, query: str) -> str: | |
| """Classify query type to determine best search strategy""" | |
| query_lower = query.lower() | |
| # Code/technical query indicators | |
| code_indicators = [ | |
| r'def\s+\w+\(', r'class\s+\w+', r'function\s+\w+', | |
| r'import\s+', r'from\s+', r'\.py$', r'\.js$', r'\.java$', | |
| r'\w+\(.*\)', r'\{.*\}', r'\[.*\]', r'=\s*\w+', | |
| r'const\s+', r'let\s+', r'var\s+', r'type\s+', | |
| r'interface\s+', r'export\s+', r'async\s+', r'await\s+', | |
| r'SELECT\s+', r'FROM\s+', r'WHERE\s+', r'JOIN\s+', | |
| r'#include', r'using\s+', r'namespace\s+', r'template\s+' | |
| ] | |
| for pattern in code_indicators: | |
| if re.search(pattern, query_lower): | |
| return "code" | |
| # Natural language query indicators | |
| natural_indicators = [ | |
| r'^how\s+', r'^what\s+', r'^why\s+', r'^explain\s+', | |
| r'^describe\s+', r'^summarize\s+', r'^tell\s+me\s+about', | |
| r'\?$', r'please', r'could you', r'would you', | |
| r'understand', r'meaning', r'concept', r'idea' | |
| ] | |
| for pattern in natural_indicators: | |
| if re.search(pattern, query_lower): | |
| return "natural" | |
| # Short keyword query (good for BM25) | |
| words = query.split() | |
| if len(words) <= 4 and len(query) < 30: | |
| return "keyword" | |
| # Mixed query | |
| return "mixed" | |
| def _bm25_first_fusion(self, bm25_results: List[Dict], semantic_results: List[Dict], | |
| final_k: int) -> List[Dict]: | |
| """BM25 first, supplement with semantic if needed""" | |
| results = bm25_results.copy() | |
| # If BM25 results are weak, add semantic results | |
| if not results or (results[0]["score"] < 0.3): | |
| seen_ids = set(r["id"] for r in results) | |
| for sem in semantic_results: | |
| if sem["id"] not in seen_ids and len(results) < final_k * 2: | |
| seen_ids.add(sem["id"]) | |
| sem["match_type"] = "semantic_supplement" | |
| results.append(sem) | |
| return results[:final_k] | |
| def _semantic_first_fusion(self, semantic_results: List[Dict], bm25_results: List[Dict], | |
| final_k: int) -> List[Dict]: | |
| """Semantic first, supplement with BM25 if needed""" | |
| results = semantic_results.copy() | |
| # If semantic results are weak, add BM25 results | |
| if not results or (results[0]["score"] < 0.3): | |
| seen_ids = set(r["id"] for r in results) | |
| for bm in bm25_results: | |
| if bm["id"] not in seen_ids and len(results) < final_k * 2: | |
| seen_ids.add(bm["id"]) | |
| bm["match_type"] = "bm25_supplement" | |
| results.append(bm) | |
| return results[:final_k] | |
| def _reciprocal_rank_fusion(self, results1: List[Dict[str, Any]], results2: List[Dict[str, Any]], | |
| final_k: int, k: int = 60) -> List[Dict[str, Any]]: | |
| """ | |
| Robust RRF Fusion for hybrid search (BM25 + Semantic). | |
| Prioritizes BM25 metadata (results1) on overlaps for keyword precision. | |
| Handles empty lists/duplicates gracefully; O(n log n) efficient. | |
| """ | |
| merged_scores = defaultdict(float) | |
| merged_meta: Dict[str, Dict[str, Any]] = {} | |
| # Process semantic (results2) first | |
| for rank, item in enumerate(results2): | |
| doc_id = item.get("id") | |
| if doc_id is None: | |
| continue # Skip invalid | |
| score = 1.0 / (rank + k) | |
| merged_scores[doc_id] += score | |
| merged_meta[doc_id] = item.copy() # Avoid mutating input | |
| # Process BM25 (results1) second: overwrites meta for precision | |
| for rank, item in enumerate(results1): | |
| doc_id = item.get("id") | |
| if doc_id is None: | |
| continue | |
| score = 1.0 / (rank + k) | |
| merged_scores[doc_id] += score | |
| merged_meta[doc_id] = item.copy() | |
| # Sort by descending RRF score | |
| sorted_ids = sorted(merged_scores, key=merged_scores.get, reverse=True) | |
| # Package top-k | |
| final_results = [] | |
| for doc_id in sorted_ids[:final_k]: | |
| if doc_id in merged_meta: | |
| res = merged_meta[doc_id].copy() | |
| res["score"] = merged_scores[doc_id] | |
| res["match_type"] = "hybrid_rrf" | |
| final_results.append(res) | |
| return final_results | |
| def _smart_rerank(self, query: str, candidates: List[Dict], final_k: int) -> List[Dict]: | |
| """Smart reranking using cross-encoder""" | |
| if len(candidates) <= 1: | |
| return candidates | |
| try: | |
| # Prepare passages for reranking | |
| passages = [] | |
| for cand in candidates[:30]: | |
| text = cand.get("text", "") | |
| if len(text) > 1000: | |
| text = text[:1000] + "..." | |
| source = cand.get("meta", {}).get("source", "unknown") | |
| subtype = cand.get("meta", {}).get("subtype", "general") | |
| passages.append({ | |
| "id": cand["id"], | |
| "text": f"File: {source} | Type: {subtype} | Content: {text}" | |
| }) | |
| if not passages: | |
| return candidates | |
| # Rerank with FlashRank | |
| rerank_request = RerankRequest(query=query, passages=passages) | |
| reranked = self.ranker.rerank(rerank_request) | |
| # Update scores based on reranking | |
| rerank_map = {r["id"]: r["score"] for r in reranked} | |
| for cand in candidates: | |
| if cand["id"] in rerank_map: | |
| cand["score"] = (cand["score"] * 0.3) + (rerank_map[cand["id"]] * 0.7) | |
| cand["match_type"] = cand.get("match_type", "unknown") + "_reranked" | |
| candidates.sort(key=lambda x: x["score"], reverse=True) | |
| logger.debug(f"Smart reranking applied to {len(candidates)} candidates") | |
| except Exception as e: | |
| logger.warning(f"Reranking error: {e}") | |
| return candidates[:final_k] | |
| # ==================== COMPATIBILITY METHODS (UPDATED) ==================== | |
| def retrieve_session_context(self, query: str, user_id: str, chat_id: str, | |
| filter_type: str = None, top_k: int = 100, | |
| final_k: int = 5, min_score: float = 0.25, | |
| use_hybrid: bool = True) -> List[Dict[str, Any]]: | |
| """ | |
| Enhanced retrieval with hybrid capabilities | |
| use_hybrid: Whether to use hybrid search (BM25 + semantic) | |
| """ | |
| # Use hybrid search by default if available | |
| if use_hybrid and BM25_AVAILABLE: | |
| return self.hybrid_retrieve( | |
| query=query, | |
| user_id=user_id, | |
| chat_id=chat_id, | |
| filter_type=filter_type, | |
| top_k=top_k, | |
| final_k=final_k, | |
| strategy="smart" | |
| ) | |
| # Fall back to original semantic search | |
| return self._semantic_search( | |
| query=query, | |
| user_id=user_id, | |
| chat_id=chat_id, | |
| filter_type=filter_type, | |
| top_k=top_k, | |
| min_score=min_score, | |
| final_k=final_k | |
| ) | |
| def _semantic_search(self, query: str, user_id: str, chat_id: str, | |
| filter_type: str = None, top_k: int = 100, | |
| min_score: float = 0.25, final_k: int = 10) -> List[Dict[str, Any]]: | |
| """Core semantic search engine""" | |
| with self.memory_lock: | |
| total_vectors = self.index.ntotal | |
| user_vectors = sum(1 for m in self.metadata if m.get("user_id") == user_id and m.get("chat_id") == chat_id) | |
| if total_vectors == 0 or user_vectors == 0: | |
| return [] | |
| try: | |
| query_vec = self.embedder.encode([query], show_progress_bar=False) | |
| faiss.normalize_L2(query_vec) | |
| except Exception as e: | |
| logger.error(f"β Failed to encode query: {e}") | |
| return [] | |
| search_k = min(top_k * 2, total_vectors) | |
| if search_k == 0: | |
| search_k = min(10, total_vectors) | |
| try: | |
| with self.memory_lock: | |
| if self.index.ntotal == 0: | |
| return [] | |
| D, I = self.index.search(np.array(query_vec).astype('float32'), search_k) | |
| except Exception as e: | |
| logger.error(f"β Search failed: {e}") | |
| return [] | |
| candidates = [] | |
| query_lower = query.lower() | |
| for i, idx in enumerate(I[0]): | |
| if idx == -1 or idx >= len(self.metadata): | |
| continue | |
| item = self.metadata[idx] | |
| # Filter by user and chat | |
| if item.get("user_id") != user_id or item.get("chat_id") != chat_id: | |
| continue | |
| # Filter by type if specified | |
| if filter_type and item.get("type") != filter_type: | |
| continue | |
| score = float(D[0][i]) | |
| if np.isnan(score) or np.isinf(score): | |
| continue | |
| # Whole file boosting | |
| is_whole_file = item.get("is_whole_file", False) or item.get("subtype") == "whole_file" | |
| if is_whole_file: | |
| filename = item.get("source", "").lower() | |
| if filename in query_lower or any(word in filename for word in query_lower.split()): | |
| score = 2.5 | |
| if item.get("actual_content"): | |
| item = item.copy() | |
| item["text"] = item["actual_content"] | |
| if score < min_score: | |
| continue | |
| candidates.append({ | |
| "id": int(idx), | |
| "text": item.get("text", ""), | |
| "meta": item, | |
| "score": score | |
| }) | |
| return candidates | |
| def retrieve_exact(self, query: str, user_id: str, chat_id: str, | |
| filter_type: str = None, aggressive: bool = True) -> List[Dict[str, Any]]: | |
| """PRIMARY EXACT MATCH RETRIEVAL - Accuracy First!""" | |
| start_time = time.time() | |
| query_lower = query.lower().strip() | |
| if self.index.ntotal == 0 or not user_id: | |
| logger.warning(f"β Empty index or invalid user_id") | |
| return [] | |
| logger.info(f"π― EXACT MODE: Searching for '{query[:80]}...'") | |
| all_candidates = [] | |
| exact_matches = [] | |
| # TACTIC 1: BRUTE FORCE SUBSTRING SEARCH | |
| logger.debug("π Tactic 1: Brute force substring search...") | |
| with self.memory_lock: | |
| for idx, meta in enumerate(self.metadata): | |
| if meta.get("user_id") != user_id or meta.get("chat_id") != chat_id: | |
| continue | |
| if filter_type and meta.get("type") != filter_type: | |
| continue | |
| text = meta.get("text", "").lower() | |
| actual_content = meta.get("actual_content", "").lower() | |
| if query_lower in text or query_lower in actual_content: | |
| score = 3.0 | |
| match_type = "exact_substring" | |
| display_text = meta.get("actual_content", meta.get("text", "")) | |
| exact_matches.append({ | |
| "id": idx, | |
| "text": display_text, | |
| "meta": meta, | |
| "score": score, | |
| "match_type": match_type, | |
| "confidence": "perfect" | |
| }) | |
| if exact_matches: | |
| logger.info(f"β¨ Found {len(exact_matches)} PERFECT exact matches!") | |
| self.performance_stats["exact_matches"] += 1 | |
| exact_matches.sort(key=lambda x: ( | |
| 1 if x["meta"].get("is_whole_file") else 0, | |
| x["score"] | |
| ), reverse=True) | |
| elapsed = time.time() - start_time | |
| logger.info(f"β‘ Exact match retrieval took {elapsed:.3f}s") | |
| return exact_matches[:3] | |
| # TACTIC 2: KEYWORD MATCHING | |
| if aggressive: | |
| logger.debug("π Tactic 2: Aggressive keyword matching...") | |
| keywords = [w for w in re.findall(r'\b\w{3,}\b', query_lower) if len(w) > 2] | |
| if keywords: | |
| with self.memory_lock: | |
| for idx, meta in enumerate(self.metadata): | |
| if meta.get("user_id") != user_id or meta.get("chat_id") != chat_id: | |
| continue | |
| if filter_type and meta.get("type") != filter_type: | |
| continue | |
| text = meta.get("text", "").lower() | |
| keyword_matches = sum(1 for kw in keywords if kw in text) | |
| if keyword_matches >= max(1, len(keywords) * 0.6): | |
| score = 2.0 + (keyword_matches / len(keywords)) * 0.5 | |
| all_candidates.append({ | |
| "id": idx, | |
| "text": meta.get("actual_content", meta.get("text", "")), | |
| "meta": meta, | |
| "score": score, | |
| "match_type": "keyword_explosion", | |
| "keyword_match_ratio": keyword_matches / len(keywords) | |
| }) | |
| # TACTIC 3: SEMANTIC SEARCH WITH LOW THRESHOLD | |
| logger.debug("π Tactic 3: Semantic search with low threshold...") | |
| semantic_results = self._semantic_search( | |
| query=query, | |
| user_id=user_id, | |
| chat_id=chat_id, | |
| filter_type=filter_type, | |
| top_k=200, | |
| min_score=0.1, | |
| final_k=30 | |
| ) | |
| for res in semantic_results: | |
| res["match_type"] = "semantic_low_threshold" | |
| all_candidates.append(res) | |
| # DEDUPLICATE AND RANK | |
| seen_ids = set() | |
| unique_candidates = [] | |
| for candidate in all_candidates: | |
| if candidate["id"] not in seen_ids: | |
| seen_ids.add(candidate["id"]) | |
| unique_candidates.append(candidate) | |
| unique_candidates.sort(key=lambda x: x["score"], reverse=True) | |
| # Apply reranking if available | |
| if unique_candidates: | |
| try: | |
| passages = [] | |
| for cand in unique_candidates[:50]: | |
| text_for_rerank = cand["text"] | |
| if len(text_for_rerank) > 1000: | |
| text_for_rerank = text_for_rerank[:1000] + "..." | |
| passages.append({ | |
| "id": cand["id"], | |
| "text": text_for_rerank | |
| }) | |
| if passages: | |
| rerank_request = RerankRequest(query=query, passages=passages) | |
| reranked = self.ranker.rerank(rerank_request) | |
| rerank_map = {r["id"]: r["score"] for r in reranked} | |
| for cand in unique_candidates: | |
| if cand["id"] in rerank_map: | |
| cand["score"] = cand["score"] * 0.3 + rerank_map[cand["id"]] * 0.7 | |
| unique_candidates.sort(key=lambda x: x["score"], reverse=True) | |
| except Exception as e: | |
| logger.warning(f"β οΈ Reranking failed: {e}") | |
| # FINAL SELECTION | |
| final_results = [] | |
| confidence_threshold = 0.4 if aggressive else 0.5 | |
| for cand in unique_candidates[:10]: | |
| if cand["score"] >= confidence_threshold: | |
| final_results.append(cand) | |
| if final_results: | |
| self.performance_stats["semantic_matches"] += 1 | |
| logger.info(f"β Found {len(final_results)} relevant results") | |
| top_match = final_results[0] | |
| logger.info(f"π Top match: Score={top_match['score']:.3f}, Type={top_match.get('match_type', 'unknown')}") | |
| if top_match["meta"].get("is_whole_file"): | |
| logger.info(f"π Returning whole file: {top_match['meta'].get('source', 'unknown')}") | |
| elapsed = time.time() - start_time | |
| logger.info(f"β±οΈ Exact retrieval completed in {elapsed:.3f}s") | |
| # Store in query history | |
| self.query_history.append({ | |
| "query": query[:100], | |
| "timestamp": time.time(), | |
| "results_count": len(final_results), | |
| "top_score": final_results[0]["score"] if final_results else 0, | |
| "elapsed_time": elapsed, | |
| "method": "exact" | |
| }) | |
| if len(self.query_history) > 1000: | |
| self.query_history = self.query_history[-500:] | |
| return final_results[:5] | |
| # ==================== INFRASTRUCTURE METHODS ==================== | |
| def _load_or_create_index(self): | |
| """Thread-safe and process-safe index loading/creation""" | |
| with self.file_lock: | |
| if os.path.exists(self.index_path) and os.path.exists(self.metadata_path): | |
| try: | |
| logger.info("π Loading existing vector index...") | |
| self.index = faiss.read_index(self.index_path) | |
| if self.index.ntotal < 0: | |
| raise ValueError("Corrupt index: negative vector count") | |
| with open(self.metadata_path, "rb") as f: | |
| self.metadata = pickle.load(f) | |
| if len(self.metadata) != self.index.ntotal: | |
| logger.error(f"β οΈ Metadata mismatch: {len(self.metadata)} entries vs {self.index.ntotal} vectors. Rebuilding...") | |
| self._create_new_index() | |
| return | |
| logger.info(f"β Loaded index with {self.index.ntotal} vectors, {len(self.metadata)} metadata entries") | |
| except Exception as e: | |
| logger.error(f"β οΈ Failed to load index: {e}. Creating new one.") | |
| self._create_new_index() | |
| else: | |
| logger.info("π Creating new vector index...") | |
| self._create_new_index() | |
| def _create_new_index(self): | |
| """Create fresh IndexFlatIP for cosine similarity""" | |
| dimension = 384 | |
| self.index = faiss.IndexFlatIP(dimension) | |
| self.metadata = [] | |
| logger.info(f"π Created new IndexFlatIP with dimension {dimension}") | |
| def _save_index(self): | |
| """Thread-safe and process-safe index saving with atomic writes""" | |
| with self.file_lock: | |
| temp_index = f"{self.index_path}.tmp" | |
| temp_meta = f"{self.metadata_path}.tmp" | |
| try: | |
| faiss.write_index(self.index, temp_index) | |
| with open(temp_meta, "wb") as f: | |
| pickle.dump(self.metadata, f) | |
| os.replace(temp_index, self.index_path) | |
| os.replace(temp_meta, self.metadata_path) | |
| logger.info(f"πΎ Saved index: {self.index.ntotal} vectors, {len(self.metadata)} metadata entries") | |
| except Exception as e: | |
| logger.error(f"β Failed to save index: {e}") | |
| for f in [temp_index, temp_meta]: | |
| if os.path.exists(f): | |
| try: | |
| os.remove(f) | |
| except Exception: | |
| logger.warning(f"Failed to remove temp file: {f}") | |
| finally: | |
| gc.collect() | |
| def _rollback_partial_storage(self, user_id: str, chat_id: str): | |
| """Remove partially stored vectors for a session""" | |
| try: | |
| new_metadata = [] | |
| surviving_texts = [] | |
| for meta in self.metadata: | |
| if meta.get("user_id") != user_id or meta.get("chat_id") != chat_id: | |
| new_metadata.append(meta) | |
| surviving_texts.append(meta["text"]) | |
| if len(new_metadata) == len(self.metadata): | |
| return | |
| if surviving_texts: | |
| embeddings = self.embedder.encode(surviving_texts, show_progress_bar=False) | |
| faiss.normalize_L2(embeddings) | |
| new_index = faiss.IndexFlatIP(384) | |
| new_index.add(np.array(embeddings).astype('float32')) | |
| self.index = new_index | |
| else: | |
| self.index = faiss.IndexFlatIP(384) | |
| self.metadata = new_metadata | |
| self._save_index() | |
| # Invalidate BM25 cache | |
| self._invalidate_bm25_cache(user_id, chat_id) | |
| logger.info(f"π Rolled back partial storage for user {user_id[:8]}") | |
| except Exception as e: | |
| logger.error(f"β Rollback failed: {e}") | |
| self._create_new_index() | |
| def _verify_storage(self, user_id: str, chat_id: str, expected_count: int): | |
| """Verify vectors were stored correctly""" | |
| with self.memory_lock: | |
| user_vectors = sum(1 for m in self.metadata if m.get("user_id") == user_id and m.get("chat_id") == chat_id) | |
| logger.info(f"π Storage verification: User {user_id[:8]} has {user_vectors} vectors (expected: {expected_count})") | |
| if user_vectors < expected_count: | |
| logger.warning(f"β οΈ Storage mismatch for user {user_id[:8]}") | |
| # ==================== ANALYTICS & ADMIN METHODS ==================== | |
| def get_retrieval_analytics(self, query: str = None) -> Dict[str, Any]: | |
| """Get detailed analytics about retrieval performance""" | |
| analytics = { | |
| "performance_stats": self.performance_stats.copy(), | |
| "query_types": dict(self.query_types), | |
| "query_history_count": len(self.query_history), | |
| "index_stats": { | |
| "total_vectors": self.index.ntotal, | |
| "metadata_count": len(self.metadata), | |
| "avg_metadata_size": 0, | |
| "bm25_cache_size": len(self.bm25_indices), | |
| "bm25_cache_capacity": self.bm25_cache_size, | |
| "bm25_available": BM25_AVAILABLE, | |
| "nltk_available": NLTK_AVAILABLE | |
| }, | |
| "recent_queries": [], | |
| "cache_stats": { | |
| "bm25_cache_hits": 0, # Could be tracked with more instrumentation | |
| "bm25_cache_misses": 0 | |
| } | |
| } | |
| if self.metadata: | |
| total_text_size = sum(len(m.get("text", "")) for m in self.metadata) | |
| analytics["index_stats"]["avg_metadata_size"] = total_text_size / len(self.metadata) | |
| for qh in self.query_history[-10:]: | |
| analytics["recent_queries"].append({ | |
| "query_preview": qh.get("query", "")[:50], | |
| "results": qh.get("results_count", 0), | |
| "top_score": qh.get("top_score", 0), | |
| "elapsed": qh.get("elapsed_time", 0), | |
| "method": qh.get("method", "unknown") | |
| }) | |
| if query: | |
| query_lower = query.lower() | |
| keyword_matches = defaultdict(int) | |
| for meta in self.metadata: | |
| text = meta.get("text", "").lower() | |
| for word in re.findall(r'\b\w{3,}\b', query_lower): | |
| if word in text: | |
| keyword_matches[word] += 1 | |
| analytics["query_analysis"] = { | |
| "query_length": len(query), | |
| "word_count": len(query.split()), | |
| "keyword_frequency": dict(keyword_matches), | |
| "has_file_reference": bool(re.search(r'\.(?:py|js|html|css|ts|java|cpp)', query, re.I)), | |
| "classified_as": self._classify_query(query) | |
| } | |
| return analytics | |
| def store_chat_context(self, messages: list, user_id: str, chat_id: str) -> bool: | |
| """Store chat history as session memory""" | |
| if not messages or not user_id: | |
| return False | |
| conversation = "" | |
| for msg in messages[-10:]: | |
| role = msg.get("role", "unknown") | |
| content = msg.get("content", "") | |
| if content: | |
| conversation += f"{role.upper()}: {content}\n\n" | |
| if len(conversation) < 50: | |
| return False | |
| chunks = self._chunk_text_enhanced(conversation, chunk_size=800, overlap=100) | |
| if not chunks: | |
| return False | |
| texts = [c["text"] for c in chunks] | |
| metadata_list = [] | |
| for i, chunk in enumerate(chunks): | |
| metadata_list.append({ | |
| "text": chunk["text"], | |
| "source": "chat_history", | |
| "type": "history", | |
| "user_id": user_id, | |
| "chat_id": chat_id, | |
| "timestamp": time.time(), | |
| "chunk_index": i | |
| }) | |
| try: | |
| embeddings = self.embedder.encode(texts, show_progress_bar=False) | |
| faiss.normalize_L2(embeddings) | |
| with self.memory_lock: | |
| self.index.add(np.array(embeddings).astype('float32')) | |
| self.metadata.extend(metadata_list) | |
| self._save_index() | |
| # Invalidate BM25 cache for this session | |
| self._invalidate_bm25_cache(user_id, chat_id) | |
| logger.info(f"π Stored {len(texts)} chat history chunks for user {user_id[:8]}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Failed to store chat history: {e}") | |
| return False | |
| def delete_session(self, user_id: str, chat_id: str) -> bool: | |
| """Surgical Strike: Permanently remove ONLY one specific session""" | |
| with self.memory_lock: | |
| new_metadata = [] | |
| removed_count = 0 | |
| for meta in self.metadata: | |
| if meta.get("user_id") == user_id and meta.get("chat_id") == chat_id: | |
| removed_count += 1 | |
| else: | |
| new_metadata.append(meta) | |
| if removed_count == 0: | |
| logger.info(f"βΉοΈ No vectors to delete for session {chat_id}") | |
| return False | |
| logger.info(f"π§Ή Surgically removing {removed_count} vectors for session {chat_id}...") | |
| if not new_metadata: | |
| self.index = faiss.IndexFlatIP(384) | |
| else: | |
| surviving_texts = [m["text"] for m in new_metadata] | |
| try: | |
| embeddings = self.embedder.encode(surviving_texts, show_progress_bar=False) | |
| faiss.normalize_L2(embeddings) | |
| new_index = faiss.IndexFlatIP(384) | |
| new_index.add(np.array(embeddings).astype('float32')) | |
| self.index = new_index | |
| except Exception as e: | |
| logger.error(f"β Rebuild failed: {e}") | |
| return False | |
| self.metadata = new_metadata | |
| self._save_index() | |
| # Invalidate BM25 cache for this session | |
| self._invalidate_bm25_cache(user_id, chat_id) | |
| logger.info(f"β Successfully deleted session {chat_id}") | |
| return True | |
| def get_user_stats(self, user_id: str) -> Dict[str, Any]: | |
| """Get statistics for a user's session""" | |
| with self.memory_lock: | |
| user_vectors = [] | |
| for meta in enumerate(self.metadata): | |
| if meta[1].get("user_id") == user_id: | |
| user_vectors.append(meta) | |
| stats = { | |
| "user_id": user_id, | |
| "total_vectors": len(user_vectors), | |
| "by_type": {}, | |
| "by_source": {}, | |
| "sessions": {}, | |
| "bm25_cached": False | |
| } | |
| for vec_id, vec in user_vectors: | |
| vec_type = vec.get("type", "unknown") | |
| source = vec.get("source", "unknown") | |
| chat_id = vec.get("chat_id", "unknown") | |
| stats["by_type"][vec_type] = stats["by_type"].get(vec_type, 0) + 1 | |
| stats["by_source"][source] = stats["by_source"].get(source, 0) + 1 | |
| stats["sessions"][chat_id] = stats["sessions"].get(chat_id, 0) + 1 | |
| # Check if any session has BM25 in cache | |
| for chat_id in stats["sessions"]: | |
| key = (user_id, chat_id) | |
| if key in self.bm25_indices: | |
| stats["bm25_cached"] = True | |
| break | |
| return stats | |
| def cleanup_old_sessions(self, max_age_hours: int = 24) -> int: | |
| """Clean up old session data""" | |
| current_time = time.time() | |
| cutoff = current_time - (max_age_hours * 3600) | |
| with self.memory_lock: | |
| old_metadata = [] | |
| new_metadata = [] | |
| affected_sessions = set() | |
| for meta in self.metadata: | |
| if meta.get("timestamp", 0) < cutoff: | |
| old_metadata.append(meta) | |
| user_id = meta.get("user_id") | |
| chat_id = meta.get("chat_id") | |
| if user_id and chat_id: | |
| affected_sessions.add((user_id, chat_id)) | |
| else: | |
| new_metadata.append(meta) | |
| if not old_metadata: | |
| return 0 | |
| logger.info(f"π§Ή Cleaning up {len(old_metadata)} old vectors...") | |
| recent_texts = [m["text"] for m in new_metadata] | |
| if recent_texts: | |
| try: | |
| embeddings = self.embedder.encode(recent_texts, show_progress_bar=False) | |
| faiss.normalize_L2(embeddings) | |
| self.index = faiss.IndexFlatIP(384) | |
| self.index.add(np.array(embeddings).astype('float32')) | |
| except Exception as e: | |
| logger.error(f"β Failed to rebuild index: {e}") | |
| return 0 | |
| else: | |
| self.index = faiss.IndexFlatIP(384) | |
| self.metadata = new_metadata | |
| self._save_index() | |
| # Remove affected sessions from BM25 cache | |
| for key in affected_sessions: | |
| self._invalidate_bm25_cache(*key) | |
| logger.info(f"β Cleanup complete. Removed {len(old_metadata)} vectors.") | |
| return len(old_metadata) | |
| def _cleanup(self): | |
| """Cleanup on exit""" | |
| try: | |
| if hasattr(self, 'file_lock'): | |
| self.file_lock.release() | |
| gc.collect() | |
| except Exception as e: | |
| logger.warning(f"Cleanup warning: {e}") | |
| # Global instance (singleton pattern) | |
| _vdb_instance = None | |
| _vdb_lock = threading.Lock() | |
| def get_vector_db(index_path: str = "faiss_session_index.bin", metadata_path: str = "session_metadata.pkl") -> VectorDatabase: | |
| """Singleton factory for VectorDatabase with thread-safe initialization""" | |
| global _vdb_instance | |
| if _vdb_instance is None: | |
| with _vdb_lock: | |
| if _vdb_instance is None: | |
| _vdb_instance = VectorDatabase(index_path, metadata_path) | |
| return _vdb_instance | |
| # For backward compatibility | |
| vdb = get_vector_db() |