File size: 2,905 Bytes
869eb7d
 
593b823
869eb7d
 
 
593b823
869eb7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593b823
869eb7d
 
593b823
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from dataclasses import dataclass

from bson.objectid import ObjectId
from langchain.embeddings import CacheBackedEmbeddings
from langchain_core.embeddings import Embeddings
from langchain_core.stores import InMemoryStore
from langchain_mongodb import MongoDBAtlasVectorSearch
from pymongo import MongoClient

from emdedd.Embedding import Embedding


@dataclass
class EmbeddingDbConnection:
    connection: str
    database: str
    collection: str
    index: str


class MongoEmbedding(Embedding):
    db: EmbeddingDbConnection
    embedding: Embeddings

    def __init__(self, db, embedding, cache: bool = True):
        self.db = db
        if cache:
            self.embedding = CacheBackedEmbeddings.from_bytes_store(
                underlying_embeddings=embedding,
                document_embedding_cache=InMemoryStore(),
                namespace="mongo-embedding-cache"
            )
        else:
            self.embedding = embedding

    def embedd(self, chunks, metadata: list[dict] = None):
        self.__store_embeddings(chunks, metadata)

    def __store_embeddings(self, chunks, metadata: list[dict] = None):
        client = MongoClient(self.db.connection)
        collection = client[self.db.database][self.db.collection]

        # collection.create_search_index(
        #     {"definition":
        #          {"mappings": {"dynamic": True, "fields": {
        #              "embedding": {
        #                  "dimensions": 1536,
        #                  "similarity": "cosine",
        #                  "type": "knnVector"
        #              }}}},
        #      "name": self.MONGODB_INDEX_NAME
        #      }
        # )

        MongoDBAtlasVectorSearch.from_texts(
            texts=chunks,
            metadatas=metadata,
            embedding=self.embedding,
            collection=collection,
            index_name=self.db.index
        )

        self.__add_id_to_metadata(collection)

    def __add_id_to_metadata(self, collection):
        for document in collection.find({"metadata.id": {"$exists": "false"}}):
            metadata: dict = document["metadata"]
            if metadata is None: metadata = {}
            object_id: ObjectId = document["_id"]
            metadata["id"] = object_id.__str__()
            collection.update_one(
                filter={"_id": object_id},
                update={"metadata": metadata}
            )

    def get_vector_store(self):
        return MongoDBAtlasVectorSearch.from_connection_string(
            self.db.connection,
            self.db.database + "." + self.db.collection,
            embedding=self.embedding,
            index_name=self.db.index
        )

    def search(self, query, search_type, doc_count):
        return self.get_vector_store().as_retriever(
            search_type="similarity",
            search_kwargs={"k": doc_count}
        ).get_relevant_documents(query=query)