""" Embedding processor for the AI Backend with RAG + Authentication Implements text preprocessing, caching, and document chunking for embeddings """ import hashlib import asyncio from typing import List, Optional, Tuple, Dict import logging from uuid import UUID from ..config.settings import settings from .gemini_client import generate_embedding, generate_embeddings_batch from ..qdrant.operations import get_vector_operations from ..db import crud from ..config.database import get_db_session logger = logging.getLogger(__name__) # Maximum characters per chunk (Gemini has token limits) MAX_CHUNK_SIZE = 2000 OVERLAP_SIZE = 200 # Overlap between chunks to maintain context class EmbeddingProcessor: """ Processor class to handle embedding workflows including preprocessing, caching, and document chunking """ def __init__(self): self.vector_ops = get_vector_operations() # Simple in-memory cache (in production, use Redis or similar) self.cache: Dict[str, List[float]] = {} def _generate_content_hash(self, content: str) -> str: """ Generate a hash for content to use for caching and deduplication """ return hashlib.sha256(content.encode('utf-8')).hexdigest() def _preprocess_text(self, text: str) -> str: """ Preprocess text by cleaning and normalizing """ if not text or not isinstance(text, str): raise ValueError("Input text must be a non-empty string") # Remove extra whitespace text = ' '.join(text.split()) # Validate text length if len(text) > 1000000: # 1M characters max logger.warning(f"Text is very long ({len(text)} chars), consider pre-processing") return text.strip() def _chunk_text(self, text: str, chunk_size: int = MAX_CHUNK_SIZE, overlap: int = OVERLAP_SIZE) -> List[str]: """ Split text into overlapping chunks to maintain context """ if len(text) <= chunk_size: return [text] chunks = [] start = 0 while start < len(text): end = start + chunk_size # If we're near the end, include the rest if end > len(text): end = len(text) start = max(0, end - chunk_size) chunk = text[start:end] # If this isn't the last chunk, try to break at sentence boundary if end < len(text): # Look for sentence endings to break at sentence_end = max( chunk.rfind('.'), chunk.rfind('!'), chunk.rfind('?'), chunk.rfind('\n') ) if sentence_end > chunk_size // 2: # Only if it's not too early end = start + sentence_end + 1 chunk = text[start:end] chunks.append(chunk) start = end - overlap return chunks async def _get_from_cache(self, content_hash: str) -> Optional[List[float]]: """ Get embedding from cache if available """ return self.cache.get(content_hash) async def _save_to_cache(self, content_hash: str, embedding: List[float]): """ Save embedding to cache """ self.cache[content_hash] = embedding async def process_single_text(self, text: str, user_id: UUID) -> Optional[List[float]]: """ Process a single text for embedding with caching """ try: # Preprocess the text processed_text = self._preprocess_text(text) if not processed_text: logger.warning("Text preprocessing resulted in empty string") return None # Generate content hash for caching content_hash = self._generate_content_hash(processed_text) # Check cache first cached_embedding = await self._get_from_cache(content_hash) if cached_embedding: logger.info(f"Found embedding in cache for text of length {len(processed_text)}") return cached_embedding # Generate embedding using Gemini embedding = await generate_embedding(processed_text) if embedding is None: logger.error(f"Failed to generate embedding for text of length {len(processed_text)}") return None # Save to cache await self._save_to_cache(content_hash, embedding) logger.info(f"Successfully processed embedding for text of length {len(processed_text)}") return embedding except Exception as e: logger.error(f"Error processing single text: {e}") return None async def process_document( self, document_id: UUID, user_id: UUID, content: str, title: Optional[str] = None, metadata: Optional[Dict] = None ) -> bool: """ Process a document for embedding, including chunking and storage """ try: # Preprocess the content processed_content = self._preprocess_text(content) if not processed_content: logger.warning("Document content preprocessing resulted in empty string") return False # Chunk the document if it's large if len(processed_content) > MAX_CHUNK_SIZE: chunks = self._chunk_text(processed_content) logger.info(f"Document chunked into {len(chunks)} parts") else: chunks = [processed_content] # Process each chunk all_embeddings = [] chunk_payloads = [] for i, chunk in enumerate(chunks): # Generate content hash for caching content_hash = self._generate_content_hash(chunk) # Check cache first embedding = await self._get_from_cache(content_hash) if embedding is None: # Generate embedding using Gemini embedding = await generate_embedding(chunk) if embedding is None: logger.error(f"Failed to generate embedding for chunk {i}") continue # Save to cache await self._save_to_cache(content_hash, embedding) all_embeddings.append(embedding) # Create payload for this chunk chunk_payload = { "chunk_index": i, "chunk_text": chunk[:100] + "..." if len(chunk) > 100 else chunk, # Store first 100 chars as reference "document_id": str(document_id), "user_id": str(user_id), "title": title or "Untitled Document", "total_chunks": len(chunks) } if metadata: chunk_payload.update(metadata) chunk_payloads.append(chunk_payload) # Store embeddings in Qdrant if all_embeddings: success = await self.vector_ops.batch_upsert_vectors( user_id=user_id, document_id=document_id, embeddings_list=all_embeddings, payloads_list=chunk_payloads ) if success: logger.info(f"Successfully stored {len(all_embeddings)} embeddings for document {document_id}") return True else: logger.error(f"Failed to store embeddings in Qdrant for document {document_id}") return False else: logger.warning("No embeddings were generated for the document") return False except Exception as e: logger.error(f"Error processing document {document_id}: {e}") return False async def process_texts_batch( self, texts: List[str], user_id: UUID ) -> Optional[List[List[float]]]: """ Process a batch of texts for embedding with caching """ try: embeddings = [] for text in texts: embedding = await self.process_single_text(text, user_id) if embedding is None: logger.error(f"Failed to process text: {text[:50]}...") return None embeddings.append(embedding) logger.info(f"Successfully processed batch of {len(texts)} texts") return embeddings except Exception as e: logger.error(f"Error processing text batch: {e}") return None async def invalidate_cache_for_document(self, document_id: UUID): """ Remove cached embeddings associated with a document In a real implementation with Redis, this would be more sophisticated """ # In our simple in-memory cache, we can't easily identify which cache entries # belong to a specific document, so we'd need to implement a more sophisticated # cache structure. For now, we'll just log the action. logger.info(f"Cache invalidation requested for document {document_id} (not implemented in simple cache)") # Global instance of EmbeddingProcessor embedding_processor = EmbeddingProcessor() def get_embedding_processor() -> EmbeddingProcessor: """Get the embedding processor instance""" return embedding_processor async def process_single_text(text: str, user_id: UUID) -> Optional[List[float]]: """ Process a single text for embedding with caching """ return await embedding_processor.process_single_text(text, user_id) async def process_document( document_id: UUID, user_id: UUID, content: str, title: Optional[str] = None, metadata: Optional[Dict] = None ) -> bool: """ Process a document for embedding, including chunking and storage """ return await embedding_processor.process_document(document_id, user_id, content, title, metadata) async def process_texts_batch( texts: List[str], user_id: UUID ) -> Optional[List[List[float]]]: """ Process a batch of texts for embedding with caching """ return await embedding_processor.process_texts_batch(texts, user_id)