wolf1997 commited on
Commit
3e28a59
·
verified ·
1 Parent(s): f2032eb

Update schedule_agent.py

Browse files
Files changed (1) hide show
  1. schedule_agent.py +149 -194
schedule_agent.py CHANGED
@@ -1,247 +1,202 @@
1
- from langchain_google_genai import ChatGoogleGenerativeAI
2
-
3
-
4
- from langchain.tools import tool
5
-
6
  from langgraph.graph import StateGraph, START, END
7
- from langgraph.graph.message import add_messages
8
- from langgraph.prebuilt import ToolNode, tools_condition,InjectedState
9
  from langchain_core.messages import (
10
- SystemMessage,
11
  HumanMessage,
12
- AIMessage,
13
- ToolMessage,
14
  )
15
- from langgraph.types import Command, interrupt
16
  from langgraph.checkpoint.memory import MemorySaver
17
- from langchain_core.tools.base import InjectedToolCallId
18
 
19
  #structuring
20
  import ast
21
- from langchain.prompts import PromptTemplate
22
  from langchain_core.output_parsers import JsonOutputParser
23
  #error handling with output parser
24
  from langchain.output_parsers import RetryOutputParser
25
 
26
-
27
- from dataclasses import dataclass
28
  from typing_extensions import TypedDict
29
- from typing import Annotated, Literal
30
- from pydantic import BaseModel, Field
31
-
32
-
33
-
34
- import os
35
- import requests
36
- import json
37
- from dotenv import load_dotenv
38
- from os import listdir
39
- from os.path import isfile, join
40
-
41
-
42
- load_dotenv()
43
-
44
-
45
- # loading the necessary api keys
46
- GOOGLE_API_KEY=os.getenv('google_api_key')
47
 
 
 
 
 
48
 
49
-
50
 
51
- GEMINI_MODEL='gemini-2.0-flash'
52
 
53
- llm = ChatGoogleGenerativeAI(google_api_key=GOOGLE_API_KEY, model=GEMINI_MODEL, temperature=0.3)
54
-
55
  # state
56
  class State(TypedDict):
57
  """
58
  A dictionnary representing the state of the agent.
59
  """
60
- messages: Annotated[list, add_messages]
61
  trip_data: dict
62
-
63
- # defining the tools for the agent to use
64
-
65
- @tool
66
- def local_files_browser(tool_call_id: Annotated[str, InjectedToolCallId]) -> str:
67
- """
68
- tool to list the local schedule files.
69
- args:none
70
- """
71
- mypath=f'schedules/'
72
- onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))]
73
- if not onlyfiles:
74
- return Command(update={'messages':[ToolMessage(f'No files are available, try to upload one',tool_call_id=tool_call_id)]})
75
- else:
76
- return Command(update={'messages':[ToolMessage(f'Here are the available schedules: {onlyfiles}',tool_call_id=tool_call_id)]})
77
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
- @tool
81
- def schedule_loader(tool_call_id: Annotated[str, InjectedToolCallId],filename: str) -> Command:
82
- """
83
- Use this tool to load the schedule from local directory, which is a text file.
84
- args: filename - the name of the file, include the extention.
85
- return: schedule in a json format
86
- """
87
- try:
88
- with open(f'schedules/{filename}', 'rb') as f:
89
- schedule=f.read()
90
-
91
- try:
92
- parser = JsonOutputParser()
93
- prompt = PromptTemplate(
94
- template="Answer the user query.\n{format_instructions}\n{query}\n",
95
- input_variables=["query"],
96
- partial_variables={"format_instructions": parser.get_format_instructions()},
97
- )
98
-
99
- chain = prompt | llm
100
- result=chain.invoke({"query": f'format this schedule: {str(schedule)} into a json format in the output, do not include ```json```, do not include comments either'})
101
- result=parser.parse(result.content)
102
- return Command(update={'trip_data':result,
103
- 'messages': [ToolMessage('Succesfully uploaded schedule',tool_call_id=tool_call_id)]})
104
- except:
105
  try:
106
  retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
107
  result=retry_parser.parse_with_prompt(result.content, prompt)
108
- return Command(update={'trip_data':result,
109
- 'messages': [ToolMessage('Succesfully uploaded schedule',tool_call_id=tool_call_id)]})
110
  except:
111
- return Command(update={'trip_data':result.content,
112
- 'messages': [ToolMessage(f'loaded the schedule:{result.content}, but formating failed ',tool_call_id=tool_call_id)]})
113
- except:
114
- return Command(update={'messages':[ToolMessage('No Schedule please try a different filename, or include the extention eg. filename.txt',tool_call_id=tool_call_id)]})
115
-
116
-
117
- @tool
118
- def schedule_creator(tool_call_id: Annotated[str, InjectedToolCallId], schedule:str)->str:
119
- """Tool to create a schedule from the chat with the agent
120
- and then uses an llm to structure it.
121
- args: schedule - the schedule from the chat
122
- """
 
 
 
 
123
 
