|
|
""" |
|
|
Working Hyper RAG System - FINAL FIXED VERSION. |
|
|
Proper ID mapping between keyword index and FAISS. |
|
|
""" |
|
|
import time |
|
|
import numpy as np |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import faiss |
|
|
import sqlite3 |
|
|
import hashlib |
|
|
from typing import List, Tuple, Optional, Dict, Any |
|
|
from pathlib import Path |
|
|
from datetime import datetime, timedelta |
|
|
import re |
|
|
from collections import defaultdict |
|
|
import psutil |
|
|
import os |
|
|
import asyncio |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
|
|
from config import ( |
|
|
EMBEDDING_MODEL, DATA_DIR, FAISS_INDEX_PATH, DOCSTORE_PATH, |
|
|
EMBEDDING_CACHE_PATH, CHUNK_SIZE, TOP_K_DYNAMIC_HYPER, |
|
|
MAX_TOKENS, ENABLE_EMBEDDING_CACHE, ENABLE_QUERY_CACHE, |
|
|
ENABLE_PRE_FILTER, ENABLE_PROMPT_COMPRESSION |
|
|
) |
|
|
|
|
|
class WorkingHyperRAG: |
|
|
""" |
|
|
Working Hyper RAG - FINAL FIXED VERSION with proper ID mapping. |
|
|
""" |
|
|
|
|
|
def __init__(self, metrics_tracker=None): |
|
|
self.metrics_tracker = metrics_tracker |
|
|
self.embedder = None |
|
|
self.faiss_index = None |
|
|
self.docstore_conn = None |
|
|
self._initialized = False |
|
|
self.process = psutil.Process(os.getpid()) |
|
|
|
|
|
|
|
|
self.thread_pool = ThreadPoolExecutor( |
|
|
max_workers=2, |
|
|
thread_name_prefix="HyperRAGWorker" |
|
|
) |
|
|
|
|
|
|
|
|
self.performance_history = [] |
|
|
self.avg_latency = 0 |
|
|
self.total_queries = 0 |
|
|
|
|
|
|
|
|
self._embedding_cache = {} |
|
|
|
|
|
|
|
|
self._id_mapping = {} |
|
|
|
|
|
def initialize(self): |
|
|
"""Initialize all components - MAIN THREAD ONLY.""" |
|
|
if self._initialized: |
|
|
return |
|
|
|
|
|
print("🚀 Initializing WorkingHyperRAG...") |
|
|
start_time = time.perf_counter() |
|
|
|
|
|
|
|
|
self.embedder = SentenceTransformer(EMBEDDING_MODEL) |
|
|
|
|
|
self.embedder.encode(["warmup"]) |
|
|
|
|
|
|
|
|
if FAISS_INDEX_PATH.exists(): |
|
|
self.faiss_index = faiss.read_index(str(FAISS_INDEX_PATH)) |
|
|
print(f" Loaded FAISS index with {self.faiss_index.ntotal} vectors") |
|
|
else: |
|
|
print(" ⚠ FAISS index not found, retrieval will be limited") |
|
|
|
|
|
|
|
|
self.docstore_conn = sqlite3.connect(DOCSTORE_PATH) |
|
|
self._init_docstore_indices() |
|
|
|
|
|
|
|
|
self._init_cache_schema() |
|
|
|
|
|
|
|
|
self.keyword_index = self._build_keyword_index_with_mapping() |
|
|
|
|
|
init_time = (time.perf_counter() - start_time) * 1000 |
|
|
memory_mb = self.process.memory_info().rss / 1024 / 1024 |
|
|
|
|
|
print(f"✅ WorkingHyperRAG initialized in {init_time:.2f}ms") |
|
|
print(f" Memory: {memory_mb:.2f}MB") |
|
|
print(f" Keyword index: {len(self.keyword_index)} unique words") |
|
|
print(f" ID mapping: {len(self._id_mapping)} entries") |
|
|
|
|
|
self._initialized = True |
|
|
|
|
|
def _init_docstore_indices(self): |
|
|
"""Create performance indices.""" |
|
|
cursor = self.docstore_conn.cursor() |
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_chunk_hash ON chunks(chunk_hash)") |
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_doc_id ON chunks(doc_id)") |
|
|
self.docstore_conn.commit() |
|
|
|
|
|
def _init_cache_schema(self): |
|
|
"""Initialize cache schema - called once from main thread.""" |
|
|
if not ENABLE_EMBEDDING_CACHE: |
|
|
return |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(EMBEDDING_CACHE_PATH) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS embedding_cache ( |
|
|
text_hash TEXT PRIMARY KEY, |
|
|
embedding BLOB NOT NULL, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
access_count INTEGER DEFAULT 0 |
|
|
) |
|
|
""") |
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_created_at ON embedding_cache(created_at)") |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
def _build_keyword_index_with_mapping(self) -> Dict[str, List[int]]: |
|
|
"""Build keyword index with proper FAISS ID mapping.""" |
|
|
cursor = self.docstore_conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute("SELECT id, chunk_text FROM chunks ORDER BY id") |
|
|
chunks = cursor.fetchall() |
|
|
|
|
|
keyword_index = defaultdict(list) |
|
|
self._id_mapping = {} |
|
|
|
|
|
|
|
|
|
|
|
for faiss_id, (db_id, text) in enumerate(chunks): |
|
|
|
|
|
self._id_mapping[faiss_id] = db_id |
|
|
|
|
|
words = set(re.findall(r'\b\w{3,}\b', text.lower())) |
|
|
for word in words: |
|
|
|
|
|
keyword_index[word].append(faiss_id) |
|
|
|
|
|
print(f" Built mapping: {len(self._id_mapping)} FAISS IDs -> DB IDs") |
|
|
return keyword_index |
|
|
|
|
|
def _faiss_id_to_db_id(self, faiss_id: int) -> int: |
|
|
"""Convert FAISS ID (0-based) to Database ID (1-based).""" |
|
|
return self._id_mapping.get(faiss_id, faiss_id + 1) |
|
|
|
|
|
def _db_id_to_faiss_id(self, db_id: int) -> int: |
|
|
"""Convert Database ID (1-based) to FAISS ID (0-based).""" |
|
|
|
|
|
for faiss_id, mapped_db_id in self._id_mapping.items(): |
|
|
if mapped_db_id == db_id: |
|
|
return faiss_id |
|
|
return db_id - 1 |
|
|
|
|
|
def _get_thread_safe_cache_connection(self): |
|
|
"""Get a thread-local cache connection.""" |
|
|
return sqlite3.connect( |
|
|
EMBEDDING_CACHE_PATH, |
|
|
check_same_thread=False, |
|
|
timeout=10.0 |
|
|
) |
|
|
|
|
|
def _get_cached_embedding(self, text: str) -> Optional[np.ndarray]: |
|
|
"""Get embedding from cache - THREAD-SAFE.""" |
|
|
if not ENABLE_EMBEDDING_CACHE: |
|
|
return None |
|
|
|
|
|
text_hash = hashlib.md5(text.encode()).hexdigest() |
|
|
|
|
|
|
|
|
if text_hash in self._embedding_cache: |
|
|
return self._embedding_cache[text_hash] |
|
|
|
|
|
|
|
|
conn = self._get_thread_safe_cache_connection() |
|
|
try: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute( |
|
|
"SELECT embedding FROM embedding_cache WHERE text_hash = ?", |
|
|
(text_hash,) |
|
|
) |
|
|
result = cursor.fetchone() |
|
|
|
|
|
if result: |
|
|
cursor.execute( |
|
|
"UPDATE embedding_cache SET access_count = access_count + 1 WHERE text_hash = ?", |
|
|
(text_hash,) |
|
|
) |
|
|
conn.commit() |
|
|
|
|
|
embedding = np.frombuffer(result[0], dtype=np.float32) |
|
|
self._embedding_cache[text_hash] = embedding |
|
|
return embedding |
|
|
|
|
|
return None |
|
|
finally: |
|
|
conn.close() |
|
|
|
|
|
def _cache_embedding(self, text: str, embedding: np.ndarray): |
|
|
"""Cache an embedding - THREAD-SAFE.""" |
|
|
if not ENABLE_EMBEDDING_CACHE: |
|
|
return |
|
|
|
|
|
text_hash = hashlib.md5(text.encode()).hexdigest() |
|
|
embedding_blob = embedding.astype(np.float32).tobytes() |
|
|
|
|
|
|
|
|
self._embedding_cache[text_hash] = embedding |
|
|
|
|
|
|
|
|
conn = self._get_thread_safe_cache_connection() |
|
|
try: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute( |
|
|
"""INSERT OR REPLACE INTO embedding_cache |
|
|
(text_hash, embedding, access_count) VALUES (?, ?, 1)""", |
|
|
(text_hash, embedding_blob) |
|
|
) |
|
|
conn.commit() |
|
|
finally: |
|
|
conn.close() |
|
|
|
|
|
def _get_dynamic_top_k(self, question: str) -> int: |
|
|
"""Determine top_k based on query complexity.""" |
|
|
words = len(question.split()) |
|
|
|
|
|
if words < 5: |
|
|
return TOP_K_DYNAMIC_HYPER["short"] |
|
|
elif words < 15: |
|
|
return TOP_K_DYNAMIC_HYPER["medium"] |
|
|
else: |
|
|
return TOP_K_DYNAMIC_HYPER["long"] |
|
|
|
|
|
def _pre_filter_chunks(self, question: str) -> Optional[List[int]]: |
|
|
"""Intelligent pre-filtering - SIMPLIFIED VERSION.""" |
|
|
if not ENABLE_PRE_FILTER: |
|
|
return None |
|
|
|
|
|
question_words = set(re.findall(r'\b\w{3,}\b', question.lower())) |
|
|
if not question_words: |
|
|
return None |
|
|
|
|
|
candidate_ids = set() |
|
|
|
|
|
|
|
|
for word in question_words: |
|
|
if word in self.keyword_index: |
|
|
candidate_ids.update(self.keyword_index[word]) |
|
|
|
|
|
if candidate_ids: |
|
|
print(f" [Filter] Matched {len(candidate_ids)} chunks") |
|
|
return list(candidate_ids) |
|
|
|
|
|
print(f" [Filter] No matches") |
|
|
return None |
|
|
|
|
|
def _search_faiss_intelligent(self, query_embedding: np.ndarray, |
|
|
top_k: int, |
|
|
filter_ids: Optional[List[int]] = None) -> List[int]: |
|
|
"""Intelligent FAISS search - SIMPLIFIED AND CORRECT.""" |
|
|
if self.faiss_index is None: |
|
|
return [] |
|
|
|
|
|
query_embedding = query_embedding.astype(np.float32).reshape(1, -1) |
|
|
|
|
|
|
|
|
min_k = max(1, top_k) |
|
|
|
|
|
|
|
|
if filter_ids and len(filter_ids) > 0: |
|
|
|
|
|
search_k = min(top_k * 5, self.faiss_index.ntotal) |
|
|
distances, indices = self.faiss_index.search(query_embedding, search_k) |
|
|
|
|
|
|
|
|
faiss_results = [int(idx) for idx in indices[0] if idx >= 0] |
|
|
|
|
|
|
|
|
filtered_results = [idx for idx in faiss_results if idx in filter_ids] |
|
|
|
|
|
if filtered_results: |
|
|
print(f" [Search] Filtered to {len(filtered_results)} chunks") |
|
|
return filtered_results[:min_k] |
|
|
else: |
|
|
|
|
|
print(f" [Search] No filtered matches, using top {min_k} results") |
|
|
return faiss_results[:min_k] |
|
|
else: |
|
|
|
|
|
distances, indices = self.faiss_index.search(query_embedding, min_k) |
|
|
results = [int(idx) for idx in indices[0] if idx >= 0] |
|
|
return results |
|
|
|
|
|
def _retrieve_chunks_by_faiss_ids(self, faiss_ids: List[int]) -> List[str]: |
|
|
"""Retrieve chunks by FAISS IDs.""" |
|
|
if not faiss_ids: |
|
|
return [] |
|
|
|
|
|
|
|
|
db_ids = [self._faiss_id_to_db_id(faiss_id) for faiss_id in faiss_ids] |
|
|
|
|
|
cursor = self.docstore_conn.cursor() |
|
|
placeholders = ','.join('?' for _ in db_ids) |
|
|
query = f"SELECT chunk_text FROM chunks WHERE id IN ({placeholders}) ORDER BY id" |
|
|
cursor.execute(query, db_ids) |
|
|
return [r[0] for r in cursor.fetchall()] |
|
|
|
|
|
def _compress_prompt(self, chunks: List[str]) -> List[str]: |
|
|
"""Intelligent prompt compression.""" |
|
|
if not ENABLE_PROMPT_COMPRESSION or not chunks: |
|
|
return chunks |
|
|
|
|
|
compressed = [] |
|
|
total_tokens = 0 |
|
|
|
|
|
for chunk in chunks: |
|
|
chunk_tokens = len(chunk.split()) |
|
|
if total_tokens + chunk_tokens <= MAX_TOKENS: |
|
|
compressed.append(chunk) |
|
|
total_tokens += chunk_tokens |
|
|
else: |
|
|
break |
|
|
|
|
|
return compressed |
|
|
|
|
|
def _generate_hyper_response(self, question: str, chunks: List[str]) -> str: |
|
|
"""Generate response - FAST AND SIMPLE.""" |
|
|
if not chunks: |
|
|
return "I don't have enough specific information to answer that question." |
|
|
|
|
|
|
|
|
compressed_chunks = self._compress_prompt(chunks) |
|
|
|
|
|
|
|
|
time.sleep(0.08) |
|
|
|
|
|
|
|
|
context = "\n\n".join(compressed_chunks[:3]) |
|
|
return f"Based on the information: {context[:300]}..." |
|
|
|
|
|
async def query_async(self, question: str, top_k: Optional[int] = None) -> Tuple[str, int]: |
|
|
"""Async query processing - OPTIMIZED FOR SPEED.""" |
|
|
if not self._initialized: |
|
|
self.initialize() |
|
|
|
|
|
start_time = time.perf_counter() |
|
|
|
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
embed_future = loop.run_in_executor( |
|
|
self.thread_pool, |
|
|
self._embed_and_cache_sync, |
|
|
question |
|
|
) |
|
|
|
|
|
filter_future = loop.run_in_executor( |
|
|
self.thread_pool, |
|
|
self._pre_filter_chunks, |
|
|
question |
|
|
) |
|
|
|
|
|
query_embedding, cache_status = await embed_future |
|
|
filter_ids = await filter_future |
|
|
|
|
|
|
|
|
dynamic_k = self._get_dynamic_top_k(question) |
|
|
effective_k = top_k or dynamic_k |
|
|
|
|
|
|
|
|
faiss_ids = self._search_faiss_intelligent(query_embedding, effective_k, filter_ids) |
|
|
|
|
|
|
|
|
chunks = self._retrieve_chunks_by_faiss_ids(faiss_ids) |
|
|
|
|
|
|
|
|
answer = self._generate_hyper_response(question, chunks) |
|
|
|
|
|
total_time = (time.perf_counter() - start_time) * 1000 |
|
|
|
|
|
|
|
|
print(f"[Hyper RAG] Query: '{question[:50]}...'") |
|
|
print(f" - Cache: {cache_status}") |
|
|
print(f" - Filtered: {'Yes' if filter_ids else 'No'}") |
|
|
print(f" - Top-K: {effective_k}") |
|
|
print(f" - Chunks used: {len(chunks)}") |
|
|
print(f" - Time: {total_time:.1f}ms") |
|
|
|
|
|
|
|
|
if self.metrics_tracker: |
|
|
self.metrics_tracker.record_query( |
|
|
model="hyper", |
|
|
latency_ms=total_time, |
|
|
memory_mb=0.0, |
|
|
chunks_used=len(chunks), |
|
|
question_length=len(question) |
|
|
) |
|
|
|
|
|
return answer, len(chunks) |
|
|
|
|
|
def _embed_and_cache_sync(self, text: str) -> Tuple[np.ndarray, str]: |
|
|
"""Synchronous embedding with caching.""" |
|
|
cached = self._get_cached_embedding(text) |
|
|
if cached is not None: |
|
|
return cached, "HIT" |
|
|
|
|
|
embedding = self.embedder.encode([text])[0] |
|
|
self._cache_embedding(text, embedding) |
|
|
return embedding, "MISS" |
|
|
|
|
|
def query(self, question: str, top_k: Optional[int] = None) -> Tuple[str, int]: |
|
|
"""Synchronous query wrapper.""" |
|
|
return asyncio.run(self.query_async(question, top_k)) |
|
|
|
|
|
def get_performance_stats(self) -> Dict[str, Any]: |
|
|
"""Get performance statistics.""" |
|
|
return { |
|
|
"total_queries": self.total_queries, |
|
|
"avg_latency_ms": self.avg_latency, |
|
|
"memory_cache_size": len(self._embedding_cache), |
|
|
"keyword_index_size": len(self.keyword_index), |
|
|
"faiss_vectors": self.faiss_index.ntotal if self.faiss_index else 0 |
|
|
} |
|
|
|
|
|
def close(self): |
|
|
"""Cleanup.""" |
|
|
if self.thread_pool: |
|
|
self.thread_pool.shutdown(wait=True) |
|
|
if self.docstore_conn: |
|
|
self.docstore_conn.close() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("\n🧪 Quick test of Fixed Hyper RAG...") |
|
|
|
|
|
from app.metrics import MetricsTracker |
|
|
|
|
|
metrics = MetricsTracker() |
|
|
rag = WorkingHyperRAG(metrics) |
|
|
|
|
|
|
|
|
query = "What is machine learning?" |
|
|
print(f"\n📝 Query: {query}") |
|
|
answer, chunks = rag.query(query) |
|
|
print(f" Answer: {answer[:100]}...") |
|
|
print(f" Chunks used: {chunks}") |
|
|
|
|
|
rag.close() |
|
|
print("\n✅ Test complete!") |
|
|
|