wolf1997 commited on
Commit
a0bc067
·
verified ·
1 Parent(s): 3e28a59

Update maps_agent.py

Browse files
Files changed (1) hide show
  1. maps_agent.py +82 -166
maps_agent.py CHANGED
@@ -1,65 +1,56 @@
1
- from langchain.tools import tool
2
- from langchain.prompts import PromptTemplate
3
  from langgraph.graph import StateGraph, START, END
4
- from langgraph.graph.message import add_messages
5
- from langgraph.prebuilt import ToolNode, tools_condition,InjectedState
6
  from langchain_core.messages import (
7
- SystemMessage,
8
  HumanMessage,
9
- AIMessage,
10
- ToolMessage,
11
  )
12
- from langgraph.types import Command, interrupt
13
  from langgraph.checkpoint.memory import MemorySaver
14
- from langchain_core.tools.base import InjectedToolCallId
15
 
16
- #structuring
17
- import ast
18
 
19
- from dataclasses import dataclass
20
  from typing_extensions import TypedDict
21
- from typing import Annotated, Literal
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  #getting current location
25
  import geocoder
26
  import os
27
  import requests
28
- import json
29
- from dotenv import load_dotenv
30
- from os import listdir
31
- from os.path import isfile, join
32
-
33
 
34
  load_dotenv()
35
 
36
  GOOGLE_API_KEY=os.getenv('google_api_key')
37
 
38
-
39
  class State(TypedDict):
40
  """
41
  A dictionnary representing the state of the agent.
42
  """
43
- messages: Annotated[list, add_messages]
44
-
45
  #location data
46
  latitude: str
47
  longitude: str
48
  address: str
 
49
  #results from place search
50
  places: dict
 
 
51
 
52
  def get_current_location_node(state: State):