124
- try:
125
- parser = JsonOutputParser()
126
- prompt = PromptTemplate(
127
- template="Answer the user query.\n{format_instructions}\n{query}\n",
128
- input_variables=["query"],
129
- partial_variables={"format_instructions": parser.get_format_instructions()},
130
- )
131
-
132
- chain = prompt | llm
133
- result=chain.invoke({"query": f'format this schedule: {str(schedule)} into a json format in the output, do not include ```json```, do not include comments either'})
134
- result=parser.parse(result.content)
135
- return Command(update={'trip_data':result,
136
- 'messages': [ToolMessage('Succesfully created schedule',tool_call_id=tool_call_id)]})
137
- except:
138
- try:
139
- retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
140
- result=retry_parser.parse_with_prompt(result.content, prompt)
141
- return Command(update={'trip_data':result,
142
- 'messages': [ToolMessage('Succesfully created schedule',tool_call_id=tool_call_id)]})
143
- except:
144
- return Command(update={'trip_data':result.content,
145
- 'messages': [ToolMessage(f'created the schedule:{result.content}, but formating failed ',tool_call_id=tool_call_id)]})
146
-
147
-
148
- @tool
149
- def get_schedule(state: Annotated[dict, InjectedState])-> str:
150
- """
151
- Use this tool to get the information about the schedule once it has been loaded.
152
- args: none
153
- return: schedule
154
- """
155
- return state['trip_data']
156
-
157
- @tool
158
- def schedule_editor(query:str,state: Annotated[dict, InjectedState],tool_call_id: Annotated[str, InjectedToolCallId])-> str:
159
- """
160
- Tool to make modifications to the schedule such as add, delete or modify.
161
- Pass the query to the llm to edit the schedule.
162
- args: query - the query to edit the schedule.
163
- return: modified schedule in a json format
164
- """
165
- file=state['trip_data']
166
- try:
167
- parser = JsonOutputParser()
168
- prompt = PromptTemplate(
169
- template="Answer the user query.\n{format_instructions}\n{query}\n",
170
- input_variables=["query"],
171
- partial_variables={"format_instructions": parser.get_format_instructions()},
172
- )
173
-
174
- chain = prompt | llm
175
- result=chain.invoke({"query": f'Edit this schedule: {str(file)} following the instructions in the query: {query}, and include the changes in the schedule, but do not mention them specifically, only include the updated schedule json format in the output, do not include ```json```, do not include comments either'})
176
- result=parser.parse(result.content)
177
- return Command(update={'trip_data':result,
178
- 'messages': [ToolMessage(f'edited the schedule with these changes:{result} ',tool_call_id=tool_call_id)]})
179
- except:
180
- try:
181
- retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
182
- result=retry_parser.parse_with_prompt(result.content, prompt)
183
- return Command(update={'trip_data':result,
184
- 'messages': [ToolMessage(f'edited the schedule with these changes:{result} ',tool_call_id=tool_call_id)]})
185
- except:
186
- return Command(update={'trip_data':result.content,
187
- 'messages': [ToolMessage(f'edited the schedule with these changes:{result}, but formating failed ',tool_call_id=tool_call_id)]})
188
-
189
- @tool
190
- def save_schedule(state: Annotated[dict, InjectedState],tool_call_id: Annotated[str, InjectedToolCallId], filename: str) -> str:
191
- """
192
- Tool to save the schedule with a specified filename.
193
- agrs: filename the name of the file, no need to include the extentions of the file
194
  """
195
- file= state['trip_data']
196
- return f'{filename} saved (however, since this is a demo, it is not really saved:))'
197
-
198
-
199
-
200
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  class Schedule_agent:
202
- def __init__(self,llm:any):
203
  self.agent=self._setup(llm)
204
  def _setup(self,llm):
205
-
206
- langgraph_tools=[get_schedule,schedule_creator,local_files_browser, save_schedule, schedule_editor,schedule_loader]
207
-
208
 
209
  graph_builder = StateGraph(State)
210
 
211
- # Modification: tell the LLM which tools it can call
212
- llm_with_tools = llm.bind_tools(langgraph_tools)
213
- tool_node = ToolNode(tools=langgraph_tools)
214
- def chatbot(state: State):
215
- """ travel assistant that answers user questions about their trip.
216
- Depending on the request, leverage which tools to use if necessary."""
217
- return {"messages": [llm_with_tools.invoke(state['messages'])]}
218
 
