rag-latency-optimization / scripts /initialize_rag.py
Ariyan-Pro's picture
Deploy RAG Latency Optimization v1.0
04ab625
#!/usr/bin/env python3
"""
Initialize the RAG system by creating embeddings and FAISS index.
"""
import sys
from pathlib import Path
# Add project root to Python path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from config import DATA_DIR, MODELS_DIR, CHUNK_SIZE, CHUNK_OVERLAP, EMBEDDING_MODEL
import sqlite3
import hashlib
from typing import List, Tuple
import os
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
"""Simple text chunking implementation."""
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = " ".join(words[i:i + chunk_size])
chunks.append(chunk)
if i + chunk_size >= len(words):
break
return chunks
def initialize_rag():
"""Initialize the RAG system with sample data."""
print("Initializing RAG system...")
# Load embedding model
print(f"Loading embedding model: {EMBEDDING_MODEL}")
embedder = SentenceTransformer(EMBEDDING_MODEL)
# Collect all documents
documents = []
doc_ids = []
chunk_metadata = []
# First, check if we have documents
md_files = list(DATA_DIR.glob("*.md"))
txt_files = list(DATA_DIR.glob("*.txt"))
if not md_files and not txt_files:
print("No documents found. Running download_sample_data.py first...")
# Try to create sample data
from scripts.download_sample_data import download_sample_data
download_sample_data()
# Refresh file list
md_files = list(DATA_DIR.glob("*.md"))
txt_files = list(DATA_DIR.glob("*.txt"))
print(f"Found {len(md_files)} .md files and {len(txt_files)} .txt files")
for file_path in md_files:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
chunks = chunk_text(content)
documents.extend(chunks)
doc_ids.extend([file_path.name] * len(chunks))
for j, chunk in enumerate(chunks):
chunk_metadata.append({
'doc_id': file_path.name,
'chunk_index': j,
'file_type': 'markdown'
})
for file_path in txt_files:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
chunks = chunk_text(content)
documents.extend(chunks)
doc_ids.extend([file_path.name] * len(chunks))
for j, chunk in enumerate(chunks):
chunk_metadata.append({
'doc_id': file_path.name,
'chunk_index': j,
'file_type': 'text'
})
print(f"Found {len(documents)} chunks from {len(set(doc_ids))} documents")
if not documents:
print("ERROR: No documents found. Please add documents to the data/ directory first.")
return
# Create embeddings
print("Creating embeddings...")
embeddings = embedder.encode(documents, show_progress_bar=True, batch_size=32)
# Create FAISS index
print("Creating FAISS index...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension) # L2 distance
index.add(embeddings.astype(np.float32))
# Save FAISS index
faiss_index_path = DATA_DIR / "faiss_index.bin"
faiss.write_index(index, str(faiss_index_path))
print(f"Saved FAISS index to {faiss_index_path}")
# Create document store (SQLite)
print("Creating document store...")
conn = sqlite3.connect(DATA_DIR / "docstore.db")
cursor = conn.cursor()
# Create tables
cursor.execute("""
CREATE TABLE IF NOT EXISTS chunks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
chunk_text TEXT NOT NULL,
doc_id TEXT NOT NULL,
chunk_hash TEXT UNIQUE NOT NULL,
embedding_hash TEXT,
chunk_index INTEGER,
file_type TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS embedding_cache (
text_hash TEXT PRIMARY KEY,
embedding BLOB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
access_count INTEGER DEFAULT 0
)
""")
# Insert chunks
inserted_count = 0
for i, (chunk, doc_id, metadata) in enumerate(zip(documents, doc_ids, chunk_metadata)):
chunk_hash = hashlib.md5(chunk.encode()).hexdigest()
try:
cursor.execute(
"""INSERT INTO chunks
(chunk_text, doc_id, chunk_hash, chunk_index, file_type)
VALUES (?, ?, ?, ?, ?)""",
(chunk, doc_id, chunk_hash, metadata['chunk_index'], metadata['file_type'])
)
inserted_count += 1
except sqlite3.IntegrityError:
# Skip duplicates
pass
conn.commit()
# Create indexes for performance
cursor.execute("CREATE INDEX IF NOT EXISTS idx_chunk_hash ON chunks(chunk_hash)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_doc_id ON chunks(doc_id)")
conn.commit()
conn.close()
print(f"Saved {inserted_count} chunks to document store")
# Also create embedding_cache.db if it doesn't exist
cache_path = DATA_DIR / "embedding_cache.db"
if not cache_path.exists():
conn = sqlite3.connect(cache_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS embedding_cache (
text_hash TEXT PRIMARY KEY,
embedding BLOB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
access_count INTEGER DEFAULT 0
)
""")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_created_at ON embedding_cache(created_at)")
conn.commit()
conn.close()
print(f"Created embedding cache at {cache_path}")
print("\nRAG system initialized successfully!")
print(f"FAISS index: {faiss_index_path}")
print(f"Document store: {DATA_DIR / 'docstore.db'}")
print(f"Embedding cache: {DATA_DIR / 'embedding_cache.db'}")
print(f"Total chunks: {len(documents)}")
print(f"Embedding dimension: {dimension}")
print("\nYou can now start the API server with: python -m app.main")
if __name__ == "__main__":
initialize_rag()