""" Semantic Search using pre-computed embeddings from Colab. Lightweight - only needs sentence-transformers for query encoding. """ import sqlite3 import numpy as np from typing import List, Dict, Any, Optional # Try importing sentence-transformers try: from sentence_transformers import SentenceTransformer HAS_TRANSFORMERS = True except ImportError: HAS_TRANSFORMERS = False SentenceTransformer = None class SemanticSearch: """ Semantic search using pre-computed embeddings. The embeddings.db file is created by running the Colab notebook. This class just loads and searches them. """ def __init__(self, embeddings_db: str = 'embeddings.db', messages_db: str = 'telegram.db'): self.embeddings_db = embeddings_db self.messages_db = messages_db self.model = None self.embeddings_loaded = False self.embeddings = [] self.message_ids = [] self.from_names = [] self.text_previews = [] def _load_model(self): """Load the embedding model (same one used in Colab).""" if not HAS_TRANSFORMERS: raise RuntimeError( "sentence-transformers not installed.\n" "Install with: pip install sentence-transformers" ) if self.model is None: print("Loading embedding model...") self.model = SentenceTransformer('intfloat/multilingual-e5-large') print("Model loaded!") def reload_embeddings(self): """Force reload embeddings from DB (e.g., after daily sync adds new ones).""" self.embeddings_loaded = False self.embeddings = np.array([]).reshape(0, 0) self.message_ids = [] self.from_names = [] self.text_previews = [] self._load_embeddings() def _load_embeddings(self): """Load all embeddings into memory for fast search.""" if self.embeddings_loaded: return import os if not os.path.exists(self.embeddings_db): print(f"Embeddings DB not found: {self.embeddings_db}") self.embeddings_loaded = True self.embeddings = np.array([]).reshape(0, 0) return print(f"Loading embeddings from {self.embeddings_db}...") conn = sqlite3.connect(self.embeddings_db) cursor = conn.execute( "SELECT message_id, from_name, text_preview, embedding FROM embeddings" ) emb_list = [] for row in cursor: msg_id, name, text, emb_blob = row emb = np.frombuffer(emb_blob, dtype=np.float32) self.message_ids.append(msg_id) self.from_names.append(name or '') self.text_previews.append(text or '') emb_list.append(emb) conn.close() if len(emb_list) == 0: print("No embeddings found in database") self.embeddings = np.array([]).reshape(0, 0) self.embeddings_loaded = True return # Stack into numpy array for fast computation self.embeddings = np.vstack(emb_list) # Normalize embeddings for cosine similarity norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) norms = np.where(norms == 0, 1, norms) # Avoid division by zero self.embeddings = self.embeddings / norms self.embeddings_loaded = True print(f"Loaded {len(self.message_ids)} embeddings") def search(self, query: str, limit: int = 50, min_score: float = 0.3) -> List[Dict[str, Any]]: """ Search for semantically similar messages. Args: query: The search query limit: Max results to return min_score: Minimum similarity score (0-1) Returns: List of dicts with message_id, from_name, text, score """ self._load_model() self._load_embeddings() if len(self.message_ids) == 0: return [] # Encode query (e5 model requires "query: " prefix) query_emb = self.model.encode([f"query: {query}"], convert_to_numpy=True)[0] # Compute cosine similarity with all embeddings # embeddings are already normalized from Colab query_norm = query_emb / np.linalg.norm(query_emb) similarities = np.dot(self.embeddings, query_norm) # Get top results top_indices = np.argsort(similarities)[::-1][:limit * 2] # Get more, then filter results = [] for idx in top_indices: score = float(similarities[idx]) if score < min_score: continue results.append({ 'message_id': int(self.message_ids[idx]), 'from_name': self.from_names[idx], 'text': self.text_previews[idx], 'score': score }) if len(results) >= limit: break return results def search_with_full_text(self, query: str, limit: int = 20) -> List[Dict[str, Any]]: """ Search and return full message text from messages DB. """ results = self.search(query, limit=limit) if not results: return [] # Get full text from messages DB conn = sqlite3.connect(self.messages_db) conn.row_factory = sqlite3.Row for result in results: cursor = conn.execute( "SELECT date, from_name, text_plain, reply_to_message_id FROM messages WHERE id = ?", (result['message_id'],) ) row = cursor.fetchone() if row: result['date'] = row['date'] result['from_name'] = row['from_name'] result['text'] = row['text_plain'] result['reply_to_message_id'] = row['reply_to_message_id'] conn.close() return results def _add_thread_context(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Add FULL thread context to search results. For each message, find the entire conversation thread: 1. Go up to find the root message 2. Get all messages in that thread """ if not results: return results conn = sqlite3.connect(self.messages_db) conn.row_factory = sqlite3.Row all_messages = {r['message_id']: r for r in results} thread_roots = set() # Step 1: Find root messages by following reply chains UP for result in results: msg_id = result['message_id'] reply_to = result.get('reply_to_message_id') # Follow the chain up to find the root current_id = msg_id current_reply_to = reply_to visited = {current_id} while current_reply_to and current_reply_to not in visited: visited.add(current_reply_to) cursor = conn.execute( "SELECT id, reply_to_message_id FROM messages WHERE id = ?", (current_reply_to,) ) row = cursor.fetchone() if row: current_id = row['id'] current_reply_to = row['reply_to_message_id'] else: break # current_id is now the root of this thread thread_roots.add(current_id) # Step 2: Get ALL messages in these threads (recursively) def get_thread_messages(root_ids, depth=0, max_depth=10): """Recursively get all messages in threads.""" if not root_ids or depth > max_depth: return [] messages = [] # Get root messages themselves if root_ids: placeholders = ','.join('?' * len(root_ids)) cursor = conn.execute(f""" SELECT id, date, from_name, text_plain, reply_to_message_id FROM messages WHERE id IN ({placeholders}) """, list(root_ids)) for row in cursor: if row['id'] not in all_messages: messages.append({ 'message_id': row['id'], 'date': row['date'], 'from_name': row['from_name'], 'text': row['text_plain'], 'reply_to_message_id': row['reply_to_message_id'], 'is_thread_context': True }) all_messages[row['id']] = messages[-1] # Get all replies to these messages all_ids = set(root_ids) | set(all_messages.keys()) if all_ids: placeholders = ','.join('?' * len(all_ids)) cursor = conn.execute(f""" SELECT id, date, from_name, text_plain, reply_to_message_id FROM messages WHERE reply_to_message_id IN ({placeholders}) LIMIT 200 """, list(all_ids)) new_ids = set() for row in cursor: if row['id'] not in all_messages: msg = { 'message_id': row['id'], 'date': row['date'], 'from_name': row['from_name'], 'text': row['text_plain'], 'reply_to_message_id': row['reply_to_message_id'], 'is_thread_context': True } messages.append(msg) all_messages[row['id']] = msg new_ids.add(row['id']) # Recursively get replies to the new messages if new_ids: messages.extend(get_thread_messages(new_ids, depth + 1, max_depth)) return messages # Get all thread messages get_thread_messages(thread_roots) conn.close() # Sort all messages by date all_list = list(all_messages.values()) all_list.sort(key=lambda x: x.get('date', '') or '') return all_list def search_with_ai_answer(self, query: str, ai_engine, limit: int = 30) -> Dict[str, Any]: """ Search semantically and send results to AI for reasoning. This combines the power of: 1. Semantic search (finds relevant messages by meaning) 2. Thread context (includes replies to/from found messages) 3. AI reasoning (reads messages and answers the question) """ results = self.search_with_full_text(query, limit=limit) if not results: return { 'query': query, 'answer': 'לא נמצאו הודעות רלוונטיות', 'mode': 'semantic_ai', 'results': [], 'count': 0 } # Get thread context for each result results_with_threads = self._add_thread_context(results) # Build context from semantic search results + threads context_text = "\n".join([ f"[{r.get('date', '')}] {r.get('from_name', 'Unknown')}: {r.get('text', '')[:500]}" for r in results_with_threads if r.get('text') ]) # Send to AI for reasoning reason_prompt = f"""You are analyzing a Telegram chat history to answer a question. The messages below were found using semantic search, along with their thread context (replies). Read them carefully and provide a comprehensive answer. Question: {query} Relevant messages and their threads: {context_text} Based on these messages, answer the question in Hebrew. If you can find the answer, provide it clearly. Pay special attention to reply chains - the answer might be in a reply! If you can infer information from context clues, do so. Cite specific messages when relevant. Answer:""" try: # Call the appropriate AI provider based on engine configuration provider = getattr(ai_engine, 'provider', None) if provider == 'gemini': answer = ai_engine._call_gemini(reason_prompt) elif provider == 'groq': answer = ai_engine._call_groq(reason_prompt) elif provider == 'ollama': answer = ai_engine._call_ollama(reason_prompt) else: answer = "AI engine not available for reasoning" except Exception as e: answer = f"שגיאה ב-AI: {str(e)}" return { 'query': query, 'answer': answer, 'mode': 'semantic_ai', 'results': results, # Original results for display 'count': len(results), 'total_with_threads': len(results_with_threads) } def is_available(self) -> bool: """Check if semantic search is available (DB exists and has embeddings).""" import os if not HAS_TRANSFORMERS or not os.path.exists(self.embeddings_db): return False try: conn = sqlite3.connect(self.embeddings_db) count = conn.execute("SELECT COUNT(*) FROM embeddings").fetchone()[0] conn.close() return count > 0 except Exception: return False def stats(self) -> Dict[str, Any]: """Get statistics about the embeddings.""" import os if not os.path.exists(self.embeddings_db): return {'available': False, 'error': 'embeddings.db not found'} conn = sqlite3.connect(self.embeddings_db) cursor = conn.execute("SELECT COUNT(*) FROM embeddings") count = cursor.fetchone()[0] conn.close() size_mb = os.path.getsize(self.embeddings_db) / (1024 * 1024) return { 'available': True, 'count': count, 'size_mb': round(size_mb, 1), 'model': 'intfloat/multilingual-e5-large' } # Singleton instance _search_instance = None def get_semantic_search() -> SemanticSearch: """Get or create semantic search instance.""" global _search_instance if _search_instance is None: _search_instance = SemanticSearch() return _search_instance if __name__ == '__main__': # Test ss = SemanticSearch() print("Stats:", ss.stats()) if ss.is_available(): results = ss.search("איפה אתה עובד?", limit=5) print("\nResults for 'איפה אתה עובד?':") for r in results: print(f" [{r['score']:.3f}] {r['from_name']}: {r['text'][:60]}...")