import os from dotenv import load_dotenv from gptcache import Cache from gptcache.manager.factory import manager_factory from gptcache.processor.pre import get_prompt from langchain.globals import set_debug from langchain.retrievers import ContextualCompressionRetriever from langchain_cohere import CohereRerank, CohereEmbeddings from langchain_community.cache import GPTCache from langchain_core.language_models import BaseChatModel from langchain_core.prompts import PromptTemplate from langchain_core.retrievers import BaseRetriever from langchain_google_genai import ChatGoogleGenerativeAI, HarmCategory, HarmBlockThreshold from langchain_openai import ChatOpenAI from agent.Agent import Agent from agent.agents import deepinfra_chat, \ together_ai_chat, groq_chat, cohere_llm from emdedd.Embedding import Embedding from emdedd.MongoEmbedding import EmbeddingDbConnection, MongoEmbedding from prompt.prompt_store import PromptStore from rag import vanilla_rag_chain, rag_chain load_dotenv() # set_verbose(True) set_debug(True) class LangChainRAG: embedding: Embedding llms: dict[str, BaseChatModel] retriever: BaseRetriever prompt_template: PromptTemplate config: dict semantic_cache: GPTCache prompt_store = PromptStore() def __init__(self, config): self.config = config self.semantic_cache = GPTCache(_init_gptcache) self.embedding = MongoEmbedding( db=EmbeddingDbConnection( connection=os.environ["DB_CONN_EMBED"], database=os.environ["MONGODB_DB_NAME_ZPL_EMBED"], collection="zpl-2402-cohere", index="knnVector-cosine-index" ), embedding=CohereEmbeddings(model="embed-multilingual-v3.0") ) self.llms = { "gpt-4o 128k": ChatOpenAI( model_name="gpt-4o", temperature=config["temperature"], openai_api_key=os.environ["OPENAI_API_KEY"], openai_organization=os.environ["OPENAI_ORGANIZATION_ID"] ), "llama-3 70B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-70B-Instruct", self.config["temperature"]), "llama-3 8B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-8B-Instruct", self.config["temperature"]), "Mixtral-8x22B-Instruct-v0.1 deepinfra 32k": deepinfra_chat("mistralai/Mixtral-8x22B-Instruct-v0.1", self.config["temperature"]), "gemini-pro 128k": ChatGoogleGenerativeAI( model="gemini-pro", convert_system_message_to_human=True, safety_settings={ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DEROGATORY: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, }, transport="rest", stopSequence=["%%%%"], temperature=config["temperature"], cache=self.semantic_cache ), "Mistral (7B) Instruct v0.3 together.ai 32k": together_ai_chat( model="mistralai/Mistral-7B-Instruct-v0.3", temperature=config["temperature"] ), "OpenHermes-2.5 Mistral 7B together.ai 32k": together_ai_chat( model="teknium/OpenHermes-2p5-Mistral-7B", temperature=config["temperature"] ), "chat_groq_llm": groq_chat("mixtral-8x7b-32768"), "chat_groq_llama3_70": groq_chat("llama3-70b-8192"), "command_r_plus": cohere_llm(), } self.retriever = ContextualCompressionRetriever( base_compressor=CohereRerank(model="rerank-multilingual-v3.0", top_n=os.getenv("retrieve_documents")), base_retriever=self.embedding.get_vector_store().as_retriever( search_type="similarity", search_kwargs={"k": config["retrieve_documents"] * 10} ) ) def get_llms(self): return self.llms.keys() async def rag_chain(self, query, llm_choice): print("Using " + llm_choice) # answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde( # answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever( # answer, check_result, context_doc = vanilla_rag_chain( answer, check_result, context_doc = await rag_chain( Agent(embedding=self.embedding, llm=self.llms[llm_choice]), query, self.config["retrieve_documents"], self.prompt_store.get_by_name(self.config["prompt_id"]).text, self.prompt_store.get_by_name(self.config["check_prompt_id"]).text ) return answer, check_result, context_doc def _init_gptcache(cache_obj: Cache, llm: str): cache_obj.init( pre_embedding_func=get_prompt, data_manager=manager_factory(data_dir=f"map_cache"), # data_manager=get_data_manager( # cache_base=CacheBase("mongo", url="mongodb://localhost:27017/"), # vector_base=Chromadb( # persist_directory="./chromadb/cache", # ), # ) )