Spaces:
Runtime error
Runtime error
| """ | |
| RAG System β Hybrid approach: | |
| 1. Simple vector search (OpenAI embeddings + local storage) for reliable chat retrieval | |
| 2. LightRAG knowledge graph for enriched context (optional, non-blocking) | |
| This avoids LightRAG's internal async worker issues with Streamlit. | |
| """ | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import List, Optional | |
| import numpy as np | |
| from openai import OpenAI | |
| from config import ( | |
| OPENAI_API_KEY, | |
| CHAT_MODEL, | |
| WORKING_DIR, | |
| EMBEDDING_MODEL, | |
| CHUNK_SIZE, | |
| CHUNK_OVERLAP, | |
| ) | |
| from pdf_processor import extract_text_from_pdf, chunk_text | |
| # βββ Vector store file βββββββββββββββββββββββββββββββββββββββββββββββ | |
| VECTORS_FILE = WORKING_DIR / "vectors.json" | |
| _client: Optional[OpenAI] = None | |
| _chunks_db: list[dict] = [] # {"text": ..., "embedding": [...]} | |
| def _get_client() -> OpenAI: | |
| global _client | |
| if _client is None: | |
| _client = OpenAI(api_key=OPENAI_API_KEY) | |
| return _client | |
| def _embed_texts(texts: list[str]) -> list[list[float]]: | |
| """Get embeddings from OpenAI (sync, reliable).""" | |
| client = _get_client() | |
| response = client.embeddings.create( | |
| model=EMBEDDING_MODEL, | |
| input=texts, | |
| ) | |
| return [item.embedding for item in response.data] | |
| def _cosine_similarity(a: list[float], b: list[float]) -> float: | |
| a_np = np.array(a) | |
| b_np = np.array(b) | |
| dot = np.dot(a_np, b_np) | |
| norm = np.linalg.norm(a_np) * np.linalg.norm(b_np) | |
| return float(dot / norm) if norm > 0 else 0.0 | |
| def _load_db(): | |
| """Load vector DB from disk.""" | |
| global _chunks_db | |
| if VECTORS_FILE.exists(): | |
| with open(VECTORS_FILE, "r", encoding="utf-8") as f: | |
| _chunks_db = json.load(f) | |
| else: | |
| _chunks_db = [] | |
| def _save_db(): | |
| """Save vector DB to disk.""" | |
| WORKING_DIR.mkdir(parents=True, exist_ok=True) | |
| with open(VECTORS_FILE, "w", encoding="utf-8") as f: | |
| json.dump(_chunks_db, f) | |
| # βββ Public API (all synchronous β no event loop issues) βββββββββββββ | |
| def index_pdf(pdf_path: str | Path, source_name: str | None = None) -> int: | |
| """Extract text from PDF, chunk, embed, and store. Returns number of chunks.""" | |
| global _chunks_db | |
| text = extract_text_from_pdf(pdf_path) | |
| if not text: | |
| return 0 | |
| source = source_name or Path(pdf_path).name | |
| chunks = chunk_text(text) | |
| if not chunks: | |
| return 0 | |
| # Get embeddings for all chunks | |
| texts = [c["text"] for c in chunks] | |
| # Embed in batches of 20 to avoid token limits | |
| all_embeddings = [] | |
| for i in range(0, len(texts), 20): | |
| batch = texts[i:i+20] | |
| batch_embeddings = _embed_texts(batch) | |
| all_embeddings.extend(batch_embeddings) | |
| # Store | |
| for chunk, embedding in zip(chunks, all_embeddings): | |
| _chunks_db.append({ | |
| "text": chunk["text"], | |
| "source": source, | |
| "embedding": embedding, | |
| }) | |
| _save_db() | |
| return len(chunks) | |
| def index_pdfs(pdf_paths: List[str | Path]) -> int: | |
| """Index multiple PDFs.""" | |
| total = 0 | |
| for p in pdf_paths: | |
| total += index_pdf(p) | |
| return total | |
| def get_context_for_query(query: str, top_k: int = 5) -> str: | |
| """Retrieve relevant chunks using cosine similarity.""" | |
| _load_db() | |
| if not _chunks_db: | |
| return "" | |
| # Embed the query | |
| query_embedding = _embed_texts([query])[0] | |
| # Score all chunks | |
| scored = [] | |
| for chunk in _chunks_db: | |
| sim = _cosine_similarity(query_embedding, chunk["embedding"]) | |
| scored.append((sim, chunk["text"], chunk.get("source", "unknown"))) | |
| # Sort by similarity | |
| scored.sort(key=lambda x: x[0], reverse=True) | |
| # Take top_k | |
| results = scored[:top_k] | |
| if not results: | |
| return "" | |
| # Format context | |
| context_parts = [] | |
| for i, (score, text, source) in enumerate(results, 1): | |
| context_parts.append(f"[Source: {source} | Relevance: {score:.2f}]\n{text}") | |
| return "\n\n---\n\n".join(context_parts) | |
| def reset_index(): | |
| """Clear all indexed data.""" | |
| global _chunks_db | |
| import shutil | |
| _chunks_db = [] | |
| if WORKING_DIR.exists(): | |
| shutil.rmtree(WORKING_DIR) | |
| WORKING_DIR.mkdir(parents=True, exist_ok=True) | |