Spaces:
Sleeping
Sleeping
| """ | |
| Semantic Search using pre-computed embeddings from Colab. | |
| Lightweight - only needs sentence-transformers for query encoding. | |
| """ | |
| import sqlite3 | |
| import numpy as np | |
| from typing import List, Dict, Any, Optional | |
| # Try importing sentence-transformers | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| HAS_TRANSFORMERS = True | |
| except ImportError: | |
| HAS_TRANSFORMERS = False | |
| SentenceTransformer = None | |
| class SemanticSearch: | |
| """ | |
| Semantic search using pre-computed embeddings. | |
| The embeddings.db file is created by running the Colab notebook. | |
| This class just loads and searches them. | |
| """ | |
| def __init__(self, embeddings_db: str = 'embeddings.db', messages_db: str = 'telegram.db'): | |
| self.embeddings_db = embeddings_db | |
| self.messages_db = messages_db | |
| self.model = None | |
| self.embeddings_loaded = False | |
| self.embeddings = [] | |
| self.message_ids = [] | |
| self.from_names = [] | |
| self.text_previews = [] | |
| def _load_model(self): | |
| """Load the embedding model (same one used in Colab).""" | |
| if not HAS_TRANSFORMERS: | |
| raise RuntimeError( | |
| "sentence-transformers not installed.\n" | |
| "Install with: pip install sentence-transformers" | |
| ) | |
| if self.model is None: | |
| print("Loading embedding model...") | |
| self.model = SentenceTransformer('intfloat/multilingual-e5-large') | |
| print("Model loaded!") | |
| def reload_embeddings(self): | |
| """Force reload embeddings from DB (e.g., after daily sync adds new ones).""" | |
| self.embeddings_loaded = False | |
| self.embeddings = np.array([]).reshape(0, 0) | |
| self.message_ids = [] | |
| self.from_names = [] | |
| self.text_previews = [] | |
| self._load_embeddings() | |
| def _load_embeddings(self): | |
| """Load all embeddings into memory for fast search.""" | |
| if self.embeddings_loaded: | |
| return | |
| import os | |
| if not os.path.exists(self.embeddings_db): | |
| print(f"Embeddings DB not found: {self.embeddings_db}") | |
| self.embeddings_loaded = True | |
| self.embeddings = np.array([]).reshape(0, 0) | |
| return | |
| print(f"Loading embeddings from {self.embeddings_db}...") | |
| conn = sqlite3.connect(self.embeddings_db) | |
| cursor = conn.execute( | |
| "SELECT message_id, from_name, text_preview, embedding FROM embeddings" | |
| ) | |
| emb_list = [] | |
| for row in cursor: | |
| msg_id, name, text, emb_blob = row | |
| emb = np.frombuffer(emb_blob, dtype=np.float32) | |
| self.message_ids.append(msg_id) | |
| self.from_names.append(name or '') | |
| self.text_previews.append(text or '') | |
| emb_list.append(emb) | |
| conn.close() | |
| if len(emb_list) == 0: | |
| print("No embeddings found in database") | |
| self.embeddings = np.array([]).reshape(0, 0) | |
| self.embeddings_loaded = True | |
| return | |
| # Stack into numpy array for fast computation | |
| self.embeddings = np.vstack(emb_list) | |
| # Normalize embeddings for cosine similarity | |
| norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) | |
| norms = np.where(norms == 0, 1, norms) # Avoid division by zero | |
| self.embeddings = self.embeddings / norms | |
| self.embeddings_loaded = True | |
| print(f"Loaded {len(self.message_ids)} embeddings") | |
| def search(self, query: str, limit: int = 50, min_score: float = 0.3) -> List[Dict[str, Any]]: | |
| """ | |
| Search for semantically similar messages. | |
| Args: | |
| query: The search query | |
| limit: Max results to return | |
| min_score: Minimum similarity score (0-1) | |
| Returns: | |
| List of dicts with message_id, from_name, text, score | |
| """ | |
| self._load_model() | |
| self._load_embeddings() | |
| if len(self.message_ids) == 0: | |
| return [] | |
| # Encode query (e5 model requires "query: " prefix) | |
| query_emb = self.model.encode([f"query: {query}"], convert_to_numpy=True)[0] | |
| # Compute cosine similarity with all embeddings | |
| # embeddings are already normalized from Colab | |
| query_norm = query_emb / np.linalg.norm(query_emb) | |
| similarities = np.dot(self.embeddings, query_norm) | |
| # Get top results | |
| top_indices = np.argsort(similarities)[::-1][:limit * 2] # Get more, then filter | |
| results = [] | |
| for idx in top_indices: | |
| score = float(similarities[idx]) | |
| if score < min_score: | |
| continue | |
| results.append({ | |
| 'message_id': int(self.message_ids[idx]), | |
| 'from_name': self.from_names[idx], | |
| 'text': self.text_previews[idx], | |
| 'score': score | |
| }) | |
| if len(results) >= limit: | |
| break | |
| return results | |
| def search_with_full_text(self, query: str, limit: int = 20) -> List[Dict[str, Any]]: | |
| """ | |
| Search and return full message text from messages DB. | |
| """ | |
| results = self.search(query, limit=limit) | |
| if not results: | |
| return [] | |
| # Get full text from messages DB | |
| conn = sqlite3.connect(self.messages_db) | |
| conn.row_factory = sqlite3.Row | |
| for result in results: | |
| cursor = conn.execute( | |
| "SELECT date, from_name, text_plain, reply_to_message_id FROM messages WHERE id = ?", | |
| (result['message_id'],) | |
| ) | |
| row = cursor.fetchone() | |
| if row: | |
| result['date'] = row['date'] | |
| result['from_name'] = row['from_name'] | |
| result['text'] = row['text_plain'] | |
| result['reply_to_message_id'] = row['reply_to_message_id'] | |
| conn.close() | |
| return results | |
| def _add_thread_context(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """ | |
| Add FULL thread context to search results. | |
| For each message, find the entire conversation thread: | |
| 1. Go up to find the root message | |
| 2. Get all messages in that thread | |
| """ | |
| if not results: | |
| return results | |
| conn = sqlite3.connect(self.messages_db) | |
| conn.row_factory = sqlite3.Row | |
| all_messages = {r['message_id']: r for r in results} | |
| thread_roots = set() | |
| # Step 1: Find root messages by following reply chains UP | |
| for result in results: | |
| msg_id = result['message_id'] | |
| reply_to = result.get('reply_to_message_id') | |
| # Follow the chain up to find the root | |
| current_id = msg_id | |
| current_reply_to = reply_to | |
| visited = {current_id} | |
| while current_reply_to and current_reply_to not in visited: | |
| visited.add(current_reply_to) | |
| cursor = conn.execute( | |
| "SELECT id, reply_to_message_id FROM messages WHERE id = ?", | |
| (current_reply_to,) | |
| ) | |
| row = cursor.fetchone() | |
| if row: | |
| current_id = row['id'] | |
| current_reply_to = row['reply_to_message_id'] | |
| else: | |
| break | |
| # current_id is now the root of this thread | |
| thread_roots.add(current_id) | |
| # Step 2: Get ALL messages in these threads (recursively) | |
| def get_thread_messages(root_ids, depth=0, max_depth=10): | |
| """Recursively get all messages in threads.""" | |
| if not root_ids or depth > max_depth: | |
| return [] | |
| messages = [] | |
| # Get root messages themselves | |
| if root_ids: | |
| placeholders = ','.join('?' * len(root_ids)) | |
| cursor = conn.execute(f""" | |
| SELECT id, date, from_name, text_plain, reply_to_message_id | |
| FROM messages WHERE id IN ({placeholders}) | |
| """, list(root_ids)) | |
| for row in cursor: | |
| if row['id'] not in all_messages: | |
| messages.append({ | |
| 'message_id': row['id'], | |
| 'date': row['date'], | |
| 'from_name': row['from_name'], | |
| 'text': row['text_plain'], | |
| 'reply_to_message_id': row['reply_to_message_id'], | |
| 'is_thread_context': True | |
| }) | |
| all_messages[row['id']] = messages[-1] | |
| # Get all replies to these messages | |
| all_ids = set(root_ids) | set(all_messages.keys()) | |
| if all_ids: | |
| placeholders = ','.join('?' * len(all_ids)) | |
| cursor = conn.execute(f""" | |
| SELECT id, date, from_name, text_plain, reply_to_message_id | |
| FROM messages WHERE reply_to_message_id IN ({placeholders}) | |
| LIMIT 200 | |
| """, list(all_ids)) | |
| new_ids = set() | |
| for row in cursor: | |
| if row['id'] not in all_messages: | |
| msg = { | |
| 'message_id': row['id'], | |
| 'date': row['date'], | |
| 'from_name': row['from_name'], | |
| 'text': row['text_plain'], | |
| 'reply_to_message_id': row['reply_to_message_id'], | |
| 'is_thread_context': True | |
| } | |
| messages.append(msg) | |
| all_messages[row['id']] = msg | |
| new_ids.add(row['id']) | |
| # Recursively get replies to the new messages | |
| if new_ids: | |
| messages.extend(get_thread_messages(new_ids, depth + 1, max_depth)) | |
| return messages | |
| # Get all thread messages | |
| get_thread_messages(thread_roots) | |
| conn.close() | |
| # Sort all messages by date | |
| all_list = list(all_messages.values()) | |
| all_list.sort(key=lambda x: x.get('date', '') or '') | |
| return all_list | |
| def search_with_ai_answer(self, query: str, ai_engine, limit: int = 30) -> Dict[str, Any]: | |
| """ | |
| Search semantically and send results to AI for reasoning. | |
| This combines the power of: | |
| 1. Semantic search (finds relevant messages by meaning) | |
| 2. Thread context (includes replies to/from found messages) | |
| 3. AI reasoning (reads messages and answers the question) | |
| """ | |
| results = self.search_with_full_text(query, limit=limit) | |
| if not results: | |
| return { | |
| 'query': query, | |
| 'answer': '诇讗 谞诪爪讗讜 讛讜讚注讜转 专诇讜讜谞讟讬讜转', | |
| 'mode': 'semantic_ai', | |
| 'results': [], | |
| 'count': 0 | |
| } | |
| # Get thread context for each result | |
| results_with_threads = self._add_thread_context(results) | |
| # Build context from semantic search results + threads | |
| context_text = "\n".join([ | |
| f"[{r.get('date', '')}] {r.get('from_name', 'Unknown')}: {r.get('text', '')[:500]}" | |
| for r in results_with_threads if r.get('text') | |
| ]) | |
| # Send to AI for reasoning | |
| reason_prompt = f"""You are analyzing a Telegram chat history to answer a question. | |
| The messages below were found using semantic search, along with their thread context (replies). | |
| Read them carefully and provide a comprehensive answer. | |
| Question: {query} | |
| Relevant messages and their threads: | |
| {context_text} | |
| Based on these messages, answer the question in Hebrew. | |
| If you can find the answer, provide it clearly. | |
| Pay special attention to reply chains - the answer might be in a reply! | |
| If you can infer information from context clues, do so. | |
| Cite specific messages when relevant. | |
| Answer:""" | |
| try: | |
| # Call the appropriate AI provider based on engine configuration | |
| provider = getattr(ai_engine, 'provider', None) | |
| if provider == 'gemini': | |
| answer = ai_engine._call_gemini(reason_prompt) | |
| elif provider == 'groq': | |
| answer = ai_engine._call_groq(reason_prompt) | |
| elif provider == 'ollama': | |
| answer = ai_engine._call_ollama(reason_prompt) | |
| else: | |
| answer = "AI engine not available for reasoning" | |
| except Exception as e: | |
| answer = f"砖讙讬讗讛 讘-AI: {str(e)}" | |
| return { | |
| 'query': query, | |
| 'answer': answer, | |
| 'mode': 'semantic_ai', | |
| 'results': results, # Original results for display | |
| 'count': len(results), | |
| 'total_with_threads': len(results_with_threads) | |
| } | |
| def is_available(self) -> bool: | |
| """Check if semantic search is available (DB exists and has embeddings).""" | |
| import os | |
| if not HAS_TRANSFORMERS or not os.path.exists(self.embeddings_db): | |
| return False | |
| try: | |
| conn = sqlite3.connect(self.embeddings_db) | |
| count = conn.execute("SELECT COUNT(*) FROM embeddings").fetchone()[0] | |
| conn.close() | |
| return count > 0 | |
| except Exception: | |
| return False | |
| def stats(self) -> Dict[str, Any]: | |
| """Get statistics about the embeddings.""" | |
| import os | |
| if not os.path.exists(self.embeddings_db): | |
| return {'available': False, 'error': 'embeddings.db not found'} | |
| conn = sqlite3.connect(self.embeddings_db) | |
| cursor = conn.execute("SELECT COUNT(*) FROM embeddings") | |
| count = cursor.fetchone()[0] | |
| conn.close() | |
| size_mb = os.path.getsize(self.embeddings_db) / (1024 * 1024) | |
| return { | |
| 'available': True, | |
| 'count': count, | |
| 'size_mb': round(size_mb, 1), | |
| 'model': 'intfloat/multilingual-e5-large' | |
| } | |
| # Singleton instance | |
| _search_instance = None | |
| def get_semantic_search() -> SemanticSearch: | |
| """Get or create semantic search instance.""" | |
| global _search_instance | |
| if _search_instance is None: | |
| _search_instance = SemanticSearch() | |
| return _search_instance | |
| if __name__ == '__main__': | |
| # Test | |
| ss = SemanticSearch() | |
| print("Stats:", ss.stats()) | |
| if ss.is_available(): | |
| results = ss.search("讗讬驻讛 讗转讛 注讜讘讚?", limit=5) | |
| print("\nResults for '讗讬驻讛 讗转讛 注讜讘讚?':") | |
| for r in results: | |
| print(f" [{r['score']:.3f}] {r['from_name']}: {r['text'][:60]}...") | |