53
- current_location = geocoder.ip("me")
54
- if current_location.latlng:
55
- latitude, longitude = current_location.latlng
56
- address = current_location.address
57
- return {'latitude':latitude, 'longitude':longitude, 'address':address}
58
- else:
59
- return None
60
-
61
- @tool
62
- def get_current_location_tool(tool_call_id: Annotated[str, InjectedToolCallId]):
63
  """
64
  Tool to get the current location of the user.
65
  agrs: none
@@ -68,111 +59,39 @@ def get_current_location_tool(tool_call_id: Annotated[str, InjectedToolCallId]):
68
  if current_location.latlng:
69
  latitude, longitude = current_location.latlng
70
  address = current_location.address
71
- return Command(update={'messages':[ToolMessage(F'The current location is: address:{address}, longitude:{longitude},lattitude:{latitude}', tool_call_id=tool_call_id)],
72
- 'latitude':latitude,
73
- 'longitude':longitude,
74
- 'address':address})
 
 
 
75
  else:
76
- return None
77
 
78
 
79
- @tool
80
- def find_places_near_me(query:str,state: Annotated[dict, InjectedState],tool_call_id: Annotated[str, InjectedToolCallId]):
81
- """
82
- Use this tool to find locations near me.
83
- args: query - has to be one of the following "car_dealer", "car_rental", "car_repair", "car_wash",
84
- "electric_vehicle_charging_station", "gas_station", "parking", "rest_stop",
85
- "corporate_office", "farm", "ranch", "art_gallery", "art_studio", "auditorium",
86
- "cultural_landmark", "historical_place", "monument", "museum", "performing_arts_theater",
87
- "sculpture", "library", "preschool", "primary_school", "school", "secondary_school",
88
- "university", "adventure_sports_center", "amphitheatre", "amusement_center", "amusement_park",
89
- "aquarium", "banquet_hall", "barbecue_area", "botanical_garden", "bowling_alley", "casino",
90
- "childrens_camp", "comedy_club", "community_center", "concert_hall", "convention_center",
91
- "cultural_center", "cycling_park", "dance_hall", "dog_park", "event_venue", "ferris_wheel",
92
- "garden", "hiking_area", "historical_landmark", "internet_cafe", "karaoke", "marina",
93
- "movie_rental", "movie_theater", "national_park", "night_club", "observation_deck",
94
- "off_roading_area", "opera_house", "park", "philharmonic_hall", "picnic_ground", "planetarium",
95
- "plaza", "roller_coaster", "skateboard_park", "state_park", "tourist_attraction", "video_arcade",
96
- "visitor_center", "water_park", "wedding_venue", "wildlife_park", "wildlife_refuge", "zoo",
97
- "public_bath", "public_bathroom", "stable", "accounting", "atm", "bank", "acai_shop",
98
- "afghani_restaurant", "african_restaurant", "american_restaurant", "asian_restaurant",
99
- "bagel_shop", "bakery", "bar", "bar_and_grill", "barbecue_restaurant", "brazilian_restaurant",
100
- "breakfast_restaurant", "brunch_restaurant", "buffet_restaurant", "cafe", "cafeteria",
101
- "candy_store", "cat_cafe", "chinese_restaurant", "chocolate_factory", "chocolate_shop",
102
- "coffee_shop", "confectionery", "deli", "dessert_restaurant", "dessert_shop", "diner",
103
- "dog_cafe", "donut_shop", "fast_food_restaurant", "fine_dining_restaurant", "food_court",
104
- "french_restaurant", "greek_restaurant", "hamburger_restaurant", "ice_cream_shop", "indian_restaurant",
105
- "indonesian_restaurant", "italian_restaurant", "japanese_restaurant", "juice_shop",
106
- "korean_restaurant", "lebanese_restaurant", "meal_delivery", "meal_takeaway",
107
- "mediterranean_restaurant", "mexican_restaurant", "middle_eastern_restaurant", "pizza_restaurant",
108
- "pub", "ramen_restaurant", "restaurant", "sandwich_shop", "seafood_restaurant", "spanish_restaurant",
109
- "steak_house", "sushi_restaurant", "tea_house", "thai_restaurant", "turkish_restaurant",
110
- "vegan_restaurant", "vegetarian_restaurant", "vietnamese_restaurant", "wine_bar",
111
- "administrative_area_level_1", "administrative_area_level_2", "country", "locality", "postal_code",
112
- "school_district", "city_hall", "courthouse", "embassy", "fire_station", "government_office",
113
- "local_government_office", "neighborhood_police_station", "police", "post_office", "chiropractor",
114
- "dental_clinic", "dentist", "doctor", "drugstore", "hospital", "massage", "medical_lab", "pharmacy",
115
- "physiotherapist", "sauna", "skin_care_clinic", "spa", "tanning_studio", "wellness_center", "yoga_studio",
116
- "apartment_building", "apartment_complex", "condominium_complex", "housing_complex", "bed_and_breakfast",
117
- "budget_japanese_inn", "campground", "camping_cabin", "cottage", "extended_stay_hotel", "farmstay",
118
- "guest_house", "hostel", "hotel", "inn", "japanese_inn", "mobile_home_park", "motel", "private_guest_room",
119
- "resort_hotel", "rv_park", "beach", "church", "hindu_temple", "mosque", "synagogue", "astrologer",
120
- "barber_shop", "beautician", "beauty_salon", "body_art_service", "catering_service", "cemetery",
121
- "child_care_agency", "consultant", "courier_service", "electrician", "florist", "food_delivery", "foot_care",
122
- "funeral_home", "hair_care", "hair_salon", "insurance_agency", "laundry", "lawyer", "locksmith",
123
- "makeup_artist", "moving_company", "nail_salon", "painter", "plumber", "psychic", "real_estate_agency",
124
- "roofing_contractor", "storage", "summer_camp_organizer", "tailor", "telecommunications_service_provider",
125
- "tour_agency", "tourist_information_center", "travel_agency", "veterinary_care", "asian_grocery_store",
126
- "auto_parts_store", "bicycle_store", "book_store", "butcher_shop", "cell_phone_store", "clothing_store",
127
- "convenience_store", "department_store", "discount_store", "electronics_store", "food_store",
128
- "furniture_store", "gift_shop", "grocery_store", "hardware_store", "home_goods_store", "home_improvement_store",
129
- "jewelry_store", "liquor_store", "market", "pet_store", "shoe_store", "shopping_mall", "sporting_goods_store",
130
- "store", "supermarket", "warehouse_store", "wholesaler", "arena", "athletic_field", "fishing_charter",
131
- "fishing_pond", "fitness_center", "golf_course", "gym", "ice_skating_rink", "playground", "ski_resort",
132
- "sports_activity_location", "sports_club", "sports_coaching", "sports_complex", "stadium", "swimming_pool",
133
- "airport", "airstrip", "bus_station", "bus_stop", "ferry_terminal", "heliport", "international_airport",
134
- "light_rail_station", "park_and_ride", "subway_station", "taxi_stand", "train_station", "transit_depot",
135
- "transit_station", "truck_stop"
136
- """
137
- try:
138
- my_longitude=state['longitude']
139
- my_latitude=state['latitude']
140
- response=requests.get(f'https://maps.googleapis.com/maps/api/place/nearbysearch/json?location={my_latitude}%2C{my_longitude}&radius=500&type={query}&key={GOOGLE_API_KEY}')
141
- data=response.json()
142
- places={}
143
- for place in data['results']:
144
- try:
145
- name=place['name']
146
- rating=place['rating']
147
- id=place['place_id']
148
- response=requests.get(f'https://places.googleapis.com/v1/places/{id}?fields=googleMapsLinks.placeUri&key={GOOGLE_API_KEY}')
149
- data=response.json()
150
- link=data['googleMapsLinks']['placeUri']
151
- places[name]= {'rating':rating,
152
- 'google_maps_link':link,
153
- }
154
- except Exception as e:
155
- f'Error: {e}'
156
-
157
- return Command(update={'places':places,
158
- 'messages':[ToolMessage(f'I found {len(places)} places', tool_call_id=tool_call_id)]})
159
- except:
160
- return Command(update={'messages':[ToolMessage('Could not find places based on the query', tool_call_id=tool_call_id)]})
161
 
162
 
163
 
164
- @tool
165
- def look_for_places(query: str, tool_call_id: Annotated[str, InjectedToolCallId]):
166
  """
