Spaces:
Sleeping
Sleeping
File size: 7,385 Bytes
3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 be1f4ca c064e85 be1f4ca 3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 c064e85 3e28a59 c064e85 3e28a59 be1f4ca 3e28a59 c064e85 3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 be1f4ca 3e28a59 c064e85 be1f4ca 3e28a59 be1f4ca 3e28a59 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | from langchain.prompts import PromptTemplate
from langgraph.graph import StateGraph, START, END
from langchain_core.messages import (
HumanMessage,
)
from langgraph.checkpoint.memory import MemorySaver
#structuring
import ast
from langchain_core.output_parsers import JsonOutputParser
#error handling with output parser
from langchain.output_parsers import RetryOutputParser
from typing_extensions import TypedDict
from pydantic import BaseModel, Field
#get graph visuals
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod
# state
class State(TypedDict):
"""
A dictionnary representing the state of the agent.
"""
node_message: str
trip_data: dict
query: str
route:str
class llm_nodes:
def __init__(self, llm:any):
self.model=llm
def schedule_creator_node(self,state:State):
llm=self.model
parser = JsonOutputParser()
prompt = PromptTemplate(
template="Answer the user query.\n{format_instructions}\n{query}\n",
input_variables=["query"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
chain = prompt | llm
result=chain.invoke({"query": f'from this query: {state.get('query')} turn the data into a schedule into a json format in the output, do not include ```json```, do not include comments either'})
try:
result=parser.parse(result.content)
return {'trip_data':result,
'node_message':result}
except:
try:
retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
result=retry_parser.parse_with_prompt(result.content, prompt)
return {'trip_data':result,
'node_message':result}
except:
return {'trip_data':result.content,
'node_message': f'created the schedule:{result.content}, but formating failed '}
def schedule_editor_node(self,state:State):
"""
Tool to make modifications to the schedule such as add, delete or modify.
Pass the query to the llm to edit the schedule.
args: query - the query to edit the schedule.
return: modified schedule in a json format
"""
llm=self.model
file=state['trip_data']
# result=llm.invoke(f'Edit this schedule: {str(file)} following the instructions in the query: {query}, and include the changes in the schedule, but do not mention them specifically, only include the updated schedule json format in the output, do not include ```json```, do not include comments either')
parser = JsonOutputParser()
prompt = PromptTemplate(
template="Answer the user query.\n{format_instructions}\n{query}\n",
input_variables=["query"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
chain = prompt | llm
result=chain.invoke({"query": f'Edit this schedule: {str(file)} following the instructions in the query: {state.get('query')}, and include the changes in the schedule, but do not mention them specifically, only include the updated schedule json format in the output, do not include ```json```, do not include comments either'})
try:
result=parser.parse(result.content)
return {'trip_data':result,
'node_message': f'edited the schedule with these changes:{result}'}
except:
try:
retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
result=retry_parser.parse_with_prompt(result.content, prompt)
return {'trip_data':result,
'node_message': f'edited the schedule with these changes:{result}'}
except:
return {'trip_data':result.content,
'node_message': f'edited the schedule with these changes:{result}, but formating failed '}
def agent_node(self,state:State):
llm=self.model
class Form(BaseModel):
route: str = Field(description= 'Return one of: schedule_creator, schedule_editor, show_schedule')
parser=JsonOutputParser(pydantic_object=Form)
instruction=parser.get_format_instructions()
response=llm.invoke([HumanMessage(content=f"Based on this query: {state['query']}, select the appropriate route. Options are: schedule_creator, schedule_editor, show_schedule\n\n{instruction}")])
response=parser.parse(response.content)
route=response.get('route')
return {'route':route}
def show_schedule_node(self,state: State):
"""
Use this tool to get the information about the schedule once it has been loaded.
args: none
return: schedule
"""
schedule=state.get('trip_data')
if schedule:
return {"node_message":schedule}
else:
return{"node_message":"no schedule found, please upload one or add it in the chat"}
def route(state:State):
route=state.get('route')
routing_map={
'schedule_creator': 'to_schedule_creator',
'schedule_editor': 'to_schedule_editor',
'show_schedule': 'to_show_schedule'
}
return routing_map.get(route)
# langgraph
#loading tools
class Schedule_agent:
def __init__(self, llm:any):
self.agent=self._setup(llm)
def _setup(self,llm):
nodes=llm_nodes(llm)
graph_builder = StateGraph(State)
graph_builder.add_node("agent",nodes.agent_node)
graph_builder.add_node('schedule_creator', nodes.schedule_creator_node)
graph_builder.add_node('schedule_editor', nodes.schedule_editor_node)
graph_builder.add_node('show_schedule',nodes.show_schedule_node)
# Any time a tool is called, we return to the chatbot to decide the next step
graph_builder.set_entry_point("agent")
graph_builder.add_conditional_edges(
"agent",
route,{
'to_schedule_creator': 'schedule_creator',
'to_schedule_editor': 'schedule_editor',
'to_show_schedule': 'show_schedule'
}
)
graph_builder.add_edge('schedule_creator',END)
graph_builder.add_edge('schedule_editor',END)
graph_builder.add_edge('show_schedule',END)
memory=MemorySaver()
graph=graph_builder.compile(checkpointer=memory)
return graph
def display_graph(self):
return display(
Image(
self.agent.get_graph().draw_mermaid_png(
draw_method=MermaidDrawMethod.API,
)
)
)
def chat(self,input:str):
config = {"configurable": {"thread_id": "1"}}
response=self.agent.invoke({'query':input
},config)
return response
def stream(self,input:str):
config = {"configurable": {"thread_id": "1"}}
for event in self.agent.stream({'query':input
}, config, stream_mode="updates"):
print(event)
def get_state(self, state_val:str):
config = {"configurable": {"thread_id": "1"}}
return self.agent.get_state(config).values[state_val]
def update_state(self, data: dict):
config = {"configurable": {"thread_id": "1"}}
return self.agent.update_state(config, data) |