DareAlly-Assistant / routerchain.py
Oluwadamilare EZEKIEL
last commit
8897c58
from langchain.chains.conversation.base import ConversationChain
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
from langchain.chains.router.llm_router import RouterOutputParser
from langchain.chains.router.multi_prompt import MultiPromptChain
from langchain.memory import (
ConversationSummaryMemory
)
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from events_loader import faiss_db
import streamlit as st
import os
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
# Initialize a OPENAI Gpt3.5 Chat Model
chat_openai = ChatOpenAI(api_key=OPENAI_API_KEY, model_name="gpt-3.5-turbo", temperature=0, streaming=True)
memory = ConversationSummaryMemory(llm = chat_openai)
def retrieve_events(question):
try:
retriever = faiss_db.similarity_search(question, k=5)
relevant_events = [event.page_content for event in retriever]
relevant_events_str = "".join(relevant_events)
except Exception as e:
st.error(f"Error loading faiss db: {str(e)}")
relevant_events_str = ""
return relevant_events_str
class PromptFactory:
def __init__(self, relevant_event_str) -> None:
self.relevant_event_str = relevant_event_str
self.house_events_template = f"""You are an intelligent assistant specialized in managing house events. \
You provide detailed answers about the events happening in the house. \
You summarize the events accurately and provide relevant details.
When a user ask anything related to house, this is the house they are referring to.
If you have not seen any object or do not have a knowledge of where the event occured. just let the user know. rather than giving incorrect information.
Here are relevant events: {self.relevant_event_str}.
Here is a question:
{{input}}"""
self.weather_template = """You are a weather expert. \
You provide real-time weather updates and detailed weather information. \
You ensure that the weather data is accurate and up-to-date.
Here is a question:
{input}"""
self.time_template = """You are a time-keeping expert. \
You provide accurate and real-time information about the current time. \
You ensure that the time data is precise for the requested location.
Here is a question:
{input}"""
self.date_template = """You are a date and calendar expert. \
You provide accurate and real-time information about the current date. \
You ensure that the date data is correct for the requested location.
Here is a question:
{input}"""
self.general_template = """You are a knowledgeable assistant. \
You answer a wide range of general questions with detailed and accurate information. \
You ensure that the answers are clear and helpful.
Here is a question:
{input}"""
self.prompt_infos = [
{
'name': 'house events assistant',
'description': 'Good for questions about house events and managing house activities',
'prompt_template': self.house_events_template
},
{
'name': 'weather expert',
'description': 'Good for providing real-time weather updates and detailed weather information',
'prompt_template': self.weather_template
},
{
'name': 'time-keeping expert',
'description': 'Good for providing accurate and real-time information about the current time',
'prompt_template': self.time_template
},
{
'name': 'date and calendar expert',
'description': 'Good for providing accurate and real-time information about the current date',
'prompt_template': self.date_template
},
{
'name': 'general knowledge assistant',
'description': 'Good for answering a wide range of general questions with detailed and accurate information',
'prompt_template': self.general_template
}
]
def destination_chain(prompt_factory):
# I will store all the chains that would be available to the model in a dictionary
destination_chains = {}
# for each persona, I will create a simple chain for them using openai and the associated prompt.
for p_info in prompt_factory.prompt_infos:
# keep track of the name of the persona e-g, house expert
name = p_info['name']
# Extract the associating prompt template
prompt_template = p_info['prompt_template']
#define a simple chain for the persona
chain = LLMChain(
llm=chat_openai,
prompt=PromptTemplate(template=prompt_template, input_variables=['input']), memory=memory)
#store the created simple chain
destination_chains[name] = chain
return destination_chains
def destinations_str(prompt_factory):
destinations = [f"{p['name']}: {p['description']}" for p in prompt_factory.prompt_infos]
destinations_str = '\n'.join(destinations)
return destinations_str
def default_chain():
chain = ConversationChain(llm=chat_openai, memory=memory, verbose=True, output_key="text")
return chain
def router_template(destinations_str):
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(destinations=destinations_str)
return router_template
def router_prompt(router_template):
#Pass into PromptTemplate to enable dynamic inputs
router_prompt = PromptTemplate(
#Pass router template
template=router_template,
#dynamic input text to plug into existing router prompt
input_variables=['input'],
# required
output_parser=RouterOutputParser()
)
return router_prompt
def multipromptchain(router_chain, destination_chains):
# Multiple Prompt Chain
chain = MultiPromptChain(
router_chain=router_chain,
destination_chains=destination_chains,
default_chain=default_chain(),
verbose=True)
return chain