167
  Tool to look for places based on the user query and location.
168
  Use this tool for more complex user queries like sentences, and if the location is specified in the query.
169
  Places includes restaurants, bars, speakeasy, games, anything.
170
  args: query - the query has to be in this format eg.Spicy%20Vegetarian%20Food%20in%20Sydney%20Australia.
171
-
172
-
173
  """
 
174
  try:
175
- response=requests.get(f'https://maps.googleapis.com/maps/api/place/textsearch/json?query={query}?&key={GOOGLE_API_KEY}')
176
  data=response.json()
177
  places={}
178
  for place in data['results']:
@@ -196,67 +115,64 @@ def look_for_places(query: str, tool_call_id: Annotated[str, InjectedToolCallId]
196
  except Exception as e:
197
  f'Error: {e}'
198
 
199
- return Command(update={'places':places,
200
- 'messages':[ToolMessage(f'I found {len(places)} places', tool_call_id=tool_call_id)]})
201
  except Exception as e:
202
- return f'Error: error'
203
 
204
- @tool
205
- def show_places_found(state: Annotated[dict, InjectedState]):
206
- """
207
- Tool to get the places found by previous tool calls and to show/display them.
208
- It has links within that can also be used for directions
209
- always show the links
210
- args: none
211
- """
212
- return state['places']
213
-
214
- class maps_agent:
215
  def __init__(self,llm: any):
216
  self.agent=self._setup(llm)
217
 
218
 
219
  def _setup(self,llm):
220
- langgraph_tools=[get_current_location_tool,look_for_places, find_places_near_me,show_places_found]
221
-
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  graph_builder = StateGraph(State)
224
 
225
- # Modification: tell the LLM which tools it can call
226
- llm_with_tools = llm.bind_tools(langgraph_tools)
227
- tool_node = ToolNode(tools=langgraph_tools)
228
- def chatbot(state: State):
229
- """ travel assistant that answers user questions about their trip.
230
- Depending on the request, leverage which tools to use if necessary."""
231
- return {"messages": [llm_with_tools.invoke(state['messages'])]}
232
-
233
- graph_builder.add_node("chatbot", chatbot)
234
 
235
- graph_builder.add_node('current_location',get_current_location_node)
236
- graph_builder.add_node("tools", tool_node)
237
- # Any time a tool is called, we return to the chatbot to decide the next step
238
- graph_builder.set_entry_point("current_location")
239
- graph_builder.add_edge('current_location','chatbot')
240
- graph_builder.add_edge("tools", "chatbot")
241
- graph_builder.add_conditional_edges(
242
- "chatbot",
243
- tools_condition,
244
- )
245
  memory=MemorySaver()
246
  graph=graph_builder.compile(checkpointer=memory)
247
  return graph
248
-
 
 
 
 
 
 
 
 
249
  def get_state(self, state_val:str):
250
  config = {"configurable": {"thread_id": "1"}}
251
  return self.agent.get_state(config).values[state_val]
252
 
253
- def stream(self,input:str):
254
  config = {"configurable": {"thread_id": "1"}}
255
- input_message = HumanMessage(content=input)
256
- for event in self.agent.stream({"messages": [input_message]}, config, stream_mode="values"):
257
- event["messages"][-1].pretty_print()
258
 
259
- def chat(self,input:str):
260
  config = {"configurable": {"thread_id": "1"}}
261
- response=self.agent.invoke({'messages':HumanMessage(content=str(input))},config)
262
- return response['messages'][-1].content
 
1
+
 
2
  from langgraph.graph import StateGraph, START, END
3
+
 
4
  from langchain_core.messages import (
 
5
  HumanMessage,
 
 
6
  )
7
+
8
  from langgraph.checkpoint.memory import MemorySaver
 
9
 
10
+ from langchain_core.output_parsers import JsonOutputParser
 
11
 
 
12
  from typing_extensions import TypedDict
 
13
 
14
 
15
+ #get graph visuals
16
+ from IPython.display import Image, display
17
+ from langchain_core.runnables.graph import MermaidDrawMethod
18
+ from pydantic import BaseModel, Field
19
+ import os
20
+
21
+
22
+
23
+ from typing_extensions import TypedDict
24
+ from typing import Optional
25
+
26
+ from dotenv import load_dotenv
27
+
28
  #getting current location
29
  import geocoder
30
  import os
31
  import requests
 
 
 
 
 
32
 
33
  load_dotenv()
34
 
35
  GOOGLE_API_KEY=os.getenv('google_api_key')
36
 
 
37
  class State(TypedDict):
38
  """
