Spaces:
Sleeping
Sleeping
File size: 10,497 Bytes
db7c1e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 | """
Embedding processor for the AI Backend with RAG + Authentication
Implements text preprocessing, caching, and document chunking for embeddings
"""
import hashlib
import asyncio
from typing import List, Optional, Tuple, Dict
import logging
from uuid import UUID
from ..config.settings import settings
from .gemini_client import generate_embedding, generate_embeddings_batch
from ..qdrant.operations import get_vector_operations
from ..db import crud
from ..config.database import get_db_session
logger = logging.getLogger(__name__)
# Maximum characters per chunk (Gemini has token limits)
MAX_CHUNK_SIZE = 2000
OVERLAP_SIZE = 200 # Overlap between chunks to maintain context
class EmbeddingProcessor:
"""
Processor class to handle embedding workflows including preprocessing,
caching, and document chunking
"""
def __init__(self):
self.vector_ops = get_vector_operations()
# Simple in-memory cache (in production, use Redis or similar)
self.cache: Dict[str, List[float]] = {}
def _generate_content_hash(self, content: str) -> str:
"""
Generate a hash for content to use for caching and deduplication
"""
return hashlib.sha256(content.encode('utf-8')).hexdigest()
def _preprocess_text(self, text: str) -> str:
"""
Preprocess text by cleaning and normalizing
"""
if not text or not isinstance(text, str):
raise ValueError("Input text must be a non-empty string")
# Remove extra whitespace
text = ' '.join(text.split())
# Validate text length
if len(text) > 1000000: # 1M characters max
logger.warning(f"Text is very long ({len(text)} chars), consider pre-processing")
return text.strip()
def _chunk_text(self, text: str, chunk_size: int = MAX_CHUNK_SIZE, overlap: int = OVERLAP_SIZE) -> List[str]:
"""
Split text into overlapping chunks to maintain context
"""
if len(text) <= chunk_size:
return [text]
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
# If we're near the end, include the rest
if end > len(text):
end = len(text)
start = max(0, end - chunk_size)
chunk = text[start:end]
# If this isn't the last chunk, try to break at sentence boundary
if end < len(text):
# Look for sentence endings to break at
sentence_end = max(
chunk.rfind('.'),
chunk.rfind('!'),
chunk.rfind('?'),
chunk.rfind('\n')
)
if sentence_end > chunk_size // 2: # Only if it's not too early
end = start + sentence_end + 1
chunk = text[start:end]
chunks.append(chunk)
start = end - overlap
return chunks
async def _get_from_cache(self, content_hash: str) -> Optional[List[float]]:
"""
Get embedding from cache if available
"""
return self.cache.get(content_hash)
async def _save_to_cache(self, content_hash: str, embedding: List[float]):
"""
Save embedding to cache
"""
self.cache[content_hash] = embedding
async def process_single_text(self, text: str, user_id: UUID) -> Optional[List[float]]:
"""
Process a single text for embedding with caching
"""
try:
# Preprocess the text
processed_text = self._preprocess_text(text)
if not processed_text:
logger.warning("Text preprocessing resulted in empty string")
return None
# Generate content hash for caching
content_hash = self._generate_content_hash(processed_text)
# Check cache first
cached_embedding = await self._get_from_cache(content_hash)
if cached_embedding:
logger.info(f"Found embedding in cache for text of length {len(processed_text)}")
return cached_embedding
# Generate embedding using Gemini
embedding = await generate_embedding(processed_text)
if embedding is None:
logger.error(f"Failed to generate embedding for text of length {len(processed_text)}")
return None
# Save to cache
await self._save_to_cache(content_hash, embedding)
logger.info(f"Successfully processed embedding for text of length {len(processed_text)}")
return embedding
except Exception as e:
logger.error(f"Error processing single text: {e}")
return None
async def process_document(
self,
document_id: UUID,
user_id: UUID,
content: str,
title: Optional[str] = None,
metadata: Optional[Dict] = None
) -> bool:
"""
Process a document for embedding, including chunking and storage
"""
try:
# Preprocess the content
processed_content = self._preprocess_text(content)
if not processed_content:
logger.warning("Document content preprocessing resulted in empty string")
return False
# Chunk the document if it's large
if len(processed_content) > MAX_CHUNK_SIZE:
chunks = self._chunk_text(processed_content)
logger.info(f"Document chunked into {len(chunks)} parts")
else:
chunks = [processed_content]
# Process each chunk
all_embeddings = []
chunk_payloads = []
for i, chunk in enumerate(chunks):
# Generate content hash for caching
content_hash = self._generate_content_hash(chunk)
# Check cache first
embedding = await self._get_from_cache(content_hash)
if embedding is None:
# Generate embedding using Gemini
embedding = await generate_embedding(chunk)
if embedding is None:
logger.error(f"Failed to generate embedding for chunk {i}")
continue
# Save to cache
await self._save_to_cache(content_hash, embedding)
all_embeddings.append(embedding)
# Create payload for this chunk
chunk_payload = {
"chunk_index": i,
"chunk_text": chunk[:100] + "..." if len(chunk) > 100 else chunk, # Store first 100 chars as reference
"document_id": str(document_id),
"user_id": str(user_id),
"title": title or "Untitled Document",
"total_chunks": len(chunks)
}
if metadata:
chunk_payload.update(metadata)
chunk_payloads.append(chunk_payload)
# Store embeddings in Qdrant
if all_embeddings:
success = await self.vector_ops.batch_upsert_vectors(
user_id=user_id,
document_id=document_id,
embeddings_list=all_embeddings,
payloads_list=chunk_payloads
)
if success:
logger.info(f"Successfully stored {len(all_embeddings)} embeddings for document {document_id}")
return True
else:
logger.error(f"Failed to store embeddings in Qdrant for document {document_id}")
return False
else:
logger.warning("No embeddings were generated for the document")
return False
except Exception as e:
logger.error(f"Error processing document {document_id}: {e}")
return False
async def process_texts_batch(
self,
texts: List[str],
user_id: UUID
) -> Optional[List[List[float]]]:
"""
Process a batch of texts for embedding with caching
"""
try:
embeddings = []
for text in texts:
embedding = await self.process_single_text(text, user_id)
if embedding is None:
logger.error(f"Failed to process text: {text[:50]}...")
return None
embeddings.append(embedding)
logger.info(f"Successfully processed batch of {len(texts)} texts")
return embeddings
except Exception as e:
logger.error(f"Error processing text batch: {e}")
return None
async def invalidate_cache_for_document(self, document_id: UUID):
"""
Remove cached embeddings associated with a document
In a real implementation with Redis, this would be more sophisticated
"""
# In our simple in-memory cache, we can't easily identify which cache entries
# belong to a specific document, so we'd need to implement a more sophisticated
# cache structure. For now, we'll just log the action.
logger.info(f"Cache invalidation requested for document {document_id} (not implemented in simple cache)")
# Global instance of EmbeddingProcessor
embedding_processor = EmbeddingProcessor()
def get_embedding_processor() -> EmbeddingProcessor:
"""Get the embedding processor instance"""
return embedding_processor
async def process_single_text(text: str, user_id: UUID) -> Optional[List[float]]:
"""
Process a single text for embedding with caching
"""
return await embedding_processor.process_single_text(text, user_id)
async def process_document(
document_id: UUID,
user_id: UUID,
content: str,
title: Optional[str] = None,
metadata: Optional[Dict] = None
) -> bool:
"""
Process a document for embedding, including chunking and storage
"""
return await embedding_processor.process_document(document_id, user_id, content, title, metadata)
async def process_texts_batch(
texts: List[str],
user_id: UUID
) -> Optional[List[List[float]]]:
"""
Process a batch of texts for embedding with caching
"""
return await embedding_processor.process_texts_batch(texts, user_id) |