Spaces:
Running
Running
File size: 10,961 Bytes
62884e7 ef3e36a 62884e7 ef3e36a 62884e7 ef3e36a 62884e7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 | """
RAG (Retrieval-Augmented Generation) Utilities
Provides document loading, chunking, embedding, and retrieval for the AI chatbot.
"""
import os
import json
import hashlib
from pathlib import Path
from typing import List, Dict, Optional, Tuple
import numpy as np
from utils.openrouter_client import get_embedding, get_query_embedding, get_openrouter_client
# Configuration
CHUNK_SIZE = 500 # Target tokens per chunk (approximate)
CHUNK_OVERLAP = 50 # Overlap between chunks
SUPPORTED_EXTENSIONS = {'.txt', '.md'}
CACHE_FILE = "embeddings_cache.json"
class DocumentChunk:
"""Represents a chunk of a document with its embedding."""
def __init__(
self,
content: str,
source_file: str,
chunk_index: int,
embedding: Optional[List[float]] = None
):
self.content = content
self.source_file = source_file
self.chunk_index = chunk_index
self.embedding = embedding
self.content_hash = hashlib.md5(content.encode()).hexdigest()
def to_dict(self) -> Dict:
"""Convert to dictionary for JSON serialization."""
return {
"content": self.content,
"source_file": self.source_file,
"chunk_index": self.chunk_index,
"embedding": self.embedding,
"content_hash": self.content_hash
}
@classmethod
def from_dict(cls, data: Dict) -> 'DocumentChunk':
"""Create from dictionary."""
chunk = cls(
content=data["content"],
source_file=data["source_file"],
chunk_index=data["chunk_index"],
embedding=data.get("embedding")
)
chunk.content_hash = data.get("content_hash", chunk.content_hash)
return chunk
class RAGService:
"""Service for managing RAG document retrieval."""
def __init__(self, docs_path: str = "rag_docs"):
"""
Initialize the RAG service.
Args:
docs_path: Path to the documents folder
"""
self.docs_path = Path(docs_path)
self.cache_path = self.docs_path / CACHE_FILE
self.chunks: List[DocumentChunk] = []
self._loaded = False
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count (rough approximation: ~4 chars per token)."""
return len(text) // 4
def _chunk_text(self, text: str, source_file: str) -> List[DocumentChunk]:
"""
Split text into chunks with overlap.
Args:
text: Text content to chunk
source_file: Name of the source file
Returns:
List of DocumentChunk objects
"""
chunks = []
# Split into paragraphs first
paragraphs = text.split('\n\n')
current_chunk = ""
chunk_index = 0
for para in paragraphs:
para = para.strip()
if not para:
continue
# If adding this paragraph exceeds chunk size, save current and start new
if self._estimate_tokens(current_chunk + para) > CHUNK_SIZE and current_chunk:
chunks.append(DocumentChunk(
content=current_chunk.strip(),
source_file=source_file,
chunk_index=chunk_index
))
chunk_index += 1
# Keep overlap from the end of current chunk
words = current_chunk.split()
overlap_words = words[-CHUNK_OVERLAP:] if len(words) > CHUNK_OVERLAP else words
current_chunk = " ".join(overlap_words) + "\n\n"
current_chunk += para + "\n\n"
# Don't forget the last chunk
if current_chunk.strip():
chunks.append(DocumentChunk(
content=current_chunk.strip(),
source_file=source_file,
chunk_index=chunk_index
))
return chunks
def load_documents(self) -> int:
"""
Load and chunk all documents from the docs folder.
Returns:
Number of chunks loaded
"""
if not self.docs_path.exists():
print(f"RAG docs folder not found: {self.docs_path}")
return 0
# Try to load from cache first
cached_chunks = self._load_cache()
cached_hashes = {c.content_hash for c in cached_chunks}
new_chunks = []
# Load all document files
for file_path in self.docs_path.iterdir():
if file_path.suffix.lower() not in SUPPORTED_EXTENSIONS:
continue
if file_path.name == CACHE_FILE or file_path.name.startswith('.'):
continue
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
file_chunks = self._chunk_text(content, file_path.name)
for chunk in file_chunks:
if chunk.content_hash in cached_hashes:
# Use cached version with embedding
cached_chunk = next(
(c for c in cached_chunks if c.content_hash == chunk.content_hash),
None
)
if cached_chunk:
new_chunks.append(cached_chunk)
else:
new_chunks.append(chunk)
except Exception as e:
print(f"Error loading {file_path}: {e}")
self.chunks = new_chunks
self._loaded = True
return len(self.chunks)
def embed_documents(self) -> int:
"""
Generate embeddings for all chunks that don't have them.
Returns:
Number of new embeddings generated
"""
if not self._loaded:
self.load_documents()
client = get_openrouter_client()
if not client.is_available:
print("OpenRouter client not available, skipping embedding generation")
return 0
embedded_count = 0
for chunk in self.chunks:
if chunk.embedding is None:
embedding = get_embedding(chunk.content)
if embedding:
chunk.embedding = embedding
embedded_count += 1
# Save to cache after embedding
if embedded_count > 0:
self._save_cache()
return embedded_count
def _load_cache(self) -> List[DocumentChunk]:
"""Load cached embeddings from file."""
if not self.cache_path.exists():
return []
try:
with open(self.cache_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return [DocumentChunk.from_dict(d) for d in data]
except Exception as e:
print(f"Error loading cache: {e}")
return []
def _save_cache(self):
"""Save embeddings to cache file."""
try:
data = [c.to_dict() for c in self.chunks if c.embedding is not None]
with open(self.cache_path, 'w', encoding='utf-8') as f:
json.dump(data, f)
except Exception as e:
print(f"Error saving cache: {e}")
def retrieve(self, query: str, top_k: int = 3) -> List[Tuple[DocumentChunk, float]]:
"""
Retrieve the most relevant chunks for a query.
Args:
query: User's query
top_k: Number of chunks to retrieve
Returns:
List of (chunk, similarity_score) tuples
"""
if not self._loaded:
self.load_documents()
self.embed_documents()
# Get query embedding
query_embedding = get_query_embedding(query)
if query_embedding is None:
return []
query_vec = np.array(query_embedding)
# Calculate similarities
results = []
for chunk in self.chunks:
if chunk.embedding is None:
continue
chunk_vec = np.array(chunk.embedding)
# Cosine similarity
similarity = np.dot(query_vec, chunk_vec) / (
np.linalg.norm(query_vec) * np.linalg.norm(chunk_vec) + 1e-8
)
results.append((chunk, float(similarity)))
# Sort by similarity and return top_k
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
def build_context(self, query: str, top_k: int = 3) -> str:
"""
Build context string from retrieved chunks.
Args:
query: User's query
top_k: Number of chunks to include
Returns:
Formatted context string for the prompt
"""
results = self.retrieve(query, top_k)
if not results:
return ""
context_parts = []
for chunk, score in results:
source = chunk.source_file
context_parts.append(f"[From {source}]:\n{chunk.content}")
return "\n\n---\n\n".join(context_parts)
# Singleton instance
_rag_instance: Optional[RAGService] = None
def get_rag_service(docs_path: str = "rag_docs") -> RAGService:
"""Get or create the singleton RAG service instance."""
global _rag_instance
if _rag_instance is None:
_rag_instance = RAGService(docs_path)
return _rag_instance
def retrieve_relevant_chunks(query: str, top_k: int = 3) -> List[Tuple[DocumentChunk, float]]:
"""
Convenience function to retrieve relevant chunks.
Args:
query: User's query
top_k: Number of chunks to retrieve
Returns:
List of (chunk, score) tuples
"""
service = get_rag_service()
return service.retrieve(query, top_k)
def build_rag_context(query: str, top_k: int = 3) -> str:
"""
Convenience function to build RAG context.
Args:
query: User's query
top_k: Number of chunks to include
Returns:
Formatted context string
"""
service = get_rag_service()
return service.build_context(query, top_k)
def initialize_rag(docs_path: str = "rag_docs") -> int:
"""
Initialize the RAG service by loading and embedding documents.
Args:
docs_path: Path to documents folder
Returns:
Number of chunks loaded
"""
service = get_rag_service(docs_path)
num_chunks = service.load_documents()
service.embed_documents()
return num_chunks
|