RAG_FRIDAY / app.py
itsanurag's picture
Update app.py
c2efcb6 verified
import utils
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
from pydantic import BaseModel, validator
from langchain.agents import AgentType, Tool, initialize_agent
from langchain.chat_models import ChatHuggingFace
from langchain.tools import DuckDuckGoSearchRun
from langchain.callbacks import StreamlitCallbackHandler
st.set_page_config(page_title="ChatWeb", page_icon="🌐")
st.header('Chatbot with Internet Access')
st.write('Equipped with internet access, enables users to ask questions about recent events')
class ChatbotTools:
def __init__(self):
self.model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
def setup_agent(self):
# Define tool
ddg_search = DuckDuckGoSearchRun()
tools = [
Tool(
name="DuckDuckGoSearch",
func=ddg_search.run,
description="Useful for when you need to answer questions about current events. You should ask targeted questions",
)
]
# Setup LLM and Agent
llm = ChatHuggingFace(model=self.model, tokenizer=self.tokenizer, streaming=True)
agent = initialize_agent(
tools=tools,
llm=llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
handle_parsing_errors=True,
verbose=True
)
return agent
@staticmethod
def enable_chat_history(func):
return func
def main(self):
agent = self.setup_agent()
user_query = st.text_input("Ask me anything!")
if user_query:
with st.container():
utils.display_msg(user_query, 'user')
st_cb = StreamlitCallbackHandler(st.container())
response = agent.run(user_query, callbacks=[st_cb])
st.session_state.messages.append({"role": "assistant", "content": response})
st.write(response)
if __name__ == "__main__":
obj = ChatbotTools()
obj.main()