Spaces:
Runtime error
Runtime error
| import os | |
| from dotenv import load_dotenv | |
| import re | |
| from loguru import logger | |
| from langchain import PromptTemplate, LLMChain | |
| from langchain.agents import initialize_agent, Tool | |
| from langchain.chat_models import AzureChatOpenAI | |
| from langchain.agents import ZeroShotAgent, AgentExecutor | |
| from langchain.chains.conversation.memory import ConversationBufferMemory | |
| from langchain.callbacks import get_openai_callback | |
| from langchain.chains.llm import LLMChain | |
| from langchain.llms import AzureOpenAI | |
| from langchain.prompts import PromptTemplate | |
| from utils import lctool_search_allo_api, cut_dialogue_history | |
| from prompts.mod_prompt import MOD_PROMPT, FALLBACK_MESSAGE, MOD_PROMPT_OPTIM_v2 | |
| from prompts.ans_prompt import ANS_PREFIX, ANS_FORMAT_INSTRUCTIONS, ANS_SUFFIX, ANS_CHAIN_PROMPT | |
| from prompts.reco_prompt import RECO_PREFIX, RECO_FORMAT_INSTRUCTIONS, RECO_SUFFIX, NO_RECO_OUTPUT | |
| load_dotenv() | |
| class AllofreshChatbot(): | |
| def __init__(self, debug=False): | |
| self.ans_memory = None | |
| self.debug = debug | |
| # init llm | |
| self.llms = self.init_llm() | |
| # init moderation chain | |
| self.mod_chain = self.init_mod_chain() | |
| # init answering agent | |
| self.ans_memory = self.init_ans_memory() | |
| self.ans_agent = self.init_ans_agent() | |
| self.ans_chain = self.init_ans_chain() | |
| # init reco agent | |
| self.reco_agent = self.init_reco_agent() | |
| def init_llm(self): | |
| return { | |
| "gpt-4": AzureChatOpenAI( | |
| temperature=0, | |
| deployment_name = os.getenv("DEPLOYMENT_NAME_GPT4"), | |
| model_name = os.getenv("MODEL_NAME_GPT4"), | |
| openai_api_type = os.getenv("OPENAI_API_TYPE"), | |
| openai_api_base = os.getenv("OPENAI_API_BASE"), | |
| openai_api_version = os.getenv("OPENAI_API_VERSION"), | |
| openai_api_key = os.getenv("OPENAI_API_KEY"), | |
| openai_organization = os.getenv("OPENAI_ORGANIZATION") | |
| ), | |
| "gpt-3.5": AzureChatOpenAI( | |
| temperature=0, | |
| deployment_name = os.getenv("DEPLOYMENT_NAME_GPT3.5"), | |
| model_name = os.getenv("MODEL_NAME_GPT3.5"), | |
| openai_api_type = os.getenv("OPENAI_API_TYPE"), | |
| openai_api_base = os.getenv("OPENAI_API_BASE"), | |
| openai_api_version = os.getenv("OPENAI_API_VERSION"), | |
| openai_api_key = os.getenv("OPENAI_API_KEY"), | |
| openai_organization = os.getenv("OPENAI_ORGANIZATION") | |
| ), | |
| "gpt-3": AzureOpenAI( | |
| temperature=0, | |
| deployment_name = os.getenv("DEPLOYMENT_NAME_GPT3"), | |
| model_name = os.getenv("MODEL_NAME_GPT3"), | |
| openai_api_base = os.getenv("OPENAI_API_BASE"), | |
| openai_api_key = os.getenv("OPENAI_API_KEY"), | |
| openai_organization = os.getenv("OPENAI_ORGANIZATION") | |
| ), | |
| } | |
| def init_mod_chain(self): | |
| mod_prompt = PromptTemplate( | |
| template=MOD_PROMPT_OPTIM_v2, | |
| input_variables=["input"] | |
| ) | |
| # Define the first LLM chain with the shared AzureOpenAI object and prompt template | |
| return LLMChain(llm=self.llms["gpt-4"], prompt=mod_prompt) | |
| def init_ans_memory(self): | |
| return ConversationBufferMemory(memory_key="chat_history", output_key='output') | |
| def init_ans_agent(self): | |
| ans_tools = [ | |
| Tool( | |
| name="Product Search", | |
| func=lctool_search_allo_api, | |
| description=""" | |
| To search for products in Allofresh's Database. | |
| Always use this to verify product names. | |
| Outputs product names and prices | |
| """ | |
| ) | |
| ] | |
| return initialize_agent( | |
| ans_tools, | |
| self.llms["gpt-4"], | |
| agent="conversational-react-description", | |
| verbose=self.debug, | |
| return_intermediate_steps=True, | |
| agent_kwargs={ | |
| 'prefix': ANS_PREFIX, | |
| # 'format_instructions': ANS_FORMAT_INSTRUCTIONS, # only needed for below gpt-4 | |
| 'suffix': ANS_SUFFIX | |
| } | |
| ) | |
| def init_ans_chain(self): | |
| ans_prompt = PromptTemplate( | |
| template=ANS_CHAIN_PROMPT, | |
| input_variables=["input", "chat_history"] | |
| ) | |
| # Define the first LLM chain with the shared AzureOpenAI object and prompt template | |
| return LLMChain(llm=self.llms["gpt-4"], prompt=ans_prompt) | |
| def init_reco_agent(self): | |
| reco_tools = [ | |
| Tool( | |
| name="Product Search", | |
| func=lctool_search_allo_api, | |
| description=""" | |
| To search for products in Allofresh's Database. | |
| Always use this to verify product names. | |
| Outputs product names and prices | |
| """ | |
| ), | |
| Tool( | |
| name="No Recommendation", | |
| func=lambda x: "No recommendation", | |
| description=""" | |
| Use this if based on the context you don't need to recommend any products | |
| """ | |
| ) | |
| ] | |
| prompt = ZeroShotAgent.create_prompt( | |
| reco_tools, | |
| prefix=RECO_PREFIX, | |
| format_instructions=RECO_FORMAT_INSTRUCTIONS, | |
| suffix=RECO_SUFFIX, | |
| input_variables=["input", "agent_scratchpad"] | |
| ) | |
| llm_chain_reco = LLMChain(llm=self.llms["gpt-4"], prompt=prompt) | |
| agent_reco = ZeroShotAgent(llm_chain=llm_chain_reco, allowed_tools=[tool.name for tool in reco_tools]) | |
| return AgentExecutor.from_agent_and_tools(agent=agent_reco, tools=reco_tools, verbose=self.debug) | |
| def answer(self, query): | |
| # moderate | |
| mod_verdict = self.mod_chain.run({"query": query}) | |
| # if pass moderation | |
| if mod_verdict == "True": | |
| # answer question | |
| answer = self.ans_pipeline(query) | |
| # recommend | |
| reco = self.reco_agent.run({"input": self.ans_agent.memory.buffer}) | |
| if len(reco) > 0: | |
| self.ans_agent.memory.chat_memory.add_ai_message(reco) | |
| # construct output | |
| return (answer, reco) | |
| else: | |
| return ( | |
| FALLBACK_MESSAGE, | |
| None | |
| ) | |
| def answer_optim_v1(self, query, chat_history): | |
| """ | |
| We plugged off the tools from the 'answering' component and replaced it with a simple chain | |
| """ | |
| # moderate | |
| mod_verdict = self.mod_chain.run({"input": query}) | |
| # if pass moderation | |
| if mod_verdict == "True": | |
| # answer question | |
| return self.ans_chain.run({"input": query, "chat_history": str(chat_history)}) | |
| return FALLBACK_MESSAGE | |
| def answer_optim_v2(self, query, chat_history): | |
| """ | |
| We plugged off the tools from the 'answering' component and replaced it with a simple chain | |
| """ | |
| # moderate | |
| mod_verdict = self.mod_chain.run({"input": query}) | |
| llm_input = {"input": query, "chat_history": str(chat_history)} | |
| logger.info(f"mod verdict: {mod_verdict}") | |
| # if no need to access knowledge base | |
| if mod_verdict == "ANS_CHAIN": | |
| # answer question | |
| return self.ans_chain.run(llm_input) | |
| # if need to access knowledge base | |
| elif mod_verdict == "ANS_AGENT": | |
| res = self.ans_agent(llm_input) | |
| return res['output'].replace("\\", "/") | |
| return FALLBACK_MESSAGE | |
| def reco_optim_v1(self, chat_history): | |
| reco = self.reco_agent.run({"input": chat_history}) | |
| # filter out reco (str) to only contain alphabeticals | |
| return reco if reco != NO_RECO_OUTPUT else None |