telegram-analytics / semantic_search.py
rottg's picture
Update code
85ff768 verified
"""
Semantic Search using pre-computed embeddings from Colab.
Lightweight - only needs sentence-transformers for query encoding.
"""
import sqlite3
import numpy as np
from typing import List, Dict, Any, Optional
# Try importing sentence-transformers
try:
from sentence_transformers import SentenceTransformer
HAS_TRANSFORMERS = True
except ImportError:
HAS_TRANSFORMERS = False
SentenceTransformer = None
class SemanticSearch:
"""
Semantic search using pre-computed embeddings.
The embeddings.db file is created by running the Colab notebook.
This class just loads and searches them.
"""
def __init__(self, embeddings_db: str = 'embeddings.db', messages_db: str = 'telegram.db'):
self.embeddings_db = embeddings_db
self.messages_db = messages_db
self.model = None
self.embeddings_loaded = False
self.embeddings = []
self.message_ids = []
self.from_names = []
self.text_previews = []
def _load_model(self):
"""Load the embedding model (same one used in Colab)."""
if not HAS_TRANSFORMERS:
raise RuntimeError(
"sentence-transformers not installed.\n"
"Install with: pip install sentence-transformers"
)
if self.model is None:
print("Loading embedding model...")
self.model = SentenceTransformer('intfloat/multilingual-e5-large')
print("Model loaded!")
def reload_embeddings(self):
"""Force reload embeddings from DB (e.g., after daily sync adds new ones)."""
self.embeddings_loaded = False
self.embeddings = np.array([]).reshape(0, 0)
self.message_ids = []
self.from_names = []
self.text_previews = []
self._load_embeddings()
def _load_embeddings(self):
"""Load all embeddings into memory for fast search."""
if self.embeddings_loaded:
return
import os
if not os.path.exists(self.embeddings_db):
print(f"Embeddings DB not found: {self.embeddings_db}")
self.embeddings_loaded = True
self.embeddings = np.array([]).reshape(0, 0)
return
print(f"Loading embeddings from {self.embeddings_db}...")
conn = sqlite3.connect(self.embeddings_db)
cursor = conn.execute(
"SELECT message_id, from_name, text_preview, embedding FROM embeddings"
)
emb_list = []
for row in cursor:
msg_id, name, text, emb_blob = row
emb = np.frombuffer(emb_blob, dtype=np.float32)
self.message_ids.append(msg_id)
self.from_names.append(name or '')
self.text_previews.append(text or '')
emb_list.append(emb)
conn.close()
if len(emb_list) == 0:
print("No embeddings found in database")
self.embeddings = np.array([]).reshape(0, 0)
self.embeddings_loaded = True
return
# Stack into numpy array for fast computation
self.embeddings = np.vstack(emb_list)
# Normalize embeddings for cosine similarity
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
norms = np.where(norms == 0, 1, norms) # Avoid division by zero
self.embeddings = self.embeddings / norms
self.embeddings_loaded = True
print(f"Loaded {len(self.message_ids)} embeddings")
def search(self, query: str, limit: int = 50, min_score: float = 0.3) -> List[Dict[str, Any]]:
"""
Search for semantically similar messages.
Args:
query: The search query
limit: Max results to return
min_score: Minimum similarity score (0-1)
Returns:
List of dicts with message_id, from_name, text, score
"""
self._load_model()
self._load_embeddings()
if len(self.message_ids) == 0:
return []
# Encode query (e5 model requires "query: " prefix)
query_emb = self.model.encode([f"query: {query}"], convert_to_numpy=True)[0]
# Compute cosine similarity with all embeddings
# embeddings are already normalized from Colab
query_norm = query_emb / np.linalg.norm(query_emb)
similarities = np.dot(self.embeddings, query_norm)
# Get top results
top_indices = np.argsort(similarities)[::-1][:limit * 2] # Get more, then filter
results = []
for idx in top_indices:
score = float(similarities[idx])
if score < min_score:
continue
results.append({
'message_id': int(self.message_ids[idx]),
'from_name': self.from_names[idx],
'text': self.text_previews[idx],
'score': score
})
if len(results) >= limit:
break
return results
def search_with_full_text(self, query: str, limit: int = 20) -> List[Dict[str, Any]]:
"""
Search and return full message text from messages DB.
"""
results = self.search(query, limit=limit)
if not results:
return []
# Get full text from messages DB
conn = sqlite3.connect(self.messages_db)
conn.row_factory = sqlite3.Row
for result in results:
cursor = conn.execute(
"SELECT date, from_name, text_plain, reply_to_message_id FROM messages WHERE id = ?",
(result['message_id'],)
)
row = cursor.fetchone()
if row:
result['date'] = row['date']
result['from_name'] = row['from_name']
result['text'] = row['text_plain']
result['reply_to_message_id'] = row['reply_to_message_id']
conn.close()
return results
def _add_thread_context(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Add FULL thread context to search results.
For each message, find the entire conversation thread:
1. Go up to find the root message
2. Get all messages in that thread
"""
if not results:
return results
conn = sqlite3.connect(self.messages_db)
conn.row_factory = sqlite3.Row
all_messages = {r['message_id']: r for r in results}
thread_roots = set()
# Step 1: Find root messages by following reply chains UP
for result in results:
msg_id = result['message_id']
reply_to = result.get('reply_to_message_id')
# Follow the chain up to find the root
current_id = msg_id
current_reply_to = reply_to
visited = {current_id}
while current_reply_to and current_reply_to not in visited:
visited.add(current_reply_to)
cursor = conn.execute(
"SELECT id, reply_to_message_id FROM messages WHERE id = ?",
(current_reply_to,)
)
row = cursor.fetchone()
if row:
current_id = row['id']
current_reply_to = row['reply_to_message_id']
else:
break
# current_id is now the root of this thread
thread_roots.add(current_id)
# Step 2: Get ALL messages in these threads (recursively)
def get_thread_messages(root_ids, depth=0, max_depth=10):
"""Recursively get all messages in threads."""
if not root_ids or depth > max_depth:
return []
messages = []
# Get root messages themselves
if root_ids:
placeholders = ','.join('?' * len(root_ids))
cursor = conn.execute(f"""
SELECT id, date, from_name, text_plain, reply_to_message_id
FROM messages WHERE id IN ({placeholders})
""", list(root_ids))
for row in cursor:
if row['id'] not in all_messages:
messages.append({
'message_id': row['id'],
'date': row['date'],
'from_name': row['from_name'],
'text': row['text_plain'],
'reply_to_message_id': row['reply_to_message_id'],
'is_thread_context': True
})
all_messages[row['id']] = messages[-1]
# Get all replies to these messages
all_ids = set(root_ids) | set(all_messages.keys())
if all_ids:
placeholders = ','.join('?' * len(all_ids))
cursor = conn.execute(f"""
SELECT id, date, from_name, text_plain, reply_to_message_id
FROM messages WHERE reply_to_message_id IN ({placeholders})
LIMIT 200
""", list(all_ids))
new_ids = set()
for row in cursor:
if row['id'] not in all_messages:
msg = {
'message_id': row['id'],
'date': row['date'],
'from_name': row['from_name'],
'text': row['text_plain'],
'reply_to_message_id': row['reply_to_message_id'],
'is_thread_context': True
}
messages.append(msg)
all_messages[row['id']] = msg
new_ids.add(row['id'])
# Recursively get replies to the new messages
if new_ids:
messages.extend(get_thread_messages(new_ids, depth + 1, max_depth))
return messages
# Get all thread messages
get_thread_messages(thread_roots)
conn.close()
# Sort all messages by date
all_list = list(all_messages.values())
all_list.sort(key=lambda x: x.get('date', '') or '')
return all_list
def search_with_ai_answer(self, query: str, ai_engine, limit: int = 30) -> Dict[str, Any]:
"""
Search semantically and send results to AI for reasoning.
This combines the power of:
1. Semantic search (finds relevant messages by meaning)
2. Thread context (includes replies to/from found messages)
3. AI reasoning (reads messages and answers the question)
"""
results = self.search_with_full_text(query, limit=limit)
if not results:
return {
'query': query,
'answer': '诇讗 谞诪爪讗讜 讛讜讚注讜转 专诇讜讜谞讟讬讜转',
'mode': 'semantic_ai',
'results': [],
'count': 0
}
# Get thread context for each result
results_with_threads = self._add_thread_context(results)
# Build context from semantic search results + threads
context_text = "\n".join([
f"[{r.get('date', '')}] {r.get('from_name', 'Unknown')}: {r.get('text', '')[:500]}"
for r in results_with_threads if r.get('text')
])
# Send to AI for reasoning
reason_prompt = f"""You are analyzing a Telegram chat history to answer a question.
The messages below were found using semantic search, along with their thread context (replies).
Read them carefully and provide a comprehensive answer.
Question: {query}
Relevant messages and their threads:
{context_text}
Based on these messages, answer the question in Hebrew.
If you can find the answer, provide it clearly.
Pay special attention to reply chains - the answer might be in a reply!
If you can infer information from context clues, do so.
Cite specific messages when relevant.
Answer:"""
try:
# Call the appropriate AI provider based on engine configuration
provider = getattr(ai_engine, 'provider', None)
if provider == 'gemini':
answer = ai_engine._call_gemini(reason_prompt)
elif provider == 'groq':
answer = ai_engine._call_groq(reason_prompt)
elif provider == 'ollama':
answer = ai_engine._call_ollama(reason_prompt)
else:
answer = "AI engine not available for reasoning"
except Exception as e:
answer = f"砖讙讬讗讛 讘-AI: {str(e)}"
return {
'query': query,
'answer': answer,
'mode': 'semantic_ai',
'results': results, # Original results for display
'count': len(results),
'total_with_threads': len(results_with_threads)
}
def is_available(self) -> bool:
"""Check if semantic search is available (DB exists and has embeddings)."""
import os
if not HAS_TRANSFORMERS or not os.path.exists(self.embeddings_db):
return False
try:
conn = sqlite3.connect(self.embeddings_db)
count = conn.execute("SELECT COUNT(*) FROM embeddings").fetchone()[0]
conn.close()
return count > 0
except Exception:
return False
def stats(self) -> Dict[str, Any]:
"""Get statistics about the embeddings."""
import os
if not os.path.exists(self.embeddings_db):
return {'available': False, 'error': 'embeddings.db not found'}
conn = sqlite3.connect(self.embeddings_db)
cursor = conn.execute("SELECT COUNT(*) FROM embeddings")
count = cursor.fetchone()[0]
conn.close()
size_mb = os.path.getsize(self.embeddings_db) / (1024 * 1024)
return {
'available': True,
'count': count,
'size_mb': round(size_mb, 1),
'model': 'intfloat/multilingual-e5-large'
}
# Singleton instance
_search_instance = None
def get_semantic_search() -> SemanticSearch:
"""Get or create semantic search instance."""
global _search_instance
if _search_instance is None:
_search_instance = SemanticSearch()
return _search_instance
if __name__ == '__main__':
# Test
ss = SemanticSearch()
print("Stats:", ss.stats())
if ss.is_available():
results = ss.search("讗讬驻讛 讗转讛 注讜讘讚?", limit=5)
print("\nResults for '讗讬驻讛 讗转讛 注讜讘讚?':")
for r in results:
print(f" [{r['score']:.3f}] {r['from_name']}: {r['text'][:60]}...")