Naveedtechlab's picture
Add full AI Native Textbook project source code
db7c1e8
"""
Embedding processor for the AI Backend with RAG + Authentication
Implements text preprocessing, caching, and document chunking for embeddings
"""
import hashlib
import asyncio
from typing import List, Optional, Tuple, Dict
import logging
from uuid import UUID
from ..config.settings import settings
from .gemini_client import generate_embedding, generate_embeddings_batch
from ..qdrant.operations import get_vector_operations
from ..db import crud
from ..config.database import get_db_session
logger = logging.getLogger(__name__)
# Maximum characters per chunk (Gemini has token limits)
MAX_CHUNK_SIZE = 2000
OVERLAP_SIZE = 200 # Overlap between chunks to maintain context
class EmbeddingProcessor:
"""
Processor class to handle embedding workflows including preprocessing,
caching, and document chunking
"""
def __init__(self):
self.vector_ops = get_vector_operations()
# Simple in-memory cache (in production, use Redis or similar)
self.cache: Dict[str, List[float]] = {}
def _generate_content_hash(self, content: str) -> str:
"""
Generate a hash for content to use for caching and deduplication
"""
return hashlib.sha256(content.encode('utf-8')).hexdigest()
def _preprocess_text(self, text: str) -> str:
"""
Preprocess text by cleaning and normalizing
"""
if not text or not isinstance(text, str):
raise ValueError("Input text must be a non-empty string")
# Remove extra whitespace
text = ' '.join(text.split())
# Validate text length
if len(text) > 1000000: # 1M characters max
logger.warning(f"Text is very long ({len(text)} chars), consider pre-processing")
return text.strip()
def _chunk_text(self, text: str, chunk_size: int = MAX_CHUNK_SIZE, overlap: int = OVERLAP_SIZE) -> List[str]:
"""
Split text into overlapping chunks to maintain context
"""
if len(text) <= chunk_size:
return [text]
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
# If we're near the end, include the rest
if end > len(text):
end = len(text)
start = max(0, end - chunk_size)
chunk = text[start:end]
# If this isn't the last chunk, try to break at sentence boundary
if end < len(text):
# Look for sentence endings to break at
sentence_end = max(
chunk.rfind('.'),
chunk.rfind('!'),
chunk.rfind('?'),
chunk.rfind('\n')
)
if sentence_end > chunk_size // 2: # Only if it's not too early
end = start + sentence_end + 1
chunk = text[start:end]
chunks.append(chunk)
start = end - overlap
return chunks
async def _get_from_cache(self, content_hash: str) -> Optional[List[float]]:
"""
Get embedding from cache if available
"""
return self.cache.get(content_hash)
async def _save_to_cache(self, content_hash: str, embedding: List[float]):
"""
Save embedding to cache
"""
self.cache[content_hash] = embedding
async def process_single_text(self, text: str, user_id: UUID) -> Optional[List[float]]:
"""
Process a single text for embedding with caching
"""
try:
# Preprocess the text
processed_text = self._preprocess_text(text)
if not processed_text:
logger.warning("Text preprocessing resulted in empty string")
return None
# Generate content hash for caching
content_hash = self._generate_content_hash(processed_text)
# Check cache first
cached_embedding = await self._get_from_cache(content_hash)
if cached_embedding:
logger.info(f"Found embedding in cache for text of length {len(processed_text)}")
return cached_embedding
# Generate embedding using Gemini
embedding = await generate_embedding(processed_text)
if embedding is None:
logger.error(f"Failed to generate embedding for text of length {len(processed_text)}")
return None
# Save to cache
await self._save_to_cache(content_hash, embedding)
logger.info(f"Successfully processed embedding for text of length {len(processed_text)}")
return embedding
except Exception as e:
logger.error(f"Error processing single text: {e}")
return None
async def process_document(
self,
document_id: UUID,
user_id: UUID,
content: str,
title: Optional[str] = None,
metadata: Optional[Dict] = None
) -> bool:
"""
Process a document for embedding, including chunking and storage
"""
try:
# Preprocess the content
processed_content = self._preprocess_text(content)
if not processed_content:
logger.warning("Document content preprocessing resulted in empty string")
return False
# Chunk the document if it's large
if len(processed_content) > MAX_CHUNK_SIZE:
chunks = self._chunk_text(processed_content)
logger.info(f"Document chunked into {len(chunks)} parts")
else:
chunks = [processed_content]
# Process each chunk
all_embeddings = []
chunk_payloads = []
for i, chunk in enumerate(chunks):
# Generate content hash for caching
content_hash = self._generate_content_hash(chunk)
# Check cache first
embedding = await self._get_from_cache(content_hash)
if embedding is None:
# Generate embedding using Gemini
embedding = await generate_embedding(chunk)
if embedding is None:
logger.error(f"Failed to generate embedding for chunk {i}")
continue
# Save to cache
await self._save_to_cache(content_hash, embedding)
all_embeddings.append(embedding)
# Create payload for this chunk
chunk_payload = {
"chunk_index": i,
"chunk_text": chunk[:100] + "..." if len(chunk) > 100 else chunk, # Store first 100 chars as reference
"document_id": str(document_id),
"user_id": str(user_id),
"title": title or "Untitled Document",
"total_chunks": len(chunks)
}
if metadata:
chunk_payload.update(metadata)
chunk_payloads.append(chunk_payload)
# Store embeddings in Qdrant
if all_embeddings:
success = await self.vector_ops.batch_upsert_vectors(
user_id=user_id,
document_id=document_id,
embeddings_list=all_embeddings,
payloads_list=chunk_payloads
)
if success:
logger.info(f"Successfully stored {len(all_embeddings)} embeddings for document {document_id}")
return True
else:
logger.error(f"Failed to store embeddings in Qdrant for document {document_id}")
return False
else:
logger.warning("No embeddings were generated for the document")
return False
except Exception as e:
logger.error(f"Error processing document {document_id}: {e}")
return False
async def process_texts_batch(
self,
texts: List[str],
user_id: UUID
) -> Optional[List[List[float]]]:
"""
Process a batch of texts for embedding with caching
"""
try:
embeddings = []
for text in texts:
embedding = await self.process_single_text(text, user_id)
if embedding is None:
logger.error(f"Failed to process text: {text[:50]}...")
return None
embeddings.append(embedding)
logger.info(f"Successfully processed batch of {len(texts)} texts")
return embeddings
except Exception as e:
logger.error(f"Error processing text batch: {e}")
return None
async def invalidate_cache_for_document(self, document_id: UUID):
"""
Remove cached embeddings associated with a document
In a real implementation with Redis, this would be more sophisticated
"""
# In our simple in-memory cache, we can't easily identify which cache entries
# belong to a specific document, so we'd need to implement a more sophisticated
# cache structure. For now, we'll just log the action.
logger.info(f"Cache invalidation requested for document {document_id} (not implemented in simple cache)")
# Global instance of EmbeddingProcessor
embedding_processor = EmbeddingProcessor()
def get_embedding_processor() -> EmbeddingProcessor:
"""Get the embedding processor instance"""
return embedding_processor
async def process_single_text(text: str, user_id: UUID) -> Optional[List[float]]:
"""
Process a single text for embedding with caching
"""
return await embedding_processor.process_single_text(text, user_id)
async def process_document(
document_id: UUID,
user_id: UUID,
content: str,
title: Optional[str] = None,
metadata: Optional[Dict] = None
) -> bool:
"""
Process a document for embedding, including chunking and storage
"""
return await embedding_processor.process_document(document_id, user_id, content, title, metadata)
async def process_texts_batch(
texts: List[str],
user_id: UUID
) -> Optional[List[List[float]]]:
"""
Process a batch of texts for embedding with caching
"""
return await embedding_processor.process_texts_batch(texts, user_id)