Spaces:
Paused
Paused
| import os | |
| from dotenv import load_dotenv | |
| from gptcache import Cache | |
| from gptcache.manager.factory import manager_factory | |
| from gptcache.processor.pre import get_prompt | |
| from langchain.globals import set_debug | |
| from langchain.retrievers import ContextualCompressionRetriever | |
| from langchain_cohere import CohereRerank, CohereEmbeddings | |
| from langchain_community.cache import GPTCache | |
| from langchain_core.language_models import BaseChatModel | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.retrievers import BaseRetriever | |
| from langchain_google_genai import ChatGoogleGenerativeAI, HarmCategory, HarmBlockThreshold | |
| from langchain_openai import ChatOpenAI | |
| from agent.Agent import Agent | |
| from agent.agents import deepinfra_chat, \ | |
| together_ai_chat, groq_chat, cohere_llm | |
| from emdedd.Embedding import Embedding | |
| from emdedd.MongoEmbedding import EmbeddingDbConnection, MongoEmbedding | |
| from prompt.prompt_store import PromptStore | |
| from rag import vanilla_rag_chain, rag_chain | |
| load_dotenv() | |
| # set_verbose(True) | |
| set_debug(True) | |
| class LangChainRAG: | |
| embedding: Embedding | |
| llms: dict[str, BaseChatModel] | |
| retriever: BaseRetriever | |
| prompt_template: PromptTemplate | |
| config: dict | |
| semantic_cache: GPTCache | |
| prompt_store = PromptStore() | |
| def __init__(self, config): | |
| self.config = config | |
| self.semantic_cache = GPTCache(_init_gptcache) | |
| self.embedding = MongoEmbedding( | |
| db=EmbeddingDbConnection( | |
| connection=os.environ["DB_CONN_EMBED"], | |
| database=os.environ["MONGODB_DB_NAME_ZPL_EMBED"], | |
| collection="zpl-2402-cohere", | |
| index="knnVector-cosine-index" | |
| ), | |
| embedding=CohereEmbeddings(model="embed-multilingual-v3.0") | |
| ) | |
| self.llms = { | |
| "gpt-4o 128k": ChatOpenAI( | |
| model_name="gpt-4o", | |
| temperature=config["temperature"], | |
| openai_api_key=os.environ["OPENAI_API_KEY"], | |
| openai_organization=os.environ["OPENAI_ORGANIZATION_ID"] | |
| ), | |
| "llama-3 70B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-70B-Instruct", | |
| self.config["temperature"]), | |
| "llama-3 8B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-8B-Instruct", | |
| self.config["temperature"]), | |
| "Mixtral-8x22B-Instruct-v0.1 deepinfra 32k": deepinfra_chat("mistralai/Mixtral-8x22B-Instruct-v0.1", | |
| self.config["temperature"]), | |
| "gemini-pro 128k": ChatGoogleGenerativeAI( | |
| model="gemini-pro", | |
| convert_system_message_to_human=True, | |
| safety_settings={ | |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_DEROGATORY: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, | |
| }, | |
| transport="rest", | |
| stopSequence=["%%%%"], | |
| temperature=config["temperature"], | |
| cache=self.semantic_cache | |
| ), | |
| "Mistral (7B) Instruct v0.3 together.ai 32k": together_ai_chat( | |
| model="mistralai/Mistral-7B-Instruct-v0.3", | |
| temperature=config["temperature"] | |
| ), | |
| "OpenHermes-2.5 Mistral 7B together.ai 32k": together_ai_chat( | |
| model="teknium/OpenHermes-2p5-Mistral-7B", | |
| temperature=config["temperature"] | |
| ), | |
| "chat_groq_llm": groq_chat("mixtral-8x7b-32768"), | |
| "chat_groq_llama3_70": groq_chat("llama3-70b-8192"), | |
| "command_r_plus": cohere_llm(), | |
| } | |
| self.retriever = ContextualCompressionRetriever( | |
| base_compressor=CohereRerank(model="rerank-multilingual-v3.0", top_n=os.getenv("retrieve_documents")), | |
| base_retriever=self.embedding.get_vector_store().as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": config["retrieve_documents"] * 10} | |
| ) | |
| ) | |
| def get_llms(self): | |
| return self.llms.keys() | |
| async def rag_chain(self, query, llm_choice): | |
| print("Using " + llm_choice) | |
| # answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde( | |
| # answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever( | |
| # answer, check_result, context_doc = vanilla_rag_chain( | |
| answer, check_result, context_doc = await rag_chain( | |
| Agent(embedding=self.embedding, llm=self.llms[llm_choice]), | |
| query, | |
| self.config["retrieve_documents"], | |
| self.prompt_store.get_by_name(self.config["prompt_id"]).text, | |
| self.prompt_store.get_by_name(self.config["check_prompt_id"]).text | |
| ) | |
| return answer, check_result, context_doc | |
| def _init_gptcache(cache_obj: Cache, llm: str): | |
| cache_obj.init( | |
| pre_embedding_func=get_prompt, | |
| data_manager=manager_factory(data_dir=f"map_cache"), | |
| # data_manager=get_data_manager( | |
| # cache_base=CacheBase("mongo", url="mongodb://localhost:27017/"), | |
| # vector_base=Chromadb( | |
| # persist_directory="./chromadb/cache", | |
| # ), | |
| # ) | |
| ) | |