File size: 6,041 Bytes
1161dd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json
import os
import time
from typing import List, Tuple, Dict, Optional
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from langchain.schema.document import Document
from server.constant.constants import (OPENAI_EMBEDDING_MODEL_NAME,
                                       ZHIPUAI_EMBEDDING_MODEL_NAME,
                                       CHROMA_DB_DIR, CHROMA_COLLECTION_NAME,
                                       OLLAMA_EMBEDDING_MODEL_NAME)
from server.logger.logger_config import my_logger as logger
from server.rag.index.embedder.zhipuai_embedder import ZhipuAIEmbeddings


class DocumentEmbedder:
    BATCH_SIZE = 30

    def __init__(self) -> None:
        self.llm_name = os.getenv('LLM_NAME')
        if self.llm_name == 'OpenAI':
            embeddings = OpenAIEmbeddings(
                openai_api_key=os.getenv('OPENAI_API_KEY'),
                model=OPENAI_EMBEDDING_MODEL_NAME)
        elif self.llm_name == 'ZhipuAI':
            embeddings = ZhipuAIEmbeddings(
                api_key=os.getenv('ZHIPUAI_API_KEY'),
                model=ZHIPUAI_EMBEDDING_MODEL_NAME)
        elif self.llm_name == 'Ollama':
            base_url = os.getenv('OLLAMA_BASE_URL')
            embeddings = OllamaEmbeddings(base_url=base_url,
                                          model=OLLAMA_EMBEDDING_MODEL_NAME)
        elif self.llm_name in ['DeepSeek', 'Moonshot']:
            # DeepSeek and Moonshot use ZhipuAI's Embedding API
            embeddings = ZhipuAIEmbeddings(
                api_key=os.getenv('ZHIPUAI_API_KEY'),
                model=ZHIPUAI_EMBEDDING_MODEL_NAME)
        else:
            raise ValueError(
                f"Unsupported LLM_NAME '{self.llm_name}'. Must be in ['OpenAI', 'ZhipuAI', 'Ollama', 'DeepSeek', 'Moonshot']."
            )

        collection_name = CHROMA_COLLECTION_NAME
        persist_directory = CHROMA_DB_DIR
        logger.info(
            f"[DOC_EMBEDDER] init, collection_name: '{collection_name}', persist_directory: '{persist_directory}', llm_name: '{self.llm_name}'"
        )
        collection_metadata = {"hnsw:space": "cosine"}
        self.chroma_vector = Chroma(collection_name=collection_name,
                                    embedding_function=embeddings,
                                    persist_directory=persist_directory,
                                    collection_metadata=collection_metadata)

    async def aadd_document_embedding(
        self, data: List[Tuple[int, str, List[str]]], doc_source: int
    ) -> Tuple[List[Tuple[int, int, str, int, int]], List[Tuple[int, int]]]:
        records_to_add: List[Tuple[int, int, str, int, int]] = []
        records_to_update: List[Tuple[int, int]] = []
        for item in data:
            documents_to_add: List[Document] = []
            timestamp = int(time.time())
            doc_id, url, chunk_text_vec = item
            for part_index, part_content in enumerate(chunk_text_vec):
                metadata: Dict[str, str] = {
                    "source": url,
                    "id": f"{doc_source}-{doc_id}-part{part_index}"
                }
                doc = Document(page_content=part_content, metadata=metadata)
                documents_to_add.append(doc)

            if documents_to_add:
                embedding_id_vec: List[str] = []
                for start in range(0, len(documents_to_add), self.BATCH_SIZE):
                    batch = documents_to_add[start:start + self.BATCH_SIZE]
                    ret = await self.chroma_vector.aadd_documents(batch)
                    embedding_id_vec.extend(ret)
                logger.info(
                    f"[DOC_EMBEDDER] doc_id={doc_id}, url={url}, doc_source={doc_source}, added {len(documents_to_add)} chunk parts to Chroma, embedding_id_vec={embedding_id_vec}"
                )
                records_to_add.append(
                    (doc_id, doc_source, json.dumps(embedding_id_vec),
                     timestamp, timestamp))
                records_to_update.append((timestamp, doc_id))

        return records_to_add, records_to_update

    async def aadd_local_file_embedding(self, doc_id: int, url: str,
                                        chunk_text_vec: List[str],
                                        doc_source: int) -> List[str]:
        file_documents_to_add = []
        for part_index, part_content in enumerate(chunk_text_vec):
            metadata: Dict[str, str] = {
                "source": url,
                "id": f"{doc_source}-{doc_id}-part{part_index}"
            }
            doc = Document(page_content=part_content, metadata=metadata)
            file_documents_to_add.append(doc)

        if file_documents_to_add:
            embedding_id_vec = await self.chroma_vector.aadd_documents(
                file_documents_to_add)
            logger.info(
                f"[DOC_EMBEDDER] doc_id={doc_id}, url={url}, doc_source={doc_source}, added {len(file_documents_to_add)} chunk parts to Chroma, embedding_id_vec={embedding_id_vec}"
            )
            return embedding_id_vec
        else:
            return []

    async def adelete_document_embedding(
            self, embedding_id_vec: List[str]) -> Optional[bool]:
        for start in range(0, len(embedding_id_vec), self.BATCH_SIZE):
            batch = embedding_id_vec[start:start + self.BATCH_SIZE]
            await self.chroma_vector.adelete(batch)
        logger.info(
            f"[DOC_EMBEDDER] Deleted {len(embedding_id_vec)} embeddings from Chroma."
        )

    def delete_document_embedding(self, embedding_id_vec: List[str]) -> None:
        for start in range(0, len(embedding_id_vec), self.BATCH_SIZE):
            batch = embedding_id_vec[start:start + self.BATCH_SIZE]
            self.chroma_vector.delete(batch)
        logger.info(
            f"[DOC_EMBEDDER] Deleted {len(embedding_id_vec)} embeddings from Chroma."
        )


document_embedder = DocumentEmbedder()