Spaces:
Paused
Paused
File size: 5,570 Bytes
869eb7d 593b823 869eb7d ae95c3d 869eb7d ae95c3d 869eb7d 593b823 869eb7d 593b823 869eb7d ae95c3d 869eb7d 593b823 869eb7d 0c3c7ed 593b823 869eb7d 593b823 0c3c7ed 593b823 869eb7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
from dotenv import load_dotenv
from gptcache import Cache
from gptcache.manager.factory import manager_factory
from gptcache.processor.pre import get_prompt
from langchain.globals import set_debug
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank, CohereEmbeddings
from langchain_community.cache import GPTCache
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_google_genai import ChatGoogleGenerativeAI, HarmCategory, HarmBlockThreshold
from langchain_openai import ChatOpenAI
from agent.Agent import Agent
from agent.agents import deepinfra_chat, \
together_ai_chat, groq_chat, cohere_llm
from emdedd.Embedding import Embedding
from emdedd.MongoEmbedding import EmbeddingDbConnection, MongoEmbedding
from prompt.prompt_store import PromptStore
from rag import vanilla_rag_chain, rag_chain
load_dotenv()
# set_verbose(True)
set_debug(True)
class LangChainRAG:
embedding: Embedding
llms: dict[str, BaseChatModel]
retriever: BaseRetriever
prompt_template: PromptTemplate
config: dict
semantic_cache: GPTCache
prompt_store = PromptStore()
def __init__(self, config):
self.config = config
self.semantic_cache = GPTCache(_init_gptcache)
self.embedding = MongoEmbedding(
db=EmbeddingDbConnection(
connection=os.environ["DB_CONN_EMBED"],
database=os.environ["MONGODB_DB_NAME_ZPL_EMBED"],
collection="zpl-2402-cohere",
index="knnVector-cosine-index"
),
embedding=CohereEmbeddings(model="embed-multilingual-v3.0")
)
self.llms = {
"gpt-4o 128k": ChatOpenAI(
model_name="gpt-4o",
temperature=config["temperature"],
openai_api_key=os.environ["OPENAI_API_KEY"],
openai_organization=os.environ["OPENAI_ORGANIZATION_ID"]
),
"llama-3 70B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-70B-Instruct",
self.config["temperature"]),
"llama-3 8B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-8B-Instruct",
self.config["temperature"]),
"Mixtral-8x22B-Instruct-v0.1 deepinfra 32k": deepinfra_chat("mistralai/Mixtral-8x22B-Instruct-v0.1",
self.config["temperature"]),
"gemini-pro 128k": ChatGoogleGenerativeAI(
model="gemini-pro",
convert_system_message_to_human=True,
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DEROGATORY: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
},
transport="rest",
stopSequence=["%%%%"],
temperature=config["temperature"],
cache=self.semantic_cache
),
"Mistral (7B) Instruct v0.3 together.ai 32k": together_ai_chat(
model="mistralai/Mistral-7B-Instruct-v0.3",
temperature=config["temperature"]
),
"OpenHermes-2.5 Mistral 7B together.ai 32k": together_ai_chat(
model="teknium/OpenHermes-2p5-Mistral-7B",
temperature=config["temperature"]
),
"chat_groq_llm": groq_chat("mixtral-8x7b-32768"),
"chat_groq_llama3_70": groq_chat("llama3-70b-8192"),
"command_r_plus": cohere_llm(),
}
self.retriever = ContextualCompressionRetriever(
base_compressor=CohereRerank(model="rerank-multilingual-v3.0", top_n=os.getenv("retrieve_documents")),
base_retriever=self.embedding.get_vector_store().as_retriever(
search_type="similarity",
search_kwargs={"k": config["retrieve_documents"] * 10}
)
)
def get_llms(self):
return self.llms.keys()
async def rag_chain(self, query, llm_choice):
print("Using " + llm_choice)
# answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
# answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
# answer, check_result, context_doc = vanilla_rag_chain(
answer, check_result, context_doc = await rag_chain(
Agent(embedding=self.embedding, llm=self.llms[llm_choice]),
query,
self.config["retrieve_documents"],
self.prompt_store.get_by_name(self.config["prompt_id"]).text,
self.prompt_store.get_by_name(self.config["check_prompt_id"]).text
)
return answer, check_result, context_doc
def _init_gptcache(cache_obj: Cache, llm: str):
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=manager_factory(data_dir=f"map_cache"),
# data_manager=get_data_manager(
# cache_base=CacheBase("mongo", url="mongodb://localhost:27017/"),
# vector_base=Chromadb(
# persist_directory="./chromadb/cache",
# ),
# )
)
|