|
|
|
|
|
""" |
|
|
Initialize the RAG system by creating embeddings and FAISS index. |
|
|
""" |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
import faiss |
|
|
import numpy as np |
|
|
from config import DATA_DIR, MODELS_DIR, CHUNK_SIZE, CHUNK_OVERLAP, EMBEDDING_MODEL |
|
|
import sqlite3 |
|
|
import hashlib |
|
|
from typing import List, Tuple |
|
|
import os |
|
|
|
|
|
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]: |
|
|
"""Simple text chunking implementation.""" |
|
|
words = text.split() |
|
|
chunks = [] |
|
|
|
|
|
for i in range(0, len(words), chunk_size - overlap): |
|
|
chunk = " ".join(words[i:i + chunk_size]) |
|
|
chunks.append(chunk) |
|
|
if i + chunk_size >= len(words): |
|
|
break |
|
|
|
|
|
return chunks |
|
|
|
|
|
def initialize_rag(): |
|
|
"""Initialize the RAG system with sample data.""" |
|
|
print("Initializing RAG system...") |
|
|
|
|
|
|
|
|
print(f"Loading embedding model: {EMBEDDING_MODEL}") |
|
|
embedder = SentenceTransformer(EMBEDDING_MODEL) |
|
|
|
|
|
|
|
|
documents = [] |
|
|
doc_ids = [] |
|
|
chunk_metadata = [] |
|
|
|
|
|
|
|
|
md_files = list(DATA_DIR.glob("*.md")) |
|
|
txt_files = list(DATA_DIR.glob("*.txt")) |
|
|
|
|
|
if not md_files and not txt_files: |
|
|
print("No documents found. Running download_sample_data.py first...") |
|
|
|
|
|
from scripts.download_sample_data import download_sample_data |
|
|
download_sample_data() |
|
|
|
|
|
|
|
|
md_files = list(DATA_DIR.glob("*.md")) |
|
|
txt_files = list(DATA_DIR.glob("*.txt")) |
|
|
|
|
|
print(f"Found {len(md_files)} .md files and {len(txt_files)} .txt files") |
|
|
|
|
|
for file_path in md_files: |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
chunks = chunk_text(content) |
|
|
documents.extend(chunks) |
|
|
doc_ids.extend([file_path.name] * len(chunks)) |
|
|
for j, chunk in enumerate(chunks): |
|
|
chunk_metadata.append({ |
|
|
'doc_id': file_path.name, |
|
|
'chunk_index': j, |
|
|
'file_type': 'markdown' |
|
|
}) |
|
|
|
|
|
for file_path in txt_files: |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
chunks = chunk_text(content) |
|
|
documents.extend(chunks) |
|
|
doc_ids.extend([file_path.name] * len(chunks)) |
|
|
for j, chunk in enumerate(chunks): |
|
|
chunk_metadata.append({ |
|
|
'doc_id': file_path.name, |
|
|
'chunk_index': j, |
|
|
'file_type': 'text' |
|
|
}) |
|
|
|
|
|
print(f"Found {len(documents)} chunks from {len(set(doc_ids))} documents") |
|
|
|
|
|
if not documents: |
|
|
print("ERROR: No documents found. Please add documents to the data/ directory first.") |
|
|
return |
|
|
|
|
|
|
|
|
print("Creating embeddings...") |
|
|
embeddings = embedder.encode(documents, show_progress_bar=True, batch_size=32) |
|
|
|
|
|
|
|
|
print("Creating FAISS index...") |
|
|
dimension = embeddings.shape[1] |
|
|
index = faiss.IndexFlatL2(dimension) |
|
|
index.add(embeddings.astype(np.float32)) |
|
|
|
|
|
|
|
|
faiss_index_path = DATA_DIR / "faiss_index.bin" |
|
|
faiss.write_index(index, str(faiss_index_path)) |
|
|
print(f"Saved FAISS index to {faiss_index_path}") |
|
|
|
|
|
|
|
|
print("Creating document store...") |
|
|
conn = sqlite3.connect(DATA_DIR / "docstore.db") |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS chunks ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
chunk_text TEXT NOT NULL, |
|
|
doc_id TEXT NOT NULL, |
|
|
chunk_hash TEXT UNIQUE NOT NULL, |
|
|
embedding_hash TEXT, |
|
|
chunk_index INTEGER, |
|
|
file_type TEXT, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS embedding_cache ( |
|
|
text_hash TEXT PRIMARY KEY, |
|
|
embedding BLOB NOT NULL, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
access_count INTEGER DEFAULT 0 |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
inserted_count = 0 |
|
|
for i, (chunk, doc_id, metadata) in enumerate(zip(documents, doc_ids, chunk_metadata)): |
|
|
chunk_hash = hashlib.md5(chunk.encode()).hexdigest() |
|
|
try: |
|
|
cursor.execute( |
|
|
"""INSERT INTO chunks |
|
|
(chunk_text, doc_id, chunk_hash, chunk_index, file_type) |
|
|
VALUES (?, ?, ?, ?, ?)""", |
|
|
(chunk, doc_id, chunk_hash, metadata['chunk_index'], metadata['file_type']) |
|
|
) |
|
|
inserted_count += 1 |
|
|
except sqlite3.IntegrityError: |
|
|
|
|
|
pass |
|
|
|
|
|
conn.commit() |
|
|
|
|
|
|
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_chunk_hash ON chunks(chunk_hash)") |
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_doc_id ON chunks(doc_id)") |
|
|
conn.commit() |
|
|
|
|
|
conn.close() |
|
|
print(f"Saved {inserted_count} chunks to document store") |
|
|
|
|
|
|
|
|
cache_path = DATA_DIR / "embedding_cache.db" |
|
|
if not cache_path.exists(): |
|
|
conn = sqlite3.connect(cache_path) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS embedding_cache ( |
|
|
text_hash TEXT PRIMARY KEY, |
|
|
embedding BLOB NOT NULL, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
access_count INTEGER DEFAULT 0 |
|
|
) |
|
|
""") |
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_created_at ON embedding_cache(created_at)") |
|
|
conn.commit() |
|
|
conn.close() |
|
|
print(f"Created embedding cache at {cache_path}") |
|
|
|
|
|
print("\nRAG system initialized successfully!") |
|
|
print(f"FAISS index: {faiss_index_path}") |
|
|
print(f"Document store: {DATA_DIR / 'docstore.db'}") |
|
|
print(f"Embedding cache: {DATA_DIR / 'embedding_cache.db'}") |
|
|
print(f"Total chunks: {len(documents)}") |
|
|
print(f"Embedding dimension: {dimension}") |
|
|
print("\nYou can now start the API server with: python -m app.main") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
initialize_rag() |
|
|
|