from dataclasses import dataclass from bson.objectid import ObjectId from langchain.embeddings import CacheBackedEmbeddings from langchain_core.embeddings import Embeddings from langchain_core.stores import InMemoryStore from langchain_mongodb import MongoDBAtlasVectorSearch from pymongo import MongoClient from emdedd.Embedding import Embedding @dataclass class EmbeddingDbConnection: connection: str database: str collection: str index: str class MongoEmbedding(Embedding): db: EmbeddingDbConnection embedding: Embeddings def __init__(self, db, embedding, cache: bool = True): self.db = db if cache: self.embedding = CacheBackedEmbeddings.from_bytes_store( underlying_embeddings=embedding, document_embedding_cache=InMemoryStore(), namespace="mongo-embedding-cache" ) else: self.embedding = embedding def embedd(self, chunks, metadata: list[dict] = None): self.__store_embeddings(chunks, metadata) def __store_embeddings(self, chunks, metadata: list[dict] = None): client = MongoClient(self.db.connection) collection = client[self.db.database][self.db.collection] # collection.create_search_index( # {"definition": # {"mappings": {"dynamic": True, "fields": { # "embedding": { # "dimensions": 1536, # "similarity": "cosine", # "type": "knnVector" # }}}}, # "name": self.MONGODB_INDEX_NAME # } # ) MongoDBAtlasVectorSearch.from_texts( texts=chunks, metadatas=metadata, embedding=self.embedding, collection=collection, index_name=self.db.index ) self.__add_id_to_metadata(collection) def __add_id_to_metadata(self, collection): for document in collection.find({"metadata.id": {"$exists": "false"}}): metadata: dict = document["metadata"] if metadata is None: metadata = {} object_id: ObjectId = document["_id"] metadata["id"] = object_id.__str__() collection.update_one( filter={"_id": object_id}, update={"metadata": metadata} ) def get_vector_store(self): return MongoDBAtlasVectorSearch.from_connection_string( self.db.connection, self.db.database + "." + self.db.collection, embedding=self.embedding, index_name=self.db.index ) def search(self, query, search_type, doc_count): return self.get_vector_store().as_retriever( search_type="similarity", search_kwargs={"k": doc_count} ).get_relevant_documents(query=query)