219
- graph_builder.add_node("chatbot", chatbot)
220
 
221
 
222
- graph_builder.add_node("tools", tool_node)
 
 
223
  # Any time a tool is called, we return to the chatbot to decide the next step
224
- graph_builder.set_entry_point("chatbot")
225
- graph_builder.add_edge("tools", "chatbot")
226
  graph_builder.add_conditional_edges(
227
- "chatbot",
228
- tools_condition,
 
 
 
 
229
  )
 
 
 
230
  memory=MemorySaver()
231
  graph=graph_builder.compile(checkpointer=memory)
232
  return graph
233
 
234
- def stream(self,input:str):
235
- config = {"configurable": {"thread_id": "1"}}
236
- input_message = HumanMessage(content=input)
237
- for event in self.agent.stream({"messages": [input_message]}, config, stream_mode="values"):
238
- event["messages"][-1].pretty_print()
239
 
 
 
 
 
 
 
 
 
240
  def chat(self,input:str):
241
  config = {"configurable": {"thread_id": "1"}}
242
- response=self.agent.invoke({'messages':HumanMessage(content=str(input))},config)
243
- return response['messages'][-1].content
 
 
 
 
 
 
 
244
 
245
  def get_state(self, state_val:str):
246
  config = {"configurable": {"thread_id": "1"}}
247
- return self.agent.get_state(config).values[state_val]
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
 
 
 
 
2
  from langgraph.graph import StateGraph, START, END
3
+
 
4
  from langchain_core.messages import (
5
+
6
  HumanMessage,
7
+
 
8
  )
9
+
10
  from langgraph.checkpoint.memory import MemorySaver
11
+
12
 
13
  #structuring
14
  import ast
 
15
  from langchain_core.output_parsers import JsonOutputParser
16
  #error handling with output parser
17
  from langchain.output_parsers import RetryOutputParser
18
 
 
 
19
  from typing_extensions import TypedDict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ from pydantic import BaseModel, Field
22
+ #get graph visuals
23
+ from IPython.display import Image, display
24
+ from langchain_core.runnables.graph import MermaidDrawMethod
25
 
 
26
 
 
27
 
 
 
28
  # state
29
  class State(TypedDict):
30
  """
31
  A dictionnary representing the state of the agent.
32
  """
33
+ node_message: str
34
  trip_data: dict
35
+ query: str
36
+ route:str
37
+
38
+
39
+ class llm_nodes:
40
+
41
+ def __init__(self, llm:any):
42
+ self.model=llm
43
+ def schedule_creator_node(self,state:State):
44
+ llm=self.model
45
+ parser = JsonOutputParser()
46
+ prompt = PromptTemplate(
47
+ template="Answer the user query.\n{format_instructions}\n{query}\n",
48
+ input_variables=["query"],
49
+ partial_variables={"format_instructions": parser.get_format_instructions()},
50
+ )
51
+
52
+ chain = prompt | llm
53
+ result=chain.invoke({"query": f'from this query: {state.get('query')} turn the data into a schedule into a json format in the output, do not include ```json```, do not include comments either'})
54
+ try:
55
+
56
+ result=parser.parse(result.content)
57
+ return {'trip_data':result,
58
+ 'node_message':result}
59
+ except:
60
+ try:
61
+ retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
62
+ result=retry_parser.parse_with_prompt(result.content, prompt)
63
+ return {'trip_data':result,
64
+ 'node_message':result}
65
+ except:
66
+ return {'trip_data':result.content,
67
+ 'node_message': f'created the schedule:{result.content}, but formating failed '}
68
+
69
 
70
 
71
+ def schedule_editor_node(self,state:State):
72
+ """
73
+ Tool to make modifications to the schedule such as add, delete or modify.
74
+ Pass the query to the llm to edit the schedule.
75
+ args: query - the query to edit the schedule.
76
+ return: modified schedule in a json format
77
+ """
78
+ llm=self.model
79
+ file=state['trip_data']
80
+ # result=llm.invoke(f'Edit this schedule: {str(file)} following the instructions in the query: {query}, and include the changes in the schedule, but do not mention them specifically, only include the updated schedule json format in the output, do not include ```json```, do not include comments either')
81
+ parser = JsonOutputParser()
82
+ prompt = PromptTemplate(
83
+ template="Answer the user query.\n{format_instructions}\n{query}\n",
84
+ input_variables=["query"],
85
+ partial_variables={"format_instructions": parser.get_format_instructions()},
86
+ )
87
+
88
+ chain = prompt | llm
89
+ result=chain.invoke({"query": f'Edit this schedule: {str(file)} following the instructions in the query: {state.get('query')}, and include the changes in the schedule, but do not mention them specifically, only include the updated schedule json format in the output, do not include ```json```, do not include comments either'})
90
+ try:
91
+
92
+ result=parser.parse(result.content)
93
+ return {'trip_data':result,
94
+ 'node_message': f'edited the schedule with these changes:{result}'}
95
+ except:
96
  try:
