ask-my-research / embed_papers.py
anthonym21's picture
Initial commit - RAG chatbot over research papers
a864e35
#!/usr/bin/env python3
"""
Embed papers for RAG chatbot.
Run this locally before deploying to HuggingFace Space.
Usage:
1. Place your PDF papers in the papers/ directory
2. Run: python embed_papers.py
3. This creates index/faiss.index and index/chunks.json
4. Commit and push to HuggingFace Space
"""
import os
import json
import fitz # PyMuPDF
import numpy as np
from pathlib import Path
from sentence_transformers import SentenceTransformer
# Configuration
PAPERS_DIR = Path("papers")
INDEX_DIR = Path("index")
CHUNK_SIZE = 500 # characters
CHUNK_OVERLAP = 100 # characters
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
def extract_text_from_pdf(pdf_path: Path) -> list[dict]:
"""Extract text from PDF with page numbers."""
doc = fitz.open(pdf_path)
pages = []
for page_num, page in enumerate(doc, 1):
text = page.get_text()
if text.strip():
pages.append({
"text": text,
"page": page_num,
"source": pdf_path.stem
})
doc.close()
return pages
def chunk_text(pages: list[dict], chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> list[dict]:
"""Split pages into overlapping chunks."""
chunks = []
for page_data in pages:
text = page_data["text"]
source = page_data["source"]
page_num = page_data["page"]
# Split into chunks with overlap
start = 0
while start < len(text):
end = start + chunk_size
chunk_text = text[start:end]
# Try to break at sentence boundary
if end < len(text):
last_period = chunk_text.rfind('. ')
if last_period > chunk_size // 2:
chunk_text = chunk_text[:last_period + 1]
end = start + last_period + 1
if chunk_text.strip():
chunks.append({
"text": chunk_text.strip(),
"source": source,
"page": page_num,
"chunk_id": len(chunks)
})
start = end - overlap if end < len(text) else len(text)
return chunks
def create_embeddings(chunks: list[dict], model: SentenceTransformer) -> np.ndarray:
"""Generate embeddings for all chunks."""
texts = [chunk["text"] for chunk in chunks]
embeddings = model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
return embeddings
def save_faiss_index(embeddings: np.ndarray, output_path: Path):
"""Save embeddings as FAISS index."""
import faiss
# Normalize for cosine similarity
faiss.normalize_L2(embeddings)
# Create index
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension) # Inner product (cosine after normalization)
index.add(embeddings)
# Save
faiss.write_index(index, str(output_path))
print(f"Saved FAISS index with {index.ntotal} vectors to {output_path}")
def main():
# Ensure directories exist
INDEX_DIR.mkdir(exist_ok=True)
# Find all PDFs
pdf_files = list(PAPERS_DIR.glob("*.pdf"))
if not pdf_files:
print(f"No PDF files found in {PAPERS_DIR}/")
print("Please add your research papers to the papers/ directory.")
return
print(f"Found {len(pdf_files)} PDF files:")
for pdf in pdf_files:
print(f" - {pdf.name}")
# Extract and chunk
all_chunks = []
for pdf_path in pdf_files:
print(f"\nProcessing {pdf_path.name}...")
pages = extract_text_from_pdf(pdf_path)
chunks = chunk_text(pages)
all_chunks.extend(chunks)
print(f" Extracted {len(pages)} pages, {len(chunks)} chunks")
print(f"\nTotal chunks: {len(all_chunks)}")
# Load embedding model
print(f"\nLoading embedding model: {EMBEDDING_MODEL}")
model = SentenceTransformer(EMBEDDING_MODEL)
# Generate embeddings
print("Generating embeddings...")
embeddings = create_embeddings(all_chunks, model)
print(f"Embeddings shape: {embeddings.shape}")
# Save FAISS index
save_faiss_index(embeddings, INDEX_DIR / "faiss.index")
# Save chunk metadata
chunks_path = INDEX_DIR / "chunks.json"
with open(chunks_path, "w", encoding="utf-8") as f:
json.dump(all_chunks, f, ensure_ascii=False, indent=2)
print(f"Saved chunk metadata to {chunks_path}")
# Summary
print("\n" + "="*50)
print("DONE! Your index is ready.")
print("="*50)
print(f"\nFiles created:")
print(f" - {INDEX_DIR}/faiss.index ({embeddings.shape[0]} vectors)")
print(f" - {INDEX_DIR}/chunks.json ({len(all_chunks)} chunks)")
print(f"\nNext steps:")
print(" 1. Commit these files to your HuggingFace Space")
print(" 2. The chatbot will use this index for retrieval")
if __name__ == "__main__":
main()