39
  A dictionnary representing the state of the agent.
40
  """
41
+ node_message: str
42
+ query: str
43
  #location data
44
  latitude: str
45
  longitude: str
46
  address: str
47
+ place_query: str
48
  #results from place search
49
  places: dict
50
+ route:str
51
+
52
 
53
  def get_current_location_node(state: State):
 
 
 
 
 
 
 
 
 
 
54
  """
55
  Tool to get the current location of the user.
56
  agrs: none
 
59
  if current_location.latlng:
60
  latitude, longitude = current_location.latlng
61
  address = current_location.address
62
+ return {
63
+ 'latitude':latitude,
64
+ 'longitude':longitude,
65
+ 'address':address,
66
+ 'node_message':{'latitude':latitude,
67
+ 'longitude':longitude,
68
+ 'address':address}}
69
  else:
70
+ return {'node_message':'failed'}
71
 
72
 
73
+ def router_node(state=State):
74
+
75
+
76
+ route=state.get('route')
77
+ if route=='look_for_places':
78
+ return 'to_look_for_places'
79
+ elif route=='current_loc':
80
+ return 'to_current_loc'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
 
84
+ def look_for_places_node(state: State):
 
85
  """
86
  Tool to look for places based on the user query and location.
87
  Use this tool for more complex user queries like sentences, and if the location is specified in the query.
88
  Places includes restaurants, bars, speakeasy, games, anything.
89
  args: query - the query has to be in this format eg.Spicy%20Vegetarian%20Food%20in%20Sydney%20Australia.
90
+ Alaways include the links in the respons, but not longitude or latitude
 
91
  """
92
+
93
  try:
94
+ response=requests.get(f'https://maps.googleapis.com/maps/api/place/textsearch/json?query={state.get('place_query')}?&key={GOOGLE_API_KEY}')
95
  data=response.json()
96
  places={}
97
  for place in data['results']:
 
115
  except Exception as e:
116
  f'Error: {e}'
117
 
118
+ return {'places':places,
119
+ 'node_message':places}
120
  except Exception as e:
121
+ return {'node_message': e}
122
 
123
+ class Maps_agent:
 
 
 
 
 
 
 
 
 
 
124
  def __init__(self,llm: any):
125
  self.agent=self._setup(llm)
126
 
127
 
128
  def _setup(self,llm):
129
+ # langgraph_tools=[get_current_location_tool,look_for_places, show_places_found]
130
+ def agent_node(state:State):
131
+ class Form(BaseModel):
132
+ route: str = Field(description= 'return current_loc or look_for_places')
133
+ place_query: Optional[str] = Field(description= ' if the query is to look for a place return the place_query has to be in this format eg.Spicy%20Vegetarian%20Food%20in%20Sydney%20Australia')
134
+ parser=JsonOutputParser(pydantic_object=Form)
135
+ instruction=parser.get_format_instructions()
136
+ response=llm.invoke([HumanMessage(content=f'based on this query:{state['query']}, return current_loc to get the current location or look_for_places for the route '+'\n\n'+instruction)])
137
+ response=parser.parse(response.content)
138
+ route=response.get('route')
139
+ place_query=response.get('place_query')
140
+ return {'route':route,
141
+ 'place_query': place_query}
142
 
143
  graph_builder = StateGraph(State)
144
 
 
 
 
 
 
 
 
 
 
145
 
146
+ graph_builder.add_node('current_loc', get_current_location_node)
147
+ graph_builder.add_node('look_for_places',look_for_places_node)
148
+
149
+ graph_builder.add_node('agent',agent_node)
150
+ graph_builder.add_edge(START,'agent')
151
+ graph_builder.add_conditional_edges('agent',router_node,{'to_current_loc':'current_loc', 'to_look_for_places':'look_for_places'})
152
+ graph_builder.add_edge('current_loc',END)
153
+ graph_builder.add_edge('look_for_places',END)
 
 
154
  memory=MemorySaver()
155
  graph=graph_builder.compile(checkpointer=memory)
156
  return graph
157
+
158
+ def display_graph(self):
159
+ return display(
160
+ Image(
161
+ self.agent.get_graph().draw_mermaid_png(
162
+ draw_method=MermaidDrawMethod.API,
163
+ )
164
+ )
165
+ )
166
  def get_state(self, state_val:str):
167
  config = {"configurable": {"thread_id": "1"}}
168
  return self.agent.get_state(config).values[state_val]
169
 
170
+ def chat(self,input:str):
171
  config = {"configurable": {"thread_id": "1"}}
172
+ response=self.agent.invoke({'query':input},config)
173
+ return response.get('node_message')
 
174
 
175
+ def stream(self,input:str):
176
  config = {"configurable": {"thread_id": "1"}}
177
+ for event in self.agent.stream({'query':input}, config, stream_mode="updates"):
178
+ print(event)