97
  retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
98
  result=retry_parser.parse_with_prompt(result.content, prompt)
99
+ return {'trip_data':result,
100
+ 'node_message': f'edited the schedule with these changes:{result}'}
101
  except:
102
+ return {'trip_data':result.content,
103
+ 'node_message': f'edited the schedule with these changes:{result}, but formating failed '}
104
+
105
+
106
+ def agent_node(self,state:State):
107
+ llm=self.model
108
+ class Form(BaseModel):
109
+ route: str = Field(description= 'Return one of: schedule_creator, schedule_editor, show_schedule')
110
+
111
+ parser=JsonOutputParser(pydantic_object=Form)
112
+ instruction=parser.get_format_instructions()
113
+ response=llm.invoke([HumanMessage(content=f"Based on this query: {state['query']}, select the appropriate route. Options are: schedule_creator, schedule_editor, show_schedule\n\n{instruction}")])
114
+ response=parser.parse(response.content)
115
+ route=response.get('route')
116
+
117
+ return {'route':route}
118
 
119
+ def show_schedule_node(self,state: State):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  """
121
+ Use this tool to get the information about the schedule once it has been loaded.
122
+ args: none
123
+ return: schedule
124
+ """
125
+ schedule=state.get('trip_data')
126
+ if schedule:
127
+ return {"node_message":schedule}
128
+ else:
129
+ return{"node_message":"no schedule found, please upload one or add it in the chat"}
130
+
131
+ def route(state:State):
132
+ route=state.get('route')
133
+ routing_map={
134
+ 'schedule_creator': 'to_schedule_creator',
135
+ 'schedule_editor': 'to_schedule_editor',
136
+ 'show_schedule': 'to_show_schedule'
137
+ }
138
+ return routing_map.get(route)
139
+
140
+ # langgraph
141
+ #loading tools
142
  class Schedule_agent:
143
+ def __init__(self, llm:any):
144
  self.agent=self._setup(llm)
145
  def _setup(self,llm):
146
+ nodes=llm_nodes(llm)
147
+
 
148
 
149
  graph_builder = StateGraph(State)
150
 
 
 
 
 
 
 
 
151
 
152
+ graph_builder.add_node("agent",nodes.agent_node)
153
 
154
 
155
+ graph_builder.add_node('schedule_creator', nodes.schedule_creator_node)
156
+ graph_builder.add_node('schedule_editor', nodes.schedule_editor_node)
157
+ graph_builder.add_node('show_schedule',nodes.show_schedule_node)
158
  # Any time a tool is called, we return to the chatbot to decide the next step
159
+ graph_builder.set_entry_point("agent")
 
160
  graph_builder.add_conditional_edges(
161
+ "agent",
162
+ route,{
163
+ 'to_schedule_creator': 'schedule_creator',
164
+ 'to_schedule_editor': 'schedule_editor',
165
+ 'to_show_schedule': 'show_schedule'
166
+ }
167
  )
168
+ graph_builder.add_edge('schedule_creator',END)
169
+ graph_builder.add_edge('schedule_editor',END)
170
+ graph_builder.add_edge('show_schedule',END)
171
  memory=MemorySaver()
172
  graph=graph_builder.compile(checkpointer=memory)
173
  return graph
174
 
 
 
 
 
 
175
 
176
+ def display_graph(self):
177
+ return display(
178
+ Image(
179
+ self.agent.get_graph().draw_mermaid_png(
180
+ draw_method=MermaidDrawMethod.API,
181
+ )
182
+ )
183
+ )
184
  def chat(self,input:str):
185
  config = {"configurable": {"thread_id": "1"}}
186
+ response=self.agent.invoke({'query':input
187
+ },config)
188
+ return response
189
+
190
+ def stream(self,input:str):
191
+ config = {"configurable": {"thread_id": "1"}}
192
+ for event in self.agent.stream({'query':input
193
+ }, config, stream_mode="updates"):
194
+ print(event)
195
 
196
  def get_state(self, state_val:str):
197
  config = {"configurable": {"thread_id": "1"}}
198
+ return self.agent.get_state(config).values[state_val]
199
+
200
+ def update_state(self, data: dict):
201
+ config = {"configurable": {"thread_id": "1"}}
202
+ return self.agent.update_state(config, data)