File size: 10,497 Bytes
db7c1e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
"""
Embedding processor for the AI Backend with RAG + Authentication
Implements text preprocessing, caching, and document chunking for embeddings
"""
import hashlib
import asyncio
from typing import List, Optional, Tuple, Dict
import logging
from uuid import UUID

from ..config.settings import settings
from .gemini_client import generate_embedding, generate_embeddings_batch
from ..qdrant.operations import get_vector_operations
from ..db import crud
from ..config.database import get_db_session

logger = logging.getLogger(__name__)

# Maximum characters per chunk (Gemini has token limits)
MAX_CHUNK_SIZE = 2000
OVERLAP_SIZE = 200  # Overlap between chunks to maintain context


class EmbeddingProcessor:
    """
    Processor class to handle embedding workflows including preprocessing,
    caching, and document chunking
    """

    def __init__(self):
        self.vector_ops = get_vector_operations()
        # Simple in-memory cache (in production, use Redis or similar)
        self.cache: Dict[str, List[float]] = {}

    def _generate_content_hash(self, content: str) -> str:
        """
        Generate a hash for content to use for caching and deduplication
        """
        return hashlib.sha256(content.encode('utf-8')).hexdigest()

    def _preprocess_text(self, text: str) -> str:
        """
        Preprocess text by cleaning and normalizing
        """
        if not text or not isinstance(text, str):
            raise ValueError("Input text must be a non-empty string")

        # Remove extra whitespace
        text = ' '.join(text.split())

        # Validate text length
        if len(text) > 1000000:  # 1M characters max
            logger.warning(f"Text is very long ({len(text)} chars), consider pre-processing")

        return text.strip()

    def _chunk_text(self, text: str, chunk_size: int = MAX_CHUNK_SIZE, overlap: int = OVERLAP_SIZE) -> List[str]:
        """
        Split text into overlapping chunks to maintain context
        """
        if len(text) <= chunk_size:
            return [text]

        chunks = []
        start = 0

        while start < len(text):
            end = start + chunk_size

            # If we're near the end, include the rest
            if end > len(text):
                end = len(text)
                start = max(0, end - chunk_size)

            chunk = text[start:end]

            # If this isn't the last chunk, try to break at sentence boundary
            if end < len(text):
                # Look for sentence endings to break at
                sentence_end = max(
                    chunk.rfind('.'),
                    chunk.rfind('!'),
                    chunk.rfind('?'),
                    chunk.rfind('\n')
                )

                if sentence_end > chunk_size // 2:  # Only if it's not too early
                    end = start + sentence_end + 1
                    chunk = text[start:end]

            chunks.append(chunk)
            start = end - overlap

        return chunks

    async def _get_from_cache(self, content_hash: str) -> Optional[List[float]]:
        """
        Get embedding from cache if available
        """
        return self.cache.get(content_hash)

    async def _save_to_cache(self, content_hash: str, embedding: List[float]):
        """
        Save embedding to cache
        """
        self.cache[content_hash] = embedding

    async def process_single_text(self, text: str, user_id: UUID) -> Optional[List[float]]:
        """
        Process a single text for embedding with caching
        """
        try:
            # Preprocess the text
            processed_text = self._preprocess_text(text)

            if not processed_text:
                logger.warning("Text preprocessing resulted in empty string")
                return None

            # Generate content hash for caching
            content_hash = self._generate_content_hash(processed_text)

            # Check cache first
            cached_embedding = await self._get_from_cache(content_hash)
            if cached_embedding:
                logger.info(f"Found embedding in cache for text of length {len(processed_text)}")
                return cached_embedding

            # Generate embedding using Gemini
            embedding = await generate_embedding(processed_text)
            if embedding is None:
                logger.error(f"Failed to generate embedding for text of length {len(processed_text)}")
                return None

            # Save to cache
            await self._save_to_cache(content_hash, embedding)

            logger.info(f"Successfully processed embedding for text of length {len(processed_text)}")
            return embedding

        except Exception as e:
            logger.error(f"Error processing single text: {e}")
            return None

    async def process_document(
        self,
        document_id: UUID,
        user_id: UUID,
        content: str,
        title: Optional[str] = None,
        metadata: Optional[Dict] = None
    ) -> bool:
        """
        Process a document for embedding, including chunking and storage
        """
        try:
            # Preprocess the content
            processed_content = self._preprocess_text(content)

            if not processed_content:
                logger.warning("Document content preprocessing resulted in empty string")
                return False

            # Chunk the document if it's large
            if len(processed_content) > MAX_CHUNK_SIZE:
                chunks = self._chunk_text(processed_content)
                logger.info(f"Document chunked into {len(chunks)} parts")
            else:
                chunks = [processed_content]

            # Process each chunk
            all_embeddings = []
            chunk_payloads = []

            for i, chunk in enumerate(chunks):
                # Generate content hash for caching
                content_hash = self._generate_content_hash(chunk)

                # Check cache first
                embedding = await self._get_from_cache(content_hash)
                if embedding is None:
                    # Generate embedding using Gemini
                    embedding = await generate_embedding(chunk)
                    if embedding is None:
                        logger.error(f"Failed to generate embedding for chunk {i}")
                        continue

                    # Save to cache
                    await self._save_to_cache(content_hash, embedding)

                all_embeddings.append(embedding)

                # Create payload for this chunk
                chunk_payload = {
                    "chunk_index": i,
                    "chunk_text": chunk[:100] + "..." if len(chunk) > 100 else chunk,  # Store first 100 chars as reference
                    "document_id": str(document_id),
                    "user_id": str(user_id),
                    "title": title or "Untitled Document",
                    "total_chunks": len(chunks)
                }

                if metadata:
                    chunk_payload.update(metadata)

                chunk_payloads.append(chunk_payload)

            # Store embeddings in Qdrant
            if all_embeddings:
                success = await self.vector_ops.batch_upsert_vectors(
                    user_id=user_id,
                    document_id=document_id,
                    embeddings_list=all_embeddings,
                    payloads_list=chunk_payloads
                )

                if success:
                    logger.info(f"Successfully stored {len(all_embeddings)} embeddings for document {document_id}")
                    return True
                else:
                    logger.error(f"Failed to store embeddings in Qdrant for document {document_id}")
                    return False
            else:
                logger.warning("No embeddings were generated for the document")
                return False

        except Exception as e:
            logger.error(f"Error processing document {document_id}: {e}")
            return False

    async def process_texts_batch(
        self,
        texts: List[str],
        user_id: UUID
    ) -> Optional[List[List[float]]]:
        """
        Process a batch of texts for embedding with caching
        """
        try:
            embeddings = []

            for text in texts:
                embedding = await self.process_single_text(text, user_id)
                if embedding is None:
                    logger.error(f"Failed to process text: {text[:50]}...")
                    return None
                embeddings.append(embedding)

            logger.info(f"Successfully processed batch of {len(texts)} texts")
            return embeddings

        except Exception as e:
            logger.error(f"Error processing text batch: {e}")
            return None

    async def invalidate_cache_for_document(self, document_id: UUID):
        """
        Remove cached embeddings associated with a document
        In a real implementation with Redis, this would be more sophisticated
        """
        # In our simple in-memory cache, we can't easily identify which cache entries
        # belong to a specific document, so we'd need to implement a more sophisticated
        # cache structure. For now, we'll just log the action.
        logger.info(f"Cache invalidation requested for document {document_id} (not implemented in simple cache)")


# Global instance of EmbeddingProcessor
embedding_processor = EmbeddingProcessor()


def get_embedding_processor() -> EmbeddingProcessor:
    """Get the embedding processor instance"""
    return embedding_processor


async def process_single_text(text: str, user_id: UUID) -> Optional[List[float]]:
    """
    Process a single text for embedding with caching
    """
    return await embedding_processor.process_single_text(text, user_id)


async def process_document(
    document_id: UUID,
    user_id: UUID,
    content: str,
    title: Optional[str] = None,
    metadata: Optional[Dict] = None
) -> bool:
    """
    Process a document for embedding, including chunking and storage
    """
    return await embedding_processor.process_document(document_id, user_id, content, title, metadata)


async def process_texts_batch(
    texts: List[str],
    user_id: UUID
) -> Optional[List[List[float]]]:
    """
    Process a batch of texts for embedding with caching
    """
    return await embedding_processor.process_texts_batch(texts, user_id)