""" RAG (Retrieval-Augmented Generation) Utilities Provides document loading, chunking, embedding, and retrieval for the AI chatbot. """ import os import json import hashlib from pathlib import Path from typing import List, Dict, Optional, Tuple import numpy as np from utils.openrouter_client import get_embedding, get_query_embedding, get_openrouter_client # Configuration CHUNK_SIZE = 500 # Target tokens per chunk (approximate) CHUNK_OVERLAP = 50 # Overlap between chunks SUPPORTED_EXTENSIONS = {'.txt', '.md'} CACHE_FILE = "embeddings_cache.json" class DocumentChunk: """Represents a chunk of a document with its embedding.""" def __init__( self, content: str, source_file: str, chunk_index: int, embedding: Optional[List[float]] = None ): self.content = content self.source_file = source_file self.chunk_index = chunk_index self.embedding = embedding self.content_hash = hashlib.md5(content.encode()).hexdigest() def to_dict(self) -> Dict: """Convert to dictionary for JSON serialization.""" return { "content": self.content, "source_file": self.source_file, "chunk_index": self.chunk_index, "embedding": self.embedding, "content_hash": self.content_hash } @classmethod def from_dict(cls, data: Dict) -> 'DocumentChunk': """Create from dictionary.""" chunk = cls( content=data["content"], source_file=data["source_file"], chunk_index=data["chunk_index"], embedding=data.get("embedding") ) chunk.content_hash = data.get("content_hash", chunk.content_hash) return chunk class RAGService: """Service for managing RAG document retrieval.""" def __init__(self, docs_path: str = "rag_docs"): """ Initialize the RAG service. Args: docs_path: Path to the documents folder """ self.docs_path = Path(docs_path) self.cache_path = self.docs_path / CACHE_FILE self.chunks: List[DocumentChunk] = [] self._loaded = False def _estimate_tokens(self, text: str) -> int: """Estimate token count (rough approximation: ~4 chars per token).""" return len(text) // 4 def _chunk_text(self, text: str, source_file: str) -> List[DocumentChunk]: """ Split text into chunks with overlap. Args: text: Text content to chunk source_file: Name of the source file Returns: List of DocumentChunk objects """ chunks = [] # Split into paragraphs first paragraphs = text.split('\n\n') current_chunk = "" chunk_index = 0 for para in paragraphs: para = para.strip() if not para: continue # If adding this paragraph exceeds chunk size, save current and start new if self._estimate_tokens(current_chunk + para) > CHUNK_SIZE and current_chunk: chunks.append(DocumentChunk( content=current_chunk.strip(), source_file=source_file, chunk_index=chunk_index )) chunk_index += 1 # Keep overlap from the end of current chunk words = current_chunk.split() overlap_words = words[-CHUNK_OVERLAP:] if len(words) > CHUNK_OVERLAP else words current_chunk = " ".join(overlap_words) + "\n\n" current_chunk += para + "\n\n" # Don't forget the last chunk if current_chunk.strip(): chunks.append(DocumentChunk( content=current_chunk.strip(), source_file=source_file, chunk_index=chunk_index )) return chunks def load_documents(self) -> int: """ Load and chunk all documents from the docs folder. Returns: Number of chunks loaded """ if not self.docs_path.exists(): print(f"RAG docs folder not found: {self.docs_path}") return 0 # Try to load from cache first cached_chunks = self._load_cache() cached_hashes = {c.content_hash for c in cached_chunks} new_chunks = [] # Load all document files for file_path in self.docs_path.iterdir(): if file_path.suffix.lower() not in SUPPORTED_EXTENSIONS: continue if file_path.name == CACHE_FILE or file_path.name.startswith('.'): continue try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() file_chunks = self._chunk_text(content, file_path.name) for chunk in file_chunks: if chunk.content_hash in cached_hashes: # Use cached version with embedding cached_chunk = next( (c for c in cached_chunks if c.content_hash == chunk.content_hash), None ) if cached_chunk: new_chunks.append(cached_chunk) else: new_chunks.append(chunk) except Exception as e: print(f"Error loading {file_path}: {e}") self.chunks = new_chunks self._loaded = True return len(self.chunks) def embed_documents(self) -> int: """ Generate embeddings for all chunks that don't have them. Returns: Number of new embeddings generated """ if not self._loaded: self.load_documents() client = get_openrouter_client() if not client.is_available: print("OpenRouter client not available, skipping embedding generation") return 0 embedded_count = 0 for chunk in self.chunks: if chunk.embedding is None: embedding = get_embedding(chunk.content) if embedding: chunk.embedding = embedding embedded_count += 1 # Save to cache after embedding if embedded_count > 0: self._save_cache() return embedded_count def _load_cache(self) -> List[DocumentChunk]: """Load cached embeddings from file.""" if not self.cache_path.exists(): return [] try: with open(self.cache_path, 'r', encoding='utf-8') as f: data = json.load(f) return [DocumentChunk.from_dict(d) for d in data] except Exception as e: print(f"Error loading cache: {e}") return [] def _save_cache(self): """Save embeddings to cache file.""" try: data = [c.to_dict() for c in self.chunks if c.embedding is not None] with open(self.cache_path, 'w', encoding='utf-8') as f: json.dump(data, f) except Exception as e: print(f"Error saving cache: {e}") def retrieve(self, query: str, top_k: int = 3) -> List[Tuple[DocumentChunk, float]]: """ Retrieve the most relevant chunks for a query. Args: query: User's query top_k: Number of chunks to retrieve Returns: List of (chunk, similarity_score) tuples """ if not self._loaded: self.load_documents() self.embed_documents() # Get query embedding query_embedding = get_query_embedding(query) if query_embedding is None: return [] query_vec = np.array(query_embedding) # Calculate similarities results = [] for chunk in self.chunks: if chunk.embedding is None: continue chunk_vec = np.array(chunk.embedding) # Cosine similarity similarity = np.dot(query_vec, chunk_vec) / ( np.linalg.norm(query_vec) * np.linalg.norm(chunk_vec) + 1e-8 ) results.append((chunk, float(similarity))) # Sort by similarity and return top_k results.sort(key=lambda x: x[1], reverse=True) return results[:top_k] def build_context(self, query: str, top_k: int = 3) -> str: """ Build context string from retrieved chunks. Args: query: User's query top_k: Number of chunks to include Returns: Formatted context string for the prompt """ results = self.retrieve(query, top_k) if not results: return "" context_parts = [] for chunk, score in results: source = chunk.source_file context_parts.append(f"[From {source}]:\n{chunk.content}") return "\n\n---\n\n".join(context_parts) # Singleton instance _rag_instance: Optional[RAGService] = None def get_rag_service(docs_path: str = "rag_docs") -> RAGService: """Get or create the singleton RAG service instance.""" global _rag_instance if _rag_instance is None: _rag_instance = RAGService(docs_path) return _rag_instance def retrieve_relevant_chunks(query: str, top_k: int = 3) -> List[Tuple[DocumentChunk, float]]: """ Convenience function to retrieve relevant chunks. Args: query: User's query top_k: Number of chunks to retrieve Returns: List of (chunk, score) tuples """ service = get_rag_service() return service.retrieve(query, top_k) def build_rag_context(query: str, top_k: int = 3) -> str: """ Convenience function to build RAG context. Args: query: User's query top_k: Number of chunks to include Returns: Formatted context string """ service = get_rag_service() return service.build_context(query, top_k) def initialize_rag(docs_path: str = "rag_docs") -> int: """ Initialize the RAG service by loading and embedding documents. Args: docs_path: Path to documents folder Returns: Number of chunks loaded """ service = get_rag_service(docs_path) num_chunks = service.load_documents() service.embed_documents() return num_chunks