Spaces:
Paused
Paused
| from dataclasses import dataclass | |
| from bson.objectid import ObjectId | |
| from langchain.embeddings import CacheBackedEmbeddings | |
| from langchain_core.embeddings import Embeddings | |
| from langchain_core.stores import InMemoryStore | |
| from langchain_mongodb import MongoDBAtlasVectorSearch | |
| from pymongo import MongoClient | |
| from emdedd.Embedding import Embedding | |
| class EmbeddingDbConnection: | |
| connection: str | |
| database: str | |
| collection: str | |
| index: str | |
| class MongoEmbedding(Embedding): | |
| db: EmbeddingDbConnection | |
| embedding: Embeddings | |
| def __init__(self, db, embedding, cache: bool = True): | |
| self.db = db | |
| if cache: | |
| self.embedding = CacheBackedEmbeddings.from_bytes_store( | |
| underlying_embeddings=embedding, | |
| document_embedding_cache=InMemoryStore(), | |
| namespace="mongo-embedding-cache" | |
| ) | |
| else: | |
| self.embedding = embedding | |
| def embedd(self, chunks, metadata: list[dict] = None): | |
| self.__store_embeddings(chunks, metadata) | |
| def __store_embeddings(self, chunks, metadata: list[dict] = None): | |
| client = MongoClient(self.db.connection) | |
| collection = client[self.db.database][self.db.collection] | |
| # collection.create_search_index( | |
| # {"definition": | |
| # {"mappings": {"dynamic": True, "fields": { | |
| # "embedding": { | |
| # "dimensions": 1536, | |
| # "similarity": "cosine", | |
| # "type": "knnVector" | |
| # }}}}, | |
| # "name": self.MONGODB_INDEX_NAME | |
| # } | |
| # ) | |
| MongoDBAtlasVectorSearch.from_texts( | |
| texts=chunks, | |
| metadatas=metadata, | |
| embedding=self.embedding, | |
| collection=collection, | |
| index_name=self.db.index | |
| ) | |
| self.__add_id_to_metadata(collection) | |
| def __add_id_to_metadata(self, collection): | |
| for document in collection.find({"metadata.id": {"$exists": "false"}}): | |
| metadata: dict = document["metadata"] | |
| if metadata is None: metadata = {} | |
| object_id: ObjectId = document["_id"] | |
| metadata["id"] = object_id.__str__() | |
| collection.update_one( | |
| filter={"_id": object_id}, | |
| update={"metadata": metadata} | |
| ) | |
| def get_vector_store(self): | |
| return MongoDBAtlasVectorSearch.from_connection_string( | |
| self.db.connection, | |
| self.db.database + "." + self.db.collection, | |
| embedding=self.embedding, | |
| index_name=self.db.index | |
| ) | |
| def search(self, query, search_type, doc_count): | |
| return self.get_vector_store().as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": doc_count} | |
| ).get_relevant_documents(query=query) | |