chatbot_gradio / step3_setup_vector_db.py
datasciencesage's picture
Update step3_setup_vector_db.py
1e34ca6 verified
import json
from pathlib import Path
from typing import List, Dict, Any
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import os
# Set working directory
# ============================================
# CONFIGURATION
# ============================================
# Alternative: "BAAI/bge-small-en-v1.5" for better quality
# ============================================
# LOAD EXTRACTED DATA
# ============================================
def load_extracted_data(json_path: Path) -> List[Dict]:
"""Load the extracted JSON data"""
print(f"πŸ“‚ Loading data from: {json_path}")
if not json_path.exists():
raise FileNotFoundError(f"JSON file not found: {json_path}")
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"βœ… Loaded {len(data)} pages\n")
return data
# ============================================
# CREATE SEARCHABLE TEXT CHUNKS
# ============================================
def create_searchable_chunks(pages_data: List[Dict]) -> List[Dict]:
"""
Convert extracted data into searchable chunks
Combines text + equation descriptions + table summaries
"""
all_chunks = []
print("πŸ”§ Creating searchable chunks...")
for page in tqdm(pages_data):
page_num = page.get('page_number', 0)
filename = page.get('original_filename', 'unknown')
for chunk in page.get('chunks', []):
# Build searchable text
searchable_text = chunk.get('text_content', '')
# Add equation descriptions (NOT raw LaTeX for better search)
equations = chunk.get('equations', [])
if equations:
searchable_text += "\n\n"
for eq in equations:
desc = eq.get('equation_description', '')
context = eq.get('context', '')
searchable_text += f"Equation: {desc}. {context}\n"
# Add table summaries
tables = chunk.get('tables', [])
if tables:
searchable_text += "\n\n"
for tbl in tables:
caption = tbl.get('caption', '')
summary = tbl.get('summary', '')
searchable_text += f"Table ({caption}): {summary}\n"
# Skip empty chunks
if not searchable_text.strip():
continue
# Create chunk metadata
chunk_data = {
'chunk_id': chunk.get('chunk_id', f"page{page_num}_chunk1"),
'page_number': page_num,
'filename': filename,
'content_type': chunk.get('content_type', 'text'),
# Searchable text (for embedding)
'searchable_text': searchable_text.strip(),
# Original content (for generation)
'raw_text': chunk.get('text_content', ''),
'equations_latex': [eq.get('equation_latex', '') for eq in equations],
'equations_context': [eq.get('context', '') for eq in equations],
'tables_markdown': [tbl.get('markdown_table', '') for tbl in tables],
'table_captions': [tbl.get('caption', '') for tbl in tables],
}
all_chunks.append(chunk_data)
print(f"βœ… Created {len(all_chunks)} searchable chunks\n")
return all_chunks
# ============================================
# SETUP VECTOR DATABASE
# ============================================
def setup_vector_database(chunks: List[Dict], db_path: Path, model_name: str):
"""
Create Chroma vector database and embed all chunks
"""
print(f"πŸ—„οΈ Setting up vector database...")
print(f" Path: {db_path}")
print(f" Model: {model_name}\n")
# Load embedding model
print("Loading embedding model...")
embedding_model = SentenceTransformer(model_name)
print(f"βœ… Model loaded (dimension: {embedding_model.get_sentence_embedding_dimension()})\n")
# Create Chroma client
db_path.mkdir(parents=True, exist_ok=True)
client = chromadb.PersistentClient(
path=str(db_path),
settings=Settings(anonymized_telemetry=False)
)
# Create or get collection
collection_name = "document_chunks"
# Delete existing collection if it exists
try:
client.delete_collection(collection_name)
print("πŸ—‘οΈ Deleted existing collection")
except:
pass
collection = client.create_collection(
name=collection_name,
metadata={"description": "Mathematical document chunks with LaTeX"}
)
print(f"πŸ“ Embedding {len(chunks)} chunks...")
# Prepare data for batch insertion
ids = []
documents = []
embeddings = []
metadatas = []
for chunk in tqdm(chunks):
# Embed searchable text
embedding = embedding_model.encode(
chunk['searchable_text'],
convert_to_numpy=True
).tolist()
# Prepare metadata (Chroma doesn't support lists, so store as JSON strings)
metadata = {
'chunk_id': chunk['chunk_id'],
'page_number': chunk['page_number'],
'filename': chunk['filename'],
'content_type': chunk['content_type'],
'raw_text': chunk['raw_text'],
'equations_latex': json.dumps(chunk['equations_latex']),
'equations_context': json.dumps(chunk['equations_context']),
'tables_markdown': json.dumps(chunk['tables_markdown']),
'table_captions': json.dumps(chunk['table_captions']),
}
ids.append(chunk['chunk_id'])
documents.append(chunk['searchable_text'])
embeddings.append(embedding)
metadatas.append(metadata)
# Batch insert (more efficient)
batch_size = 100
for i in range(0, len(ids), batch_size):
end_idx = min(i + batch_size, len(ids))
collection.add(
ids=ids[i:end_idx],
documents=documents[i:end_idx],
embeddings=embeddings[i:end_idx],
metadatas=metadatas[i:end_idx]
)
print(f"\nβœ… Vector database created!")
print(f" Total chunks: {collection.count()}")
print(f" Location: {db_path}\n")
return collection, embedding_model
# ============================================
# TEST RETRIEVAL
# ============================================
def test_retrieval(collection, embedding_model, test_queries: List[str], top_k: int = 3):
"""Test the retrieval system"""
print("πŸ” Testing retrieval system...\n")
for query in test_queries:
print(f"Query: '{query}'")
print("-" * 60)
# Embed query
query_embedding = embedding_model.encode(query, convert_to_numpy=True).tolist()
# Search
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k
)
# Display results
for i, (doc, metadata) in enumerate(zip(results['documents'][0], results['metadatas'][0]), 1):
print(f"\nResult {i}:")
print(f" File: {metadata['filename']}")
print(f" Page: {metadata['page_number']}")
print(f" Type: {metadata['content_type']}")
print(f" Text preview: {doc[:150]}...")
# Show equations if present
equations = json.loads(metadata.get('equations_latex', '[]'))
if equations:
print(f" Equations: {len(equations)} found")
for eq in equations[:2]: # Show first 2
print(f" - {eq[:80]}...")
print("\n" + "="*60 + "\n")
# ============================================
# MAIN SETUP
# ============================================
def set_vector_db():
SCRIPT_DIR = Path(__file__).parent.resolve()
os.chdir(SCRIPT_DIR)
EXTRACTED_JSON = SCRIPT_DIR / "extracted_jsons" / "extracted_content.json"
CHROMA_DB_PATH = SCRIPT_DIR / "chroma_db"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # Fast, good quality
print("="*60)
print("STEP 2: VECTOR DATABASE SETUP")
print("="*60 + "\n")
# Load extracted data
pages_data = load_extracted_data(EXTRACTED_JSON)
# Create searchable chunks
chunks = create_searchable_chunks(pages_data)
if not chunks:
print("❌ No chunks created. Check your extracted data.")
return
# Setup vector database
collection, embedding_model = setup_vector_database(
chunks,
CHROMA_DB_PATH,
EMBEDDING_MODEL
)
# Test with sample queries
test_queries = [
"quadratic formula derivation",
"solve differential equations",
"matrix multiplication",
"integration by parts",
"probability distribution"
]
test_retrieval(collection, embedding_model, test_queries, top_k=3)
print("="*60)
print("βœ… SETUP COMPLETE!")
print("="*60)
print(f"\nVector database saved at: {CHROMA_DB_PATH}")
print(f"Total indexed chunks: {collection.count()}")
print("\nYou can now use this for RAG queries!")
if __name__ == "__main__":
set_vector_db()