File size: 1,786 Bytes
9c400b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from dotenv import load_dotenv

from langchain.chains import LLMMathChain
from langchain.llms.openai import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.utilities.serpapi import SerpAPIWrapper
from langchain.agents import initialize_agent, Tool, AgentExecutor
import chainlit as cl

from src.tools.crypto_coin_price_tool import CryptoCoinPriceTool

load_dotenv()

@cl.on_chat_start
def start():
    llm = ChatOpenAI(temperature=0, streaming=True)
    llm1 = OpenAI(temperature=0, streaming=True)
    search = SerpAPIWrapper()
    get_crypto_coin_price = CryptoCoinPriceTool()
    llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)

    tools = [
        Tool(
            name="Search",
            func=search.run,
            description="useful for when you need to answer questions about current events. You should ask targeted questions",
            handle_tool_error=True,
        ),
        Tool(
            name="Calculator",
            func=llm_math_chain.run,
            description="useful for when you need to answer questions about math",
            handle_tool_error=True,
        ),
        Tool(
            name=get_crypto_coin_price.name,
            func=get_crypto_coin_price.run,
            description=get_crypto_coin_price.description,
            handle_tool_error=True,
        ),
    ]
    agent = initialize_agent(
        tools, llm1, agent="chat-zero-shot-react-description", verbose=True, handle_parsing_errors=True
    )
    cl.user_session.set("agent", agent)


@cl.on_message
async def main(message: cl.Message):
    agent = cl.user_session.get("agent")  # type: AgentExecutor
    cb = cl.LangchainCallbackHandler(stream_final_answer=True)

    await cl.make_async(agent.run)(message.content, callbacks=[cb])