LunarTech / src /rag.py
vishalkatheriya's picture
Upload 14 files
24773d4 verified
"""
RAG System β€” Hybrid approach:
1. Simple vector search (OpenAI embeddings + local storage) for reliable chat retrieval
2. LightRAG knowledge graph for enriched context (optional, non-blocking)
This avoids LightRAG's internal async worker issues with Streamlit.
"""
import json
import os
from pathlib import Path
from typing import List, Optional
import numpy as np
from openai import OpenAI
from config import (
OPENAI_API_KEY,
CHAT_MODEL,
WORKING_DIR,
EMBEDDING_MODEL,
CHUNK_SIZE,
CHUNK_OVERLAP,
)
from pdf_processor import extract_text_from_pdf, chunk_text
# ─── Vector store file ───────────────────────────────────────────────
VECTORS_FILE = WORKING_DIR / "vectors.json"
_client: Optional[OpenAI] = None
_chunks_db: list[dict] = [] # {"text": ..., "embedding": [...]}
def _get_client() -> OpenAI:
global _client
if _client is None:
_client = OpenAI(api_key=OPENAI_API_KEY)
return _client
def _embed_texts(texts: list[str]) -> list[list[float]]:
"""Get embeddings from OpenAI (sync, reliable)."""
client = _get_client()
response = client.embeddings.create(
model=EMBEDDING_MODEL,
input=texts,
)
return [item.embedding for item in response.data]
def _cosine_similarity(a: list[float], b: list[float]) -> float:
a_np = np.array(a)
b_np = np.array(b)
dot = np.dot(a_np, b_np)
norm = np.linalg.norm(a_np) * np.linalg.norm(b_np)
return float(dot / norm) if norm > 0 else 0.0
def _load_db():
"""Load vector DB from disk."""
global _chunks_db
if VECTORS_FILE.exists():
with open(VECTORS_FILE, "r", encoding="utf-8") as f:
_chunks_db = json.load(f)
else:
_chunks_db = []
def _save_db():
"""Save vector DB to disk."""
WORKING_DIR.mkdir(parents=True, exist_ok=True)
with open(VECTORS_FILE, "w", encoding="utf-8") as f:
json.dump(_chunks_db, f)
# ─── Public API (all synchronous β€” no event loop issues) ─────────────
def index_pdf(pdf_path: str | Path, source_name: str | None = None) -> int:
"""Extract text from PDF, chunk, embed, and store. Returns number of chunks."""
global _chunks_db
text = extract_text_from_pdf(pdf_path)
if not text:
return 0
source = source_name or Path(pdf_path).name
chunks = chunk_text(text)
if not chunks:
return 0
# Get embeddings for all chunks
texts = [c["text"] for c in chunks]
# Embed in batches of 20 to avoid token limits
all_embeddings = []
for i in range(0, len(texts), 20):
batch = texts[i:i+20]
batch_embeddings = _embed_texts(batch)
all_embeddings.extend(batch_embeddings)
# Store
for chunk, embedding in zip(chunks, all_embeddings):
_chunks_db.append({
"text": chunk["text"],
"source": source,
"embedding": embedding,
})
_save_db()
return len(chunks)
def index_pdfs(pdf_paths: List[str | Path]) -> int:
"""Index multiple PDFs."""
total = 0
for p in pdf_paths:
total += index_pdf(p)
return total
def get_context_for_query(query: str, top_k: int = 5) -> str:
"""Retrieve relevant chunks using cosine similarity."""
_load_db()
if not _chunks_db:
return ""
# Embed the query
query_embedding = _embed_texts([query])[0]
# Score all chunks
scored = []
for chunk in _chunks_db:
sim = _cosine_similarity(query_embedding, chunk["embedding"])
scored.append((sim, chunk["text"], chunk.get("source", "unknown")))
# Sort by similarity
scored.sort(key=lambda x: x[0], reverse=True)
# Take top_k
results = scored[:top_k]
if not results:
return ""
# Format context
context_parts = []
for i, (score, text, source) in enumerate(results, 1):
context_parts.append(f"[Source: {source} | Relevance: {score:.2f}]\n{text}")
return "\n\n---\n\n".join(context_parts)
def reset_index():
"""Clear all indexed data."""
global _chunks_db
import shutil
_chunks_db = []
if WORKING_DIR.exists():
shutil.rmtree(WORKING_DIR)
WORKING_DIR.mkdir(parents=True, exist_ok=True)