trykopy / rag_langchain.py
Pavol Liška
async
0c3c7ed
raw
history blame
5.57 kB
import os
from dotenv import load_dotenv
from gptcache import Cache
from gptcache.manager.factory import manager_factory
from gptcache.processor.pre import get_prompt
from langchain.globals import set_debug
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank, CohereEmbeddings
from langchain_community.cache import GPTCache
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_google_genai import ChatGoogleGenerativeAI, HarmCategory, HarmBlockThreshold
from langchain_openai import ChatOpenAI
from agent.Agent import Agent
from agent.agents import deepinfra_chat, \
together_ai_chat, groq_chat, cohere_llm
from emdedd.Embedding import Embedding
from emdedd.MongoEmbedding import EmbeddingDbConnection, MongoEmbedding
from prompt.prompt_store import PromptStore
from rag import vanilla_rag_chain, rag_chain
load_dotenv()
# set_verbose(True)
set_debug(True)
class LangChainRAG:
embedding: Embedding
llms: dict[str, BaseChatModel]
retriever: BaseRetriever
prompt_template: PromptTemplate
config: dict
semantic_cache: GPTCache
prompt_store = PromptStore()
def __init__(self, config):
self.config = config
self.semantic_cache = GPTCache(_init_gptcache)
self.embedding = MongoEmbedding(
db=EmbeddingDbConnection(
connection=os.environ["DB_CONN_EMBED"],
database=os.environ["MONGODB_DB_NAME_ZPL_EMBED"],
collection="zpl-2402-cohere",
index="knnVector-cosine-index"
),
embedding=CohereEmbeddings(model="embed-multilingual-v3.0")
)
self.llms = {
"gpt-4o 128k": ChatOpenAI(
model_name="gpt-4o",
temperature=config["temperature"],
openai_api_key=os.environ["OPENAI_API_KEY"],
openai_organization=os.environ["OPENAI_ORGANIZATION_ID"]
),
"llama-3 70B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-70B-Instruct",
self.config["temperature"]),
"llama-3 8B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-8B-Instruct",
self.config["temperature"]),
"Mixtral-8x22B-Instruct-v0.1 deepinfra 32k": deepinfra_chat("mistralai/Mixtral-8x22B-Instruct-v0.1",
self.config["temperature"]),
"gemini-pro 128k": ChatGoogleGenerativeAI(
model="gemini-pro",
convert_system_message_to_human=True,
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DEROGATORY: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
},
transport="rest",
stopSequence=["%%%%"],
temperature=config["temperature"],
cache=self.semantic_cache
),
"Mistral (7B) Instruct v0.3 together.ai 32k": together_ai_chat(
model="mistralai/Mistral-7B-Instruct-v0.3",
temperature=config["temperature"]
),
"OpenHermes-2.5 Mistral 7B together.ai 32k": together_ai_chat(
model="teknium/OpenHermes-2p5-Mistral-7B",
temperature=config["temperature"]
),
"chat_groq_llm": groq_chat("mixtral-8x7b-32768"),
"chat_groq_llama3_70": groq_chat("llama3-70b-8192"),
"command_r_plus": cohere_llm(),
}
self.retriever = ContextualCompressionRetriever(
base_compressor=CohereRerank(model="rerank-multilingual-v3.0", top_n=os.getenv("retrieve_documents")),
base_retriever=self.embedding.get_vector_store().as_retriever(
search_type="similarity",
search_kwargs={"k": config["retrieve_documents"] * 10}
)
)
def get_llms(self):
return self.llms.keys()
async def rag_chain(self, query, llm_choice):
print("Using " + llm_choice)
# answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
# answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
# answer, check_result, context_doc = vanilla_rag_chain(
answer, check_result, context_doc = await rag_chain(
Agent(embedding=self.embedding, llm=self.llms[llm_choice]),
query,
self.config["retrieve_documents"],
self.prompt_store.get_by_name(self.config["prompt_id"]).text,
self.prompt_store.get_by_name(self.config["check_prompt_id"]).text
)
return answer, check_result, context_doc
def _init_gptcache(cache_obj: Cache, llm: str):
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=manager_factory(data_dir=f"map_cache"),
# data_manager=get_data_manager(
# cache_base=CacheBase("mongo", url="mongodb://localhost:27017/"),
# vector_base=Chromadb(
# persist_directory="./chromadb/cache",
# ),
# )
)