File size: 19,129 Bytes
01d5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
from typing import List, Tuple
import chromadb
from chromadb.utils import embedding_functions
import os
from .dto.chunk_dto import ChunkDTO
from lpm_kernel.common.llm import LLMClient
from lpm_kernel.file_data.document_dto import DocumentDTO
from typing import List, Dict, Optional
from lpm_kernel.configs.logging import get_train_process_logger
logger = get_train_process_logger()


class EmbeddingService:
    def __init__(self):
        from lpm_kernel.file_data.chroma_utils import detect_embedding_model_dimension
        from lpm_kernel.api.services.user_llm_config_service import UserLLMConfigService
        
        chroma_path = os.getenv("CHROMA_PERSIST_DIRECTORY", "./data/chroma_db")
        self.client = chromadb.PersistentClient(path=chroma_path)
        self.llm_client = LLMClient()
        
        # Get embedding model dimension from user config
        try:
            user_llm_config_service = UserLLMConfigService()
            user_llm_config = user_llm_config_service.get_available_llm()
            
            if user_llm_config and user_llm_config.embedding_model_name:
                # Detect dimension based on model name
                self.dimension = detect_embedding_model_dimension(user_llm_config.embedding_model_name)
                logger.info(f"Detected embedding dimension: {self.dimension} for model: {user_llm_config.embedding_model_name}")
            else:
                # Default to OpenAI dimension if no config found
                self.dimension = 1536
                logger.info(f"No embedding model configured, using default dimension: {self.dimension}")
        except Exception as e:
            # Default to OpenAI dimension if error occurs
            self.dimension = 1536
            logger.error(f"Error detecting embedding dimension, using default: {self.dimension}. Error: {str(e)}", exc_info=True)

        # Check for dimension mismatches in all collections first
        collections_to_init = ["documents", "document_chunks"]
        dimension_mismatch_detected = False
        
        # First pass: check all collections for dimension mismatches
        for collection_name in collections_to_init:
            try:
                collection = self.client.get_collection(name=collection_name)
                if collection.metadata.get("dimension") != self.dimension:
                    logger.warning(f"Dimension mismatch in '{collection_name}' collection: {collection.metadata.get('dimension')} vs {self.dimension}")
                    dimension_mismatch_detected = True
            except ValueError:
                # Collection doesn't exist yet, will be created later
                pass
        
        # Handle dimension mismatch if detected in any collection
        if dimension_mismatch_detected:
            self._handle_dimension_mismatch()
        
        # Second pass: create or get collections with the correct dimension
        try:
            self.document_collection = self.client.get_collection(name="documents")
            # Verify dimension after possible reinitialization
            doc_dimension = self.document_collection.metadata.get("dimension")
            if doc_dimension != self.dimension:
                logger.error(f"Collection 'documents' still has incorrect dimension after reinitialization: {doc_dimension} vs {self.dimension}")
                # Try to reinitialize again if dimension is still incorrect
                raise RuntimeError(f"Failed to set correct dimension for 'documents' collection: {doc_dimension} vs {self.dimension}")
        except ValueError:
            # Collection doesn't exist, create it with the correct dimension
            try:
                self.document_collection = self.client.create_collection(
                    name="documents", metadata={"hnsw:space": "cosine", "dimension": self.dimension}
                )
                logger.info(f"Created 'documents' collection with dimension {self.dimension}")
            except Exception as e:
                logger.error(f"Failed to create 'documents' collection: {str(e)}", exc_info=True)
                raise RuntimeError(f"Failed to create 'documents' collection: {str(e)}")

        try:
            self.chunk_collection = self.client.get_collection(name="document_chunks")
            # Verify dimension after possible reinitialization
            chunk_dimension = self.chunk_collection.metadata.get("dimension")
            if chunk_dimension != self.dimension:
                logger.error(f"Collection 'document_chunks' still has incorrect dimension after reinitialization: {chunk_dimension} vs {self.dimension}")
                # Try to reinitialize again if dimension is still incorrect
                raise RuntimeError(f"Failed to set correct dimension for 'document_chunks' collection: {chunk_dimension} vs {self.dimension}")
        except ValueError:
            # Collection doesn't exist, create it with the correct dimension
            try:
                self.chunk_collection = self.client.create_collection(
                    name="document_chunks", metadata={"hnsw:space": "cosine", "dimension": self.dimension}
                )
                logger.info(f"Created 'document_chunks' collection with dimension {self.dimension}")
            except Exception as e:
                logger.error(f"Failed to create 'document_chunks' collection: {str(e)}", exc_info=True)
                raise RuntimeError(f"Failed to create 'document_chunks' collection: {str(e)}")

    def generate_document_embedding(self, document: DocumentDTO) -> List[float]:
        """Process document level embedding and store in ChromaDB"""
        try:
            if not document.raw_content:
                logger.warning(
                    f"Document {document.id} has no content to process embedding"
                )
                return None

            # get embedding
            logger.info(f"Generating embedding for document {document.id}")
            embeddings = self.llm_client.get_embedding([document.raw_content])

            if embeddings is None or len(embeddings) == 0:
                logger.error(f"Failed to get embedding for document {document.id}")
                return None

            embedding = embeddings[0]
            logger.info(f"Successfully got embedding for document {document.id}")

            # store to ChromaDB
            try:
                logger.info(f"Storing embedding for document {document.id} in ChromaDB")
                self.document_collection.add(
                    documents=[document.raw_content],
                    ids=[str(document.id)],
                    embeddings=[embedding.tolist()],
                    metadatas=[
                        {
                            "title": document.title or document.name,
                            "mime_type": document.mime_type,
                            "create_time": document.create_time.isoformat()
                            if document.create_time
                            else None,
                            "document_size": document.document_size,
                            "url": document.url,
                        }
                    ],
                )
                logger.info(f"Successfully stored embedding for document {document.id}")

                # verify embedding storage
                result = self.document_collection.get(
                    ids=[str(document.id)], include=["embeddings"]
                )
                if not result or not result["embeddings"]:
                    logger.error(
                        f"Failed to verify embedding storage for document {document.id}"
                    )
                    return None
                logger.info(f"Verified embedding storage for document {document.id}")

                return embedding

            except Exception as e:
                logger.error(f"Error storing document embedding in ChromaDB: {str(e)}", exc_info=True)
                return None

        except Exception as e:
            logger.error(f"Error processing document embedding: {str(e)}", exc_info=True)
            raise

    def generate_chunk_embeddings(self, chunks: List[ChunkDTO]) -> List[ChunkDTO]:
        """Process chunk level embeddings"""
        """
        Store in ChromaDB, the structure is as follows:
        documents=[c.content for c in unprocessed_chunks],
                    ids=[str(c.id) for c in unprocessed_chunks],
                    embeddings=embeddings.tolist(),
                    metadatas=[
                        {
                            "document_id": str(c.document_id),
                            "topic": c.topic or "",
                            "tags": ",".join(c.tags) if c.tags else "",
                        }
                        for c in unprocessed_chunks
                    ],
        """
        try:
            unprocessed_chunks = [c for c in chunks if not c.has_embedding]
            if not unprocessed_chunks:
                logger.info("No unprocessed chunks found")
                return chunks

            logger.info(f"Processing embeddings for {len(unprocessed_chunks)} chunks")

            contents = [c.content for c in unprocessed_chunks]
            logger.info("Getting embeddings from LLM service... {}".format(contents))
            embeddings = self.llm_client.get_embedding(contents)

            if embeddings is None or len(embeddings) == 0:
                logger.error("Failed to get embeddings from LLM service")
                return chunks

            logger.info(f"Successfully got embeddings with shape: {embeddings.shape}")

            try:
                logger.info("Adding embeddings to ChromaDB...")
                self.chunk_collection.add(
                    documents=[c.content for c in unprocessed_chunks],
                    ids=[str(c.id) for c in unprocessed_chunks],
                    embeddings=embeddings.tolist(),
                    metadatas=[
                        {
                            "document_id": str(c.document_id),
                            "topic": c.topic or "",
                            "tags": ",".join(c.tags) if c.tags else "",
                        }
                        for c in unprocessed_chunks
                    ],
                )
                logger.info("Successfully added embeddings to ChromaDB")

                # verify embeddings storage
                for chunk in unprocessed_chunks:
                    result = self.chunk_collection.get(
                        ids=[str(chunk.id)], include=["embeddings"]
                    )
                    if result and result["embeddings"]:
                        chunk.has_embedding = True
                        logger.info(f"Verified embedding for chunk {chunk.id}")
                    else:
                        logger.warning(
                            f"Failed to verify embedding for chunk {chunk.id}"
                        )
                        chunk.has_embedding = False

            except Exception as e:
                logger.error(f"Error storing embeddings in ChromaDB: {str(e)}", exc_info=True)
                for chunk in unprocessed_chunks:
                    chunk.has_embedding = False
                raise

            return chunks

        except Exception as e:
            logger.error(f"Error processing chunk embeddings: {str(e)}", exc_info=True)
            raise

    def get_chunk_embedding_by_chunk_id(self, chunk_id: int) -> Optional[List[float]]:
        """Get the corresponding embedding vector by chunk_id

        Args:
            chunk_id (int): chunk ID

        Returns:
            List[float]: embedding vector, return None if not found

        Raises:
            ValueError: when chunk_id is invalid
            Exception: other errors
        """
        try:
            if not isinstance(chunk_id, int) or chunk_id < 0:
                raise ValueError("Invalid chunk_id")

            # query from ChromaDB
            result = self.chunk_collection.get(
                ids=[str(chunk_id)], include=["embeddings"]
            )

            if not result or not result["embeddings"]:
                logger.warning(f"No embedding found for chunk {chunk_id}")
                return None

            return result["embeddings"][0]

        except Exception as e:
            logger.error(f"Error getting embedding for chunk {chunk_id}: {str(e)}")
            raise

    def get_document_embedding_by_document_id(
        self, document_id: int
    ) -> Optional[List[float]]:
        """Get the corresponding embedding vector by document_id

        Args:
            document_id (int): document ID

        Returns:
            List[float]: embedding vector, return None if not found

        Raises:
            ValueError: when document_id is invalid
            Exception: other errors
        """
        try:
            if not isinstance(document_id, int) or document_id < 0:
                raise ValueError("Invalid document_id")

            # query from ChromaDB
            result = self.document_collection.get(
                ids=[str(document_id)], include=["embeddings"]
            )

            if not result or not result["embeddings"]:
                logger.warning(f"No embedding found for document {document_id}")
                return None

            return result["embeddings"][0]

        except Exception as e:
            logger.error(
                f"Error getting embedding for document {document_id}: {str(e)}"
            )
            raise

    def _handle_dimension_mismatch(self):
        """
        Handle dimension mismatch between current embedding model and ChromaDB collections
        This method will reinitialize ChromaDB collections with the new dimension
        """
        from lpm_kernel.file_data.chroma_utils import reinitialize_chroma_collections
        
        logger.warning(f"Detected dimension mismatch in ChromaDB collections. Reinitializing with dimension {self.dimension}")
        # Log the operation for better debugging
        logger.info(f"Calling reinitialize_chroma_collections with dimension {self.dimension}")
        
        try:
            success = reinitialize_chroma_collections(self.dimension)
            
            if success:
                logger.info(f"Successfully reinitialized ChromaDB collections with dimension {self.dimension}")
                # Refresh collection references
                try:
                    self.document_collection = self.client.get_collection(name="documents")
                    self.chunk_collection = self.client.get_collection(name="document_chunks")
                    
                    # Double-check dimensions after refresh
                    doc_dimension = self.document_collection.metadata.get("dimension")
                    chunk_dimension = self.chunk_collection.metadata.get("dimension")
                    
                    if doc_dimension != self.dimension or chunk_dimension != self.dimension:
                        logger.error(f"Dimension mismatch after refresh: documents={doc_dimension}, chunks={chunk_dimension}, expected={self.dimension}")
                        raise RuntimeError(f"Failed to handle dimension mismatch: collections have incorrect dimensions after reinitialization")
                        
                except Exception as e:
                    logger.error(f"Error refreshing collection references: {str(e)}", exc_info=True)
                    raise RuntimeError(f"Failed to refresh ChromaDB collections after reinitialization: {str(e)}")
            else:
                logger.error("Failed to reinitialize ChromaDB collections")
                raise RuntimeError("Failed to handle dimension mismatch in ChromaDB collections")
        except Exception as e:
            logger.error(f"Error during dimension mismatch handling: {str(e)}", exc_info=True)
            raise RuntimeError(f"Failed to handle dimension mismatch in ChromaDB collections: {str(e)}")
    
    def search_similar_chunks(
        self, query: str, limit: int = 5
    ) -> List[Tuple[ChunkDTO, float]]:
        """Search similar chunks, return list of ChunkDTO objects and their similarity scores

        Args:
            query (str): query text
            limit (int, optional): return result limit. Defaults to 5.

        Returns:
            List[Tuple[ChunkDTO, float]]: return list of (ChunkDTO, similarity score), sorted by similarity score in descending order

        Raises:
            ValueError: when query parameters are invalid
            Exception: other errors
        """
        try:
            if not query or not query.strip():
                raise ValueError("Query string cannot be empty")

            if limit < 1:
                raise ValueError("Limit must be positive")

            # calculate query text embedding
            query_embedding = self.llm_client.get_embedding([query])
            if query_embedding is None or len(query_embedding) == 0:
                raise Exception("Failed to generate embedding for query")

            # query ChromaDB
            results = self.chunk_collection.query(
                query_embeddings=[query_embedding[0].tolist()],
                n_results=limit,
                include=["documents", "metadatas", "distances"],
            )

            if not results or not results["ids"]:
                return []

            # convert results to ChunkDTO objects
            similar_chunks = []
            for i in range(len(results["ids"])):
                chunk_id = results["ids"][0][i]  # ChromaDB returns nested lists
                document_id = results["metadatas"][0][i]["document_id"]
                content = results["documents"][0][i]
                topic = results["metadatas"][0][i].get("topic", "")
                tags = (
                    results["metadatas"][0][i].get("tags", "").split(",")
                    if results["metadatas"][0][i].get("tags")
                    else []
                )

                # calculate similarity score (ChromaDB returns distances, need to convert to similarity)
                similarity_score = (
                    1 - results["distances"][0][i]
                )  # assume using Euclidean distance or cosine distance

                chunk = ChunkDTO(
                    id=int(chunk_id),
                    document_id=int(document_id),
                    content=content,
                    topic=topic,
                    tags=tags,
                    has_embedding=True,
                )

                similar_chunks.append((chunk, similarity_score))

            # sort by similarity score in descending order
            similar_chunks.sort(key=lambda x: x[1], reverse=True)

            return similar_chunks

        except ValueError as ve:
            logger.error(f"Invalid input parameters: {str(ve)}")
            raise
        except Exception as e:
            logger.error(f"Error searching similar chunks: {str(e)}")
            raise