Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- gradio_app.py +23 -0
- maps_agent.py +264 -0
- schedule_agent.py +198 -0
- schedules/test.txt +16 -0
- travel_agent.py +189 -0
gradio_app.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from travel_agent import travel_agent
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import os
|
| 4 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
GOOGLE_API_KEY=os.getenv('google_api_key')
|
| 9 |
+
GEMINI_MODEL='gemini-2.0-flash'
|
| 10 |
+
llm = ChatGoogleGenerativeAI(google_api_key=GOOGLE_API_KEY, model=GEMINI_MODEL, temperature=0.3)
|
| 11 |
+
|
| 12 |
+
#initializing the agent
|
| 13 |
+
travel_ai=travel_agent(llm)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def chatbot(input, history):
|
| 17 |
+
#no need for history since agent has state memory already
|
| 18 |
+
response=travel_ai.chatbot(input)
|
| 19 |
+
return response
|
| 20 |
+
demo = gr.ChatInterface(chatbot, type="messages", autofocus=False)
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
demo.launch()
|
maps_agent.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
from langchain.tools import tool
|
| 4 |
+
from langchain.prompts import PromptTemplate
|
| 5 |
+
from langgraph.graph import StateGraph, START, END
|
| 6 |
+
from langgraph.graph.message import add_messages
|
| 7 |
+
from langgraph.prebuilt import ToolNode, tools_condition,InjectedState
|
| 8 |
+
from langchain_core.messages import (
|
| 9 |
+
SystemMessage,
|
| 10 |
+
HumanMessage,
|
| 11 |
+
AIMessage,
|
| 12 |
+
ToolMessage,
|
| 13 |
+
)
|
| 14 |
+
from langgraph.types import Command, interrupt
|
| 15 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 16 |
+
from langchain_core.tools.base import InjectedToolCallId
|
| 17 |
+
|
| 18 |
+
#structuring
|
| 19 |
+
import ast
|
| 20 |
+
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from typing_extensions import TypedDict
|
| 23 |
+
from typing import Annotated, Literal
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
#getting current location
|
| 27 |
+
import geocoder
|
| 28 |
+
import os
|
| 29 |
+
import requests
|
| 30 |
+
import json
|
| 31 |
+
from dotenv import load_dotenv
|
| 32 |
+
from os import listdir
|
| 33 |
+
from os.path import isfile, join
|
| 34 |
+
from werkzeug.utils import secure_filename
|
| 35 |
+
|
| 36 |
+
load_dotenv()
|
| 37 |
+
|
| 38 |
+
GOOGLE_API_KEY=os.getenv('google_api_key')
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class State(TypedDict):
|
| 42 |
+
"""
|
| 43 |
+
A dictionnary representing the state of the agent.
|
| 44 |
+
"""
|
| 45 |
+
messages: Annotated[list, add_messages]
|
| 46 |
+
|
| 47 |
+
#location data
|
| 48 |
+
latitude: str
|
| 49 |
+
longitude: str
|
| 50 |
+
address: str
|
| 51 |
+
#results from place search
|
| 52 |
+
places: dict
|
| 53 |
+
|
| 54 |
+
def get_current_location_node(state: State):
|
| 55 |
+
current_location = geocoder.ip("me")
|
| 56 |
+
if current_location.latlng:
|
| 57 |
+
latitude, longitude = current_location.latlng
|
| 58 |
+
address = current_location.address
|
| 59 |
+
return {'latitude':latitude, 'longitude':longitude, 'address':address}
|
| 60 |
+
else:
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
@tool
|
| 64 |
+
def get_current_location_tool(tool_call_id: Annotated[str, InjectedToolCallId]):
|
| 65 |
+
"""
|
| 66 |
+
Tool to get the current location of the user.
|
| 67 |
+
agrs: none
|
| 68 |
+
"""
|
| 69 |
+
current_location = geocoder.ip("me")
|
| 70 |
+
if current_location.latlng:
|
| 71 |
+
latitude, longitude = current_location.latlng
|
| 72 |
+
address = current_location.address
|
| 73 |
+
return Command(update={'messages':[ToolMessage(F'The current location is: address:{address}, longitude:{longitude},lattitude:{latitude}', tool_call_id=tool_call_id)],
|
| 74 |
+
'latitude':latitude,
|
| 75 |
+
'longitude':longitude,
|
| 76 |
+
'address':address})
|
| 77 |
+
else:
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@tool
|
| 82 |
+
def find_places_near_me(query:str,state: Annotated[dict, InjectedState],tool_call_id: Annotated[str, InjectedToolCallId]):
|
| 83 |
+
"""
|
| 84 |
+
Use this tool to find locations near me.
|
| 85 |
+
args: query - has to be one of the following "car_dealer", "car_rental", "car_repair", "car_wash",
|
| 86 |
+
"electric_vehicle_charging_station", "gas_station", "parking", "rest_stop",
|
| 87 |
+
"corporate_office", "farm", "ranch", "art_gallery", "art_studio", "auditorium",
|
| 88 |
+
"cultural_landmark", "historical_place", "monument", "museum", "performing_arts_theater",
|
| 89 |
+
"sculpture", "library", "preschool", "primary_school", "school", "secondary_school",
|
| 90 |
+
"university", "adventure_sports_center", "amphitheatre", "amusement_center", "amusement_park",
|
| 91 |
+
"aquarium", "banquet_hall", "barbecue_area", "botanical_garden", "bowling_alley", "casino",
|
| 92 |
+
"childrens_camp", "comedy_club", "community_center", "concert_hall", "convention_center",
|
| 93 |
+
"cultural_center", "cycling_park", "dance_hall", "dog_park", "event_venue", "ferris_wheel",
|
| 94 |
+
"garden", "hiking_area", "historical_landmark", "internet_cafe", "karaoke", "marina",
|
| 95 |
+
"movie_rental", "movie_theater", "national_park", "night_club", "observation_deck",
|
| 96 |
+
"off_roading_area", "opera_house", "park", "philharmonic_hall", "picnic_ground", "planetarium",
|
| 97 |
+
"plaza", "roller_coaster", "skateboard_park", "state_park", "tourist_attraction", "video_arcade",
|
| 98 |
+
"visitor_center", "water_park", "wedding_venue", "wildlife_park", "wildlife_refuge", "zoo",
|
| 99 |
+
"public_bath", "public_bathroom", "stable", "accounting", "atm", "bank", "acai_shop",
|
| 100 |
+
"afghani_restaurant", "african_restaurant", "american_restaurant", "asian_restaurant",
|
| 101 |
+
"bagel_shop", "bakery", "bar", "bar_and_grill", "barbecue_restaurant", "brazilian_restaurant",
|
| 102 |
+
"breakfast_restaurant", "brunch_restaurant", "buffet_restaurant", "cafe", "cafeteria",
|
| 103 |
+
"candy_store", "cat_cafe", "chinese_restaurant", "chocolate_factory", "chocolate_shop",
|
| 104 |
+
"coffee_shop", "confectionery", "deli", "dessert_restaurant", "dessert_shop", "diner",
|
| 105 |
+
"dog_cafe", "donut_shop", "fast_food_restaurant", "fine_dining_restaurant", "food_court",
|
| 106 |
+
"french_restaurant", "greek_restaurant", "hamburger_restaurant", "ice_cream_shop", "indian_restaurant",
|
| 107 |
+
"indonesian_restaurant", "italian_restaurant", "japanese_restaurant", "juice_shop",
|
| 108 |
+
"korean_restaurant", "lebanese_restaurant", "meal_delivery", "meal_takeaway",
|
| 109 |
+
"mediterranean_restaurant", "mexican_restaurant", "middle_eastern_restaurant", "pizza_restaurant",
|
| 110 |
+
"pub", "ramen_restaurant", "restaurant", "sandwich_shop", "seafood_restaurant", "spanish_restaurant",
|
| 111 |
+
"steak_house", "sushi_restaurant", "tea_house", "thai_restaurant", "turkish_restaurant",
|
| 112 |
+
"vegan_restaurant", "vegetarian_restaurant", "vietnamese_restaurant", "wine_bar",
|
| 113 |
+
"administrative_area_level_1", "administrative_area_level_2", "country", "locality", "postal_code",
|
| 114 |
+
"school_district", "city_hall", "courthouse", "embassy", "fire_station", "government_office",
|
| 115 |
+
"local_government_office", "neighborhood_police_station", "police", "post_office", "chiropractor",
|
| 116 |
+
"dental_clinic", "dentist", "doctor", "drugstore", "hospital", "massage", "medical_lab", "pharmacy",
|
| 117 |
+
"physiotherapist", "sauna", "skin_care_clinic", "spa", "tanning_studio", "wellness_center", "yoga_studio",
|
| 118 |
+
"apartment_building", "apartment_complex", "condominium_complex", "housing_complex", "bed_and_breakfast",
|
| 119 |
+
"budget_japanese_inn", "campground", "camping_cabin", "cottage", "extended_stay_hotel", "farmstay",
|
| 120 |
+
"guest_house", "hostel", "hotel", "inn", "japanese_inn", "mobile_home_park", "motel", "private_guest_room",
|
| 121 |
+
"resort_hotel", "rv_park", "beach", "church", "hindu_temple", "mosque", "synagogue", "astrologer",
|
| 122 |
+
"barber_shop", "beautician", "beauty_salon", "body_art_service", "catering_service", "cemetery",
|
| 123 |
+
"child_care_agency", "consultant", "courier_service", "electrician", "florist", "food_delivery", "foot_care",
|
| 124 |
+
"funeral_home", "hair_care", "hair_salon", "insurance_agency", "laundry", "lawyer", "locksmith",
|
| 125 |
+
"makeup_artist", "moving_company", "nail_salon", "painter", "plumber", "psychic", "real_estate_agency",
|
| 126 |
+
"roofing_contractor", "storage", "summer_camp_organizer", "tailor", "telecommunications_service_provider",
|
| 127 |
+
"tour_agency", "tourist_information_center", "travel_agency", "veterinary_care", "asian_grocery_store",
|
| 128 |
+
"auto_parts_store", "bicycle_store", "book_store", "butcher_shop", "cell_phone_store", "clothing_store",
|
| 129 |
+
"convenience_store", "department_store", "discount_store", "electronics_store", "food_store",
|
| 130 |
+
"furniture_store", "gift_shop", "grocery_store", "hardware_store", "home_goods_store", "home_improvement_store",
|
| 131 |
+
"jewelry_store", "liquor_store", "market", "pet_store", "shoe_store", "shopping_mall", "sporting_goods_store",
|
| 132 |
+
"store", "supermarket", "warehouse_store", "wholesaler", "arena", "athletic_field", "fishing_charter",
|
| 133 |
+
"fishing_pond", "fitness_center", "golf_course", "gym", "ice_skating_rink", "playground", "ski_resort",
|
| 134 |
+
"sports_activity_location", "sports_club", "sports_coaching", "sports_complex", "stadium", "swimming_pool",
|
| 135 |
+
"airport", "airstrip", "bus_station", "bus_stop", "ferry_terminal", "heliport", "international_airport",
|
| 136 |
+
"light_rail_station", "park_and_ride", "subway_station", "taxi_stand", "train_station", "transit_depot",
|
| 137 |
+
"transit_station", "truck_stop"
|
| 138 |
+
"""
|
| 139 |
+
try:
|
| 140 |
+
my_longitude=state['longitude']
|
| 141 |
+
my_latitude=state['latitude']
|
| 142 |
+
response=requests.get(f'https://maps.googleapis.com/maps/api/place/nearbysearch/json?location={my_latitude}%2C{my_longitude}&radius=500&type={query}&key={GOOGLE_API_KEY}')
|
| 143 |
+
data=response.json()
|
| 144 |
+
places={}
|
| 145 |
+
for place in data['results']:
|
| 146 |
+
try:
|
| 147 |
+
name=place['name']
|
| 148 |
+
rating=place['rating']
|
| 149 |
+
id=place['place_id']
|
| 150 |
+
response=requests.get(f'https://places.googleapis.com/v1/places/{id}?fields=googleMapsLinks.placeUri&key={GOOGLE_API_KEY}')
|
| 151 |
+
data=response.json()
|
| 152 |
+
link=data['googleMapsLinks']['placeUri']
|
| 153 |
+
places[name]= {'rating':rating,
|
| 154 |
+
'google_maps_link':link,
|
| 155 |
+
}
|
| 156 |
+
except Exception as e:
|
| 157 |
+
f'Error: {e}'
|
| 158 |
+
|
| 159 |
+
return Command(update={'places':places,
|
| 160 |
+
'messages':[ToolMessage(f'I found {len(places)} places', tool_call_id=tool_call_id)]})
|
| 161 |
+
except:
|
| 162 |
+
return Command(update={'messages':[ToolMessage('Could not find places based on the query', tool_call_id=tool_call_id)]})
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@tool
|
| 167 |
+
def look_for_places(query: str, tool_call_id: Annotated[str, InjectedToolCallId]):
|
| 168 |
+
"""
|
| 169 |
+
Tool to look for places based on the user query and location.
|
| 170 |
+
Use this tool for more complex user queries like sentences, and if the location is specified in the query.
|
| 171 |
+
Places includes restaurants, bars, speakeasy, games, anything.
|
| 172 |
+
args: query - the query has to be in this format eg.Spicy%20Vegetarian%20Food%20in%20Sydney%20Australia.
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
"""
|
| 176 |
+
try:
|
| 177 |
+
response=requests.get(f'https://maps.googleapis.com/maps/api/place/textsearch/json?query={query}?&key={GOOGLE_API_KEY}')
|
| 178 |
+
data=response.json()
|
| 179 |
+
places={}
|
| 180 |
+
for place in data['results']:
|
| 181 |
+
try:
|
| 182 |
+
name=place['name']
|
| 183 |
+
rating=place['rating']
|
| 184 |
+
id=place['place_id']
|
| 185 |
+
price_level=place['price_level']
|
| 186 |
+
address=place['formatted_address']
|
| 187 |
+
lattitude=place['geometry']['location']['lat']
|
| 188 |
+
longitude=place['geometry']['location']['lng']
|
| 189 |
+
response=requests.get(f'https://places.googleapis.com/v1/places/{id}?fields=googleMapsLinks.placeUri&key={GOOGLE_API_KEY}')
|
| 190 |
+
data=response.json()
|
| 191 |
+
link=data['googleMapsLinks']['placeUri']
|
| 192 |
+
places[name]= {'address': address,
|
| 193 |
+
'rating':rating,
|
| 194 |
+
'Price_level':price_level,
|
| 195 |
+
'google_maps_link':link,
|
| 196 |
+
'longitude':longitude,
|
| 197 |
+
'latitude':lattitude}
|
| 198 |
+
except Exception as e:
|
| 199 |
+
f'Error: {e}'
|
| 200 |
+
|
| 201 |
+
return Command(update={'places':places,
|
| 202 |
+
'messages':[ToolMessage(f'I found {len(places)} places', tool_call_id=tool_call_id)]})
|
| 203 |
+
except Exception as e:
|
| 204 |
+
return f'Error: error'
|
| 205 |
+
|
| 206 |
+
@tool
|
| 207 |
+
def show_places_found(state: Annotated[dict, InjectedState]):
|
| 208 |
+
"""
|
| 209 |
+
Tool to get the places found by previous tool calls and to show/display them.
|
| 210 |
+
It has links within that can also be used for directions
|
| 211 |
+
always show the links
|
| 212 |
+
args: none
|
| 213 |
+
"""
|
| 214 |
+
return state['places']
|
| 215 |
+
|
| 216 |
+
class maps_agent:
|
| 217 |
+
def __init__(self,llm: any):
|
| 218 |
+
self.agent=self._setup(llm)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _setup(self,llm):
|
| 222 |
+
langgraph_tools=[get_current_location_tool,look_for_places, find_places_near_me,show_places_found]
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
graph_builder = StateGraph(State)
|
| 226 |
+
|
| 227 |
+
# Modification: tell the LLM which tools it can call
|
| 228 |
+
llm_with_tools = llm.bind_tools(langgraph_tools)
|
| 229 |
+
tool_node = ToolNode(tools=langgraph_tools)
|
| 230 |
+
def chatbot(state: State):
|
| 231 |
+
""" travel assistant that answers user questions about their trip.
|
| 232 |
+
Depending on the request, leverage which tools to use if necessary."""
|
| 233 |
+
return {"messages": [llm_with_tools.invoke(state['messages'])]}
|
| 234 |
+
|
| 235 |
+
graph_builder.add_node("chatbot", chatbot)
|
| 236 |
+
|
| 237 |
+
graph_builder.add_node('current_location',get_current_location_node)
|
| 238 |
+
graph_builder.add_node("tools", tool_node)
|
| 239 |
+
# Any time a tool is called, we return to the chatbot to decide the next step
|
| 240 |
+
graph_builder.set_entry_point("current_location")
|
| 241 |
+
graph_builder.add_edge('current_location','chatbot')
|
| 242 |
+
graph_builder.add_edge("tools", "chatbot")
|
| 243 |
+
graph_builder.add_conditional_edges(
|
| 244 |
+
"chatbot",
|
| 245 |
+
tools_condition,
|
| 246 |
+
)
|
| 247 |
+
memory=MemorySaver()
|
| 248 |
+
graph=graph_builder.compile(checkpointer=memory)
|
| 249 |
+
return graph
|
| 250 |
+
|
| 251 |
+
def get_state(self, state_val:str):
|
| 252 |
+
config = {"configurable": {"thread_id": "1"}}
|
| 253 |
+
return self.agent.get_state(config).values[state_val]
|
| 254 |
+
|
| 255 |
+
def stream(self,input:str):
|
| 256 |
+
config = {"configurable": {"thread_id": "1"}}
|
| 257 |
+
input_message = HumanMessage(content=input)
|
| 258 |
+
for event in self.agent.stream({"messages": [input_message]}, config, stream_mode="values"):
|
| 259 |
+
event["messages"][-1].pretty_print()
|
| 260 |
+
|
| 261 |
+
def chatbot(self,input:str):
|
| 262 |
+
config = {"configurable": {"thread_id": "1"}}
|
| 263 |
+
response=self.agent.invoke({'messages':HumanMessage(content=str(input))},config)
|
| 264 |
+
return response['messages'][-1].content
|
schedule_agent.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
from langchain.tools import tool
|
| 5 |
+
|
| 6 |
+
from langgraph.graph import StateGraph, START, END
|
| 7 |
+
from langgraph.graph.message import add_messages
|
| 8 |
+
from langgraph.prebuilt import ToolNode, tools_condition,InjectedState
|
| 9 |
+
from langchain_core.messages import (
|
| 10 |
+
SystemMessage,
|
| 11 |
+
HumanMessage,
|
| 12 |
+
AIMessage,
|
| 13 |
+
ToolMessage,
|
| 14 |
+
)
|
| 15 |
+
from langgraph.types import Command, interrupt
|
| 16 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 17 |
+
from langchain_core.tools.base import InjectedToolCallId
|
| 18 |
+
|
| 19 |
+
#structuring
|
| 20 |
+
import ast
|
| 21 |
+
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from typing_extensions import TypedDict
|
| 24 |
+
from typing import Annotated, Literal
|
| 25 |
+
from pydantic import BaseModel, Field
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
import os
|
| 30 |
+
import requests
|
| 31 |
+
import json
|
| 32 |
+
from dotenv import load_dotenv
|
| 33 |
+
from os import listdir
|
| 34 |
+
from os.path import isfile, join
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
load_dotenv()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# loading the necessary api keys
|
| 41 |
+
GOOGLE_API_KEY=os.getenv('google_api_key')
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
GEMINI_MODEL='gemini-2.0-flash'
|
| 47 |
+
|
| 48 |
+
llm = ChatGoogleGenerativeAI(google_api_key=GOOGLE_API_KEY, model=GEMINI_MODEL, temperature=0.3)
|
| 49 |
+
|
| 50 |
+
# state
|
| 51 |
+
class State(TypedDict):
|
| 52 |
+
"""
|
| 53 |
+
A dictionnary representing the state of the agent.
|
| 54 |
+
"""
|
| 55 |
+
messages: Annotated[list, add_messages]
|
| 56 |
+
trip_data: dict
|
| 57 |
+
|
| 58 |
+
# defining the tools for the agent to use
|
| 59 |
+
|
| 60 |
+
@tool
|
| 61 |
+
def local_files_browser(tool_call_id: Annotated[str, InjectedToolCallId]) -> str:
|
| 62 |
+
"""
|
| 63 |
+
tool to list the local schedule files.
|
| 64 |
+
args:none
|
| 65 |
+
"""
|
| 66 |
+
mypath=f'schedules/'
|
| 67 |
+
onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))]
|
| 68 |
+
if not onlyfiles:
|
| 69 |
+
return Command(update={'messages':[ToolMessage(f'No files are available, try to upload one',tool_call_id=tool_call_id)]})
|
| 70 |
+
else:
|
| 71 |
+
return Command(update={'messages':[ToolMessage(f'Here are the available schedules: {onlyfiles}',tool_call_id=tool_call_id)]})
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@tool
|
| 76 |
+
def schedule_loader(tool_call_id: Annotated[str, InjectedToolCallId],state: Annotated[dict, InjectedState],filename: str) -> str:
|
| 77 |
+
"""
|
| 78 |
+
Use this tool to load the schedule from local directory, which is a text file.
|
| 79 |
+
args: filename - the name of the file, include the extention.
|
| 80 |
+
"""
|
| 81 |
+
try:
|
| 82 |
+
with open(f'schedules/{filename}', 'rb') as f:
|
| 83 |
+
schedule=f.read()
|
| 84 |
+
result=llm.invoke(f'format this schedule: {str(schedule)} into a json format in the output, do not include ```json```, do not include comments either')
|
| 85 |
+
try:
|
| 86 |
+
return Command(update={'trip_data':ast.literal_eval(result.content),
|
| 87 |
+
'messages': [ToolMessage('Succesfully uploaded schedule',tool_call_id=tool_call_id)]})
|
| 88 |
+
except:
|
| 89 |
+
return Command(update={'messages': [ToolMessage('something went wrong',tool_call_id=tool_call_id)]})
|
| 90 |
+
except:
|
| 91 |
+
return Command(update={'messages':[ToolMessage('No Schedule please try a different filename, or include the extention eg. filename.txt',tool_call_id=tool_call_id)]},
|
| 92 |
+
goto='local_files_browser')
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@tool
|
| 96 |
+
def schedule_creator(tool_call_id: Annotated[str, InjectedToolCallId], schedule:str)->str:
|
| 97 |
+
"""Tool to create a schedule from the chat with the agent
|
| 98 |
+
and then uses an llm to structure it.
|
| 99 |
+
args: schedule - the schedule from the chat
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
result=llm.invoke(f'format this schedule: {str(schedule)} into a json format in the output, do not include ```json```, do not include comments either')
|
| 103 |
+
return Command(update={'trip_data': ast.literal_eval(result.content),
|
| 104 |
+
'messages':[ToolMessage(f'added a schedule from the chat{ast.literal_eval(result.content)}', tool_call_id=tool_call_id)
|
| 105 |
+
]})
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@tool
|
| 110 |
+
def get_schedule(state: Annotated[dict, InjectedState])-> str:
|
| 111 |
+
"""
|
| 112 |
+
Use this tool to get the information about the schedule once it has been loaded.
|
| 113 |
+
args: none
|
| 114 |
+
return: schedule
|
| 115 |
+
"""
|
| 116 |
+
return state['trip_data']
|
| 117 |
+
|
| 118 |
+
@tool
|
| 119 |
+
def schedule_editor(query:str,state: Annotated[dict, InjectedState],tool_call_id: Annotated[str, InjectedToolCallId])-> str:
|
| 120 |
+
"""
|
| 121 |
+
Tool to make modifications to the schedule such as add, delete or modify.
|
| 122 |
+
Pass the query to the llm to edit the schedule.
|
| 123 |
+
args: query - the query to edit the schedule.
|
| 124 |
+
"""
|
| 125 |
+
file=state['trip_data']
|
| 126 |
+
result=llm.invoke(f'Edit this schedule: {str(file)} following the instructions in the query: {query}, and include the changes in the schedule, but do not mention them specifically, only include the updated schedule json format in the output, do not include ```json```, do not include comments either')
|
| 127 |
+
try:
|
| 128 |
+
return Command(
|
| 129 |
+
update={'trip_data':ast.literal_eval(result.content),
|
| 130 |
+
'messages':[ToolMessage(f'edited the schedule with these changes:{ast.literal_eval(result.content)} ', tool_call_id=tool_call_id)
|
| 131 |
+
]})
|
| 132 |
+
except:
|
| 133 |
+
return Command(
|
| 134 |
+
update={'trip_data':result.content,
|
| 135 |
+
'messages':[ToolMessage(f'edited the schedule with these changes:{result.content}, but formating failed ', tool_call_id=tool_call_id)
|
| 136 |
+
]})
|
| 137 |
+
|
| 138 |
+
@tool
|
| 139 |
+
def save_schedule(state: Annotated[dict, InjectedState],tool_call_id: Annotated[str, InjectedToolCallId], filename: str) -> str:
|
| 140 |
+
"""
|
| 141 |
+
Tool to save the schedule with a specified filename.
|
| 142 |
+
agrs: filename the name of the file, no need to include the extentions of the file
|
| 143 |
+
"""
|
| 144 |
+
file= state['trip_data']
|
| 145 |
+
with open(f"schedules/{filename}.txt", "w") as f:
|
| 146 |
+
f.write(file)
|
| 147 |
+
return f'{filename} saved'
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Schedule_agent:
|
| 153 |
+
def __init__(self,llm:any):
|
| 154 |
+
self.agent=self._setup(llm)
|
| 155 |
+
def _setup(self,llm):
|
| 156 |
+
|
| 157 |
+
langgraph_tools=[get_schedule,schedule_creator,local_files_browser, save_schedule, schedule_editor,schedule_loader]
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
graph_builder = StateGraph(State)
|
| 161 |
+
|
| 162 |
+
# Modification: tell the LLM which tools it can call
|
| 163 |
+
llm_with_tools = llm.bind_tools(langgraph_tools)
|
| 164 |
+
tool_node = ToolNode(tools=langgraph_tools)
|
| 165 |
+
def chatbot(state: State):
|
| 166 |
+
""" travel assistant that answers user questions about their trip.
|
| 167 |
+
Depending on the request, leverage which tools to use if necessary."""
|
| 168 |
+
return {"messages": [llm_with_tools.invoke(state['messages'])]}
|
| 169 |
+
|
| 170 |
+
graph_builder.add_node("chatbot", chatbot)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
graph_builder.add_node("tools", tool_node)
|
| 174 |
+
# Any time a tool is called, we return to the chatbot to decide the next step
|
| 175 |
+
graph_builder.set_entry_point("chatbot")
|
| 176 |
+
graph_builder.add_edge("tools", "chatbot")
|
| 177 |
+
graph_builder.add_conditional_edges(
|
| 178 |
+
"chatbot",
|
| 179 |
+
tools_condition,
|
| 180 |
+
)
|
| 181 |
+
memory=MemorySaver()
|
| 182 |
+
graph=graph_builder.compile(checkpointer=memory)
|
| 183 |
+
return graph
|
| 184 |
+
|
| 185 |
+
def stream(self,input:str):
|
| 186 |
+
config = {"configurable": {"thread_id": "1"}}
|
| 187 |
+
input_message = HumanMessage(content=input)
|
| 188 |
+
for event in self.agent.stream({"messages": [input_message]}, config, stream_mode="values"):
|
| 189 |
+
event["messages"][-1].pretty_print()
|
| 190 |
+
|
| 191 |
+
def chatbot(self,input:str):
|
| 192 |
+
config = {"configurable": {"thread_id": "1"}}
|
| 193 |
+
response=self.agent.invoke({'messages':HumanMessage(content=str(input))},config)
|
| 194 |
+
return response['messages'][-1].content
|
| 195 |
+
|
| 196 |
+
def get_state(self, state_val:str):
|
| 197 |
+
config = {"configurable": {"thread_id": "1"}}
|
| 198 |
+
return self.agent.get_state(config).values[state_val]
|
schedules/test.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
7:00 AM: Wake up and stretch
|
| 2 |
+
7:30 AM: Breakfast
|
| 3 |
+
8:00 AM: Work on project A
|
| 4 |
+
10:00 AM: Morning break
|
| 5 |
+
10:15 AM: Work on project B
|
| 6 |
+
12:00 PM: Lunch break
|
| 7 |
+
1:00 PM: Client meeting
|
| 8 |
+
2:00 PM: Work on emails
|
| 9 |
+
3:30 PM: Afternoon break
|
| 10 |
+
4:00 PM: Work on project A
|
| 11 |
+
5:30 PM: Exercise or go for a walk
|
| 12 |
+
6:00 PM: Dinner
|
| 13 |
+
7:00 PM: Relax and watch TV
|
| 14 |
+
8:00 PM: Read a book
|
| 15 |
+
9:00 PM: Prepare for bed
|
| 16 |
+
10:00 PM: Sleep
|
travel_agent.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from maps_agent import maps_agent
|
| 2 |
+
from schedule_agent import Schedule_agent
|
| 3 |
+
|
| 4 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 5 |
+
from langchain.agents import load_tools
|
| 6 |
+
from langchain.tools import Tool,tool,StructuredTool
|
| 7 |
+
|
| 8 |
+
from langgraph.graph import StateGraph
|
| 9 |
+
from langgraph.graph.message import add_messages
|
| 10 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
| 11 |
+
from langchain_core.messages import (
|
| 12 |
+
HumanMessage,
|
| 13 |
+
)
|
| 14 |
+
from pydantic import BaseModel, Field
|
| 15 |
+
import pytz
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
|
| 18 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 19 |
+
|
| 20 |
+
from typing_extensions import TypedDict
|
| 21 |
+
from typing import Annotated
|
| 22 |
+
#get graph visuals
|
| 23 |
+
from IPython.display import Image, display
|
| 24 |
+
from langchain_core.runnables.graph import MermaidDrawMethod
|
| 25 |
+
import os
|
| 26 |
+
import requests
|
| 27 |
+
from dotenv import load_dotenv
|
| 28 |
+
load_dotenv()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
GOOGLE_API_KEY=os.getenv('google_api_key')
|
| 32 |
+
pse=os.getenv('pse')
|
| 33 |
+
OPENWEATHERMAP_API_KEY=os.getenv('open_weather_key')
|
| 34 |
+
os.environ['OPENWEATHERMAP_API_KEY']=OPENWEATHERMAP_API_KEY
|
| 35 |
+
|
| 36 |
+
GEMINI_MODEL='gemini-2.0-flash'
|
| 37 |
+
llm = ChatGoogleGenerativeAI(google_api_key=GOOGLE_API_KEY, model=GEMINI_MODEL, temperature=0.3)
|
| 38 |
+
|
| 39 |
+
class State(TypedDict):
|
| 40 |
+
messages:Annotated[list, add_messages]
|
| 41 |
+
|
| 42 |
+
maps_ai=maps_agent(llm)
|
| 43 |
+
schedule_ai=Schedule_agent(llm)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# extra tools for question answering
|
| 47 |
+
|
| 48 |
+
# initializing time and date tool
|
| 49 |
+
|
| 50 |
+
#creating a schema
|
| 51 |
+
class time_tool_schema(BaseModel):
|
| 52 |
+
continent: str = Field(description='continent')
|
| 53 |
+
city: str = Field(description='city')
|
| 54 |
+
|
| 55 |
+
def date_time_tool(continent: str,city: str) -> str:
|
| 56 |
+
"""
|
| 57 |
+
tool to get the current date and time in a city.
|
| 58 |
+
|
| 59 |
+
"""
|
| 60 |
+
city=city.replace(' ','_')
|
| 61 |
+
continent=continent.replace(' ','_')
|
| 62 |
+
query=continent+'/'+city
|
| 63 |
+
timezone = pytz.timezone(query)
|
| 64 |
+
# Get the current time in UTC, and then convert it to the Marrakech timezone
|
| 65 |
+
utc_now = datetime.now(pytz.utc) # Get current time in UTC
|
| 66 |
+
localized_time = utc_now.astimezone(timezone) # Convert to Marrakech time
|
| 67 |
+
time=localized_time.strftime('%Y-%m-%d %H:%M:%S')
|
| 68 |
+
return time
|
| 69 |
+
|
| 70 |
+
current_date_time_tool=StructuredTool.from_function(name='current_date_time_tool', func=date_time_tool, description='To get the current date and time in any city',args_schema=time_tool_schema, return_direct=True)
|
| 71 |
+
|
| 72 |
+
def google_image_search(query: str) -> str:
|
| 73 |
+
"""Search for images using Google Custom Search API
|
| 74 |
+
args: query
|
| 75 |
+
return: image url
|
| 76 |
+
"""
|
| 77 |
+
# Define the API endpoint for Google Custom Search
|
| 78 |
+
url = "https://www.googleapis.com/customsearch/v1"
|
| 79 |
+
|
| 80 |
+
params = {
|
| 81 |
+
"q": query,
|
| 82 |
+
"cx": pse,
|
| 83 |
+
"key": GOOGLE_API_KEY,
|
| 84 |
+
"searchType": "image", # Search for images
|
| 85 |
+
"num": 1 # Number of results to fetch
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
# Make the request to the Google Custom Search API
|
| 89 |
+
response = requests.get(url, params=params)
|
| 90 |
+
data = response.json()
|
| 91 |
+
|
| 92 |
+
# Check if the response contains image results
|
| 93 |
+
if 'items' in data:
|
| 94 |
+
# Extract the first image result
|
| 95 |
+
image_url = data['items'][0]['link']
|
| 96 |
+
return image_url
|
| 97 |
+
else:
|
| 98 |
+
return "Sorry, no images were found for your query."
|
| 99 |
+
|
| 100 |
+
google_image_tool=Tool(name='google_image_tool', func=google_image_search, description='Use this tool to search for images using Google Custom Search API')
|
| 101 |
+
|
| 102 |
+
@tool
|
| 103 |
+
def schedule_manager(query:str):
|
| 104 |
+
"""
|
| 105 |
+
Use this tool for any schedule related queries
|
| 106 |
+
this tool can:
|
| 107 |
+
list the local files
|
| 108 |
+
load a schedule
|
| 109 |
+
make edits to the schedule
|
| 110 |
+
answer questions about the schedule
|
| 111 |
+
save the schedule
|
| 112 |
+
args:query - pass the schedule related queries directly here
|
| 113 |
+
"""
|
| 114 |
+
response=schedule_ai.chatbot(str(query))
|
| 115 |
+
return response
|
| 116 |
+
|
| 117 |
+
@tool
|
| 118 |
+
def maps_tool(query: str):
|
| 119 |
+
"""
|
| 120 |
+
Use this tool for any maps or location related queries
|
| 121 |
+
all the context is provided in the tool, simply pass the query
|
| 122 |
+
this tool can:
|
| 123 |
+
get the current location
|
| 124 |
+
find nearby places
|
| 125 |
+
find places in different locations
|
| 126 |
+
show the places that have been found
|
| 127 |
+
args:query - maps or location related queries
|
| 128 |
+
"""
|
| 129 |
+
response=maps_ai.chatbot(str(query))
|
| 130 |
+
return response
|
| 131 |
+
|
| 132 |
+
class travel_agent:
|
| 133 |
+
def __init__(self,llm: any):
|
| 134 |
+
self.agent=self._setup(llm)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _setup(self,llm):
|
| 138 |
+
api_tools=load_tools(['openweathermap-api','wikipedia'])
|
| 139 |
+
langgraph_tools=[current_date_time_tool,google_image_tool,schedule_manager,maps_tool]+api_tools
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
graph_builder = StateGraph(State)
|
| 143 |
+
|
| 144 |
+
# Modification: tell the LLM which tools it can call
|
| 145 |
+
llm_with_tools = llm.bind_tools(langgraph_tools)
|
| 146 |
+
tool_node = ToolNode(tools=langgraph_tools)
|
| 147 |
+
def chatbot(state: State):
|
| 148 |
+
""" travel assistant that answers user questions about their trip.
|
| 149 |
+
Depending on the request, leverage which tools to use if necessary."""
|
| 150 |
+
return {"messages": [llm_with_tools.invoke(state['messages'])]}
|
| 151 |
+
|
| 152 |
+
graph_builder.add_node("chatbot", chatbot)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
graph_builder.add_node("tools", tool_node)
|
| 156 |
+
# Any time a tool is called, we return to the chatbot to decide the next step
|
| 157 |
+
graph_builder.set_entry_point("chatbot")
|
| 158 |
+
|
| 159 |
+
graph_builder.add_edge("tools", "chatbot")
|
| 160 |
+
graph_builder.add_conditional_edges(
|
| 161 |
+
"chatbot",
|
| 162 |
+
tools_condition,
|
| 163 |
+
)
|
| 164 |
+
memory=MemorySaver()
|
| 165 |
+
graph=graph_builder.compile(checkpointer=memory)
|
| 166 |
+
return graph
|
| 167 |
+
|
| 168 |
+
def display_graph(self):
|
| 169 |
+
return display(
|
| 170 |
+
Image(
|
| 171 |
+
self.agent.get_graph().draw_mermaid_png(
|
| 172 |
+
draw_method=MermaidDrawMethod.API,
|
| 173 |
+
)
|
| 174 |
+
)
|
| 175 |
+
)
|
| 176 |
+
def get_state(self, state_val:str):
|
| 177 |
+
config = {"configurable": {"thread_id": "1"}}
|
| 178 |
+
return self.agent.get_state(config).values[state_val]
|
| 179 |
+
|
| 180 |
+
def stream(self,input:str):
|
| 181 |
+
config = {"configurable": {"thread_id": "1"}}
|
| 182 |
+
input_message = HumanMessage(content=input)
|
| 183 |
+
for event in self.agent.stream({"messages": [input_message]}, config, stream_mode="values"):
|
| 184 |
+
event["messages"][-1].pretty_print()
|
| 185 |
+
|
| 186 |
+
def chatbot(self,input:str):
|
| 187 |
+
config = {"configurable": {"thread_id": "1"}}
|
| 188 |
+
response=self.agent.invoke({'messages':HumanMessage(content=str(input))},config)
|
| 189 |
+
return response['messages'][-1].content
|