trykopy / emdedd /MongoEmbedding.py
Pavol Liška
v1-fix
593b823
raw
history blame
2.91 kB
from dataclasses import dataclass
from bson.objectid import ObjectId
from langchain.embeddings import CacheBackedEmbeddings
from langchain_core.embeddings import Embeddings
from langchain_core.stores import InMemoryStore
from langchain_mongodb import MongoDBAtlasVectorSearch
from pymongo import MongoClient
from emdedd.Embedding import Embedding
@dataclass
class EmbeddingDbConnection:
connection: str
database: str
collection: str
index: str
class MongoEmbedding(Embedding):
db: EmbeddingDbConnection
embedding: Embeddings
def __init__(self, db, embedding, cache: bool = True):
self.db = db
if cache:
self.embedding = CacheBackedEmbeddings.from_bytes_store(
underlying_embeddings=embedding,
document_embedding_cache=InMemoryStore(),
namespace="mongo-embedding-cache"
)
else:
self.embedding = embedding
def embedd(self, chunks, metadata: list[dict] = None):
self.__store_embeddings(chunks, metadata)
def __store_embeddings(self, chunks, metadata: list[dict] = None):
client = MongoClient(self.db.connection)
collection = client[self.db.database][self.db.collection]
# collection.create_search_index(
# {"definition":
# {"mappings": {"dynamic": True, "fields": {
# "embedding": {
# "dimensions": 1536,
# "similarity": "cosine",
# "type": "knnVector"
# }}}},
# "name": self.MONGODB_INDEX_NAME
# }
# )
MongoDBAtlasVectorSearch.from_texts(
texts=chunks,
metadatas=metadata,
embedding=self.embedding,
collection=collection,
index_name=self.db.index
)
self.__add_id_to_metadata(collection)
def __add_id_to_metadata(self, collection):
for document in collection.find({"metadata.id": {"$exists": "false"}}):
metadata: dict = document["metadata"]
if metadata is None: metadata = {}
object_id: ObjectId = document["_id"]
metadata["id"] = object_id.__str__()
collection.update_one(
filter={"_id": object_id},
update={"metadata": metadata}
)
def get_vector_store(self):
return MongoDBAtlasVectorSearch.from_connection_string(
self.db.connection,
self.db.database + "." + self.db.collection,
embedding=self.embedding,
index_name=self.db.index
)
def search(self, query, search_type, doc_count):
return self.get_vector_store().as_retriever(
search_type="similarity",
search_kwargs={"k": doc_count}
).get_relevant_documents(query=query)