File size: 5,570 Bytes
869eb7d
 
 
 
 
 
593b823
869eb7d
ae95c3d
869eb7d
 
 
 
 
 
 
 
 
 
 
ae95c3d
869eb7d
593b823
869eb7d
 
 
593b823
 
 
869eb7d
 
ae95c3d
869eb7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593b823
869eb7d
 
 
 
 
 
 
 
0c3c7ed
593b823
 
869eb7d
 
593b823
0c3c7ed
593b823
869eb7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os

from dotenv import load_dotenv
from gptcache import Cache
from gptcache.manager.factory import manager_factory
from gptcache.processor.pre import get_prompt
from langchain.globals import set_debug
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank, CohereEmbeddings
from langchain_community.cache import GPTCache
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_google_genai import ChatGoogleGenerativeAI, HarmCategory, HarmBlockThreshold
from langchain_openai import ChatOpenAI

from agent.Agent import Agent
from agent.agents import deepinfra_chat, \
    together_ai_chat, groq_chat, cohere_llm
from emdedd.Embedding import Embedding
from emdedd.MongoEmbedding import EmbeddingDbConnection, MongoEmbedding
from prompt.prompt_store import PromptStore
from rag import vanilla_rag_chain, rag_chain

load_dotenv()

# set_verbose(True)
set_debug(True)


class LangChainRAG:
    embedding: Embedding
    llms: dict[str, BaseChatModel]
    retriever: BaseRetriever
    prompt_template: PromptTemplate
    config: dict
    semantic_cache: GPTCache
    prompt_store = PromptStore()

    def __init__(self, config):
        self.config = config
        self.semantic_cache = GPTCache(_init_gptcache)
        self.embedding = MongoEmbedding(
            db=EmbeddingDbConnection(
                connection=os.environ["DB_CONN_EMBED"],
                database=os.environ["MONGODB_DB_NAME_ZPL_EMBED"],
                collection="zpl-2402-cohere",
                index="knnVector-cosine-index"
            ),
            embedding=CohereEmbeddings(model="embed-multilingual-v3.0")
        )

        self.llms = {
            "gpt-4o 128k": ChatOpenAI(
                model_name="gpt-4o",
                temperature=config["temperature"],
                openai_api_key=os.environ["OPENAI_API_KEY"],
                openai_organization=os.environ["OPENAI_ORGANIZATION_ID"]
            ),
            "llama-3 70B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-70B-Instruct",
                                                       self.config["temperature"]),
            "llama-3 8B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-8B-Instruct",
                                                      self.config["temperature"]),
            "Mixtral-8x22B-Instruct-v0.1 deepinfra 32k": deepinfra_chat("mistralai/Mixtral-8x22B-Instruct-v0.1",
                                                                        self.config["temperature"]),
            "gemini-pro 128k": ChatGoogleGenerativeAI(
                model="gemini-pro",
                convert_system_message_to_human=True,
                safety_settings={
                    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
                    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                    HarmCategory.HARM_CATEGORY_DEROGATORY: HarmBlockThreshold.BLOCK_NONE,
                    HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
                },
                transport="rest",
                stopSequence=["%%%%"],
                temperature=config["temperature"],
                cache=self.semantic_cache
            ),
            "Mistral (7B) Instruct v0.3 together.ai 32k": together_ai_chat(
                model="mistralai/Mistral-7B-Instruct-v0.3",
                temperature=config["temperature"]
            ),
            "OpenHermes-2.5 Mistral 7B together.ai 32k": together_ai_chat(
                model="teknium/OpenHermes-2p5-Mistral-7B",
                temperature=config["temperature"]
            ),
            "chat_groq_llm": groq_chat("mixtral-8x7b-32768"),
            "chat_groq_llama3_70": groq_chat("llama3-70b-8192"),
            "command_r_plus": cohere_llm(),

        }

        self.retriever = ContextualCompressionRetriever(
            base_compressor=CohereRerank(model="rerank-multilingual-v3.0", top_n=os.getenv("retrieve_documents")),
            base_retriever=self.embedding.get_vector_store().as_retriever(
                search_type="similarity",
                search_kwargs={"k": config["retrieve_documents"] * 10}
            )
        )

    def get_llms(self):
        return self.llms.keys()

    async def rag_chain(self, query, llm_choice):
        print("Using " + llm_choice)

        # answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
        # answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
        # answer, check_result, context_doc = vanilla_rag_chain(
        answer, check_result, context_doc = await rag_chain(
            Agent(embedding=self.embedding, llm=self.llms[llm_choice]),
            query,
            self.config["retrieve_documents"],
            self.prompt_store.get_by_name(self.config["prompt_id"]).text,
            self.prompt_store.get_by_name(self.config["check_prompt_id"]).text
        )

        return answer, check_result, context_doc


def _init_gptcache(cache_obj: Cache, llm: str):
    cache_obj.init(
        pre_embedding_func=get_prompt,
        data_manager=manager_factory(data_dir=f"map_cache"),
        # data_manager=get_data_manager(
        #     cache_base=CacheBase("mongo", url="mongodb://localhost:27017/"),
        #     vector_base=Chromadb(
        #         persist_directory="./chromadb/cache",
        #     ),
        # )
    )