File size: 5,508 Bytes
ce92321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from langchain_community.utilities import SQLDatabase
from langchain_tavily import TavilySearch
from langgraph.prebuilt import create_react_agent
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain.chat_models import init_chat_model
from dotenv import load_dotenv
from langchain.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph,START,END
from typing import TypedDict
from langchain import hub
import os
load_dotenv()
mistralapi=os.environ['MISTRAL']
taviliy=os.environ['TAVILY']

llm=init_chat_model("ministral-8b-latest",model_provider="mistralai",api_key=mistralapi)

db=SQLDatabase.from_uri("sqlite:///db/bitcoin_data.db")
toolkit=SQLDatabaseToolkit(db=db,llm=llm)
sqltools=toolkit.get_tools()
table_info=db.get_table_info()
sqlprompt=hub.pull("langchain-ai/sql-agent-system-prompt")
sqlprompt=sqlprompt.format(dialect="SQLite", top_k=5)

class GraphState(TypedDict):
    userinput:str
    movewhere:str
    aimessages:list[str]
    query:str
    finalanswer:str


def router(state:GraphState):
    userinput=state["userinput"]
    prompt=ChatPromptTemplate.from_messages([
        ("system","You are assistant who have to decide what user need right now reponsd with 'News' or 'Analysis' or 'Research' or 'all' 'none' dont use any puctuation, symbols or extra text in the reponse. Remember when you respond with analysis we run sql queries on a database with bitcoin high low volume and rsi mcad etc respond this only when the user asks about statistical analysis on the bitcoin data. When you select Research we search internet for any financial oredictions or articles. You must first understand prompt and check which option is best for better responses. Respond with 'all' when i the query of user is diverse like getting some financial advice or require some extra info other than seraching old database. If none of the options is required or question is not about bitcoin or related terms or personalities respond 'none'"),
        ("user",f"Here is the input from user '{userinput}'")
    ])
    chain=prompt | llm
    response=chain.invoke({'userinput':userinput})
    print(response.content)
    state['movewhere']=response.content.lower()
    return state

def querygen(state:GraphState):
    userinput=state['userinput']
    prompt=ChatPromptTemplate.from_messages([
        ("system","You are a assistant who would generate questions for an agent which generates sql queries for data {datainfo}. Analyze the user input and generate the questions for that agent that gives maximum information that user need for better answer. You should ask only 5 statistical questions without heading just simple questions"),
        ("user","Here is a prompt from user '{userinput}'")
    ])
    chain=prompt | llm
    response=chain.invoke({"datainfo":table_info,'userinput':userinput})
    state['query']=response.content
    return state

def news(state:GraphState):
    news_update=TavilySearch(
        tavily_api_key=taviliy,
        max_results=5,
        topic="news"
    )   
    userinput=state['userinput']
    searchagent=create_react_agent(llm,tools=[news_update])
    response=searchagent.invoke({"messages":[{"role":"user","content":userinput}]})
    state['aimessages'].append(response['messages'][-1].content)
    return state

def analysis(state:GraphState):
    queries=state['query']
    sqlagent=create_react_agent(llm,tools=sqltools,prompt=sqlprompt)
    response=sqlagent.invoke({"messages":[{"role":"user","content":queries}]})
    state['aimessages'].append(response['messages'][-1].content)
    return state

def search(state:GraphState):
    userinput=state['userinput']
    finance_update=TavilySearch(
        tavily_api_key=taviliy,
        max_results=5,
        topic="finance"
    )
    searchagent=create_react_agent(llm,tools=[finance_update])
    response=searchagent.invoke({"messages":[{"role":"user","content":userinput}]})
    state['aimessages'].append(response['messages'][-1].content)
    return state

def finalnode(state:GraphState):
    userinput=state['userinput']
    aimessages=state['aimessages']
    prompt=ChatPromptTemplate.from_messages([
        ("system","You are an agent who have to write final comprehensive answer based on user query and the provided docs given by different AI agents here are some docs that might help {aimessages}"),
        ("user","{userinput}")
    ])
    chain=prompt | llm
    response=chain.invoke({"aimessages":aimessages,"userinput":userinput})
    state['finalanswer']=response.content
    return state

def runall(state:GraphState):
    return state

def random(state:GraphState):
    response=llm.invoke(state['userinput'])
    state['finalanswer']=response.content
    return state

builder=StateGraph(GraphState)
builder.add_node("decide",router)
builder.add_node("random",random)

builder.add_edge(START,"decide")
builder.add_conditional_edges("decide",lambda state:state['movewhere'],
{
  "news":"news",
  "analysis":"querygen",
  "research":"search",
  "none":"random",
  "all":"runall"
})
builder.add_sequence([
    ("runall",runall),
    ("queerygenall",querygen),
    ("analyzeall",analysis),
    ("newsall",news),
    ("searchall",search),
    ("finalnodeall",finalnode)
])
builder.add_sequence([("search",search),("finalsearch",finalnode)])
builder.add_sequence([("news",news),("finalnews",finalnode)])
builder.add_sequence([("querygen",querygen),("analyze",analysis),("finalanalyze",finalnode)])
builder.add_edge("random",END)

graph=builder.compile()