DSA_Search / workflow.py
Jorge Londoño
Updates
66bb091
# https://langchain-ai.github.io/langgraph/how-tos/memory/manage-conversation-history/#build-the-agent
# https://docs.tavily.com/docs/rest-api/api-reference
import os
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import MessagesState, StateGraph, START, END
from langgraph.prebuilt import ToolNode
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
memory = MemorySaver()
dsa_search_domains = [
# DSA
'algs4.cs.princeton.edu',
'chalmersgu-data-structure-courses.github.io',
'pressbooks.palni.org/anopenguidetodatastructuresandalgorithms',
'en.wikibooks.org/wiki/Algorithms',
'people.mpi-inf.mpg.de',
'jeffe.cs.illinois.edu/teaching/algorithms',
'opendatastructures.org',
'github.com/aibooks14',
'open.umn.edu/opentextbooks',
'opendsa-server.cs.vt.edu/OpenDSA/Books',
'www.programiz.com'
# Discrete Math
'discrete.openmathbooks.org',
'stephendavies.org',
'www.fecundity.com',
'ocw.mit.edu',
'discretemath.org',
'www.khanacademy.org',
# More general
'www.w3schools.com',
'www.geeksforgeeks.org',
'leetcode.com',
'www.hackerrank.com',
'www.freecodecamp.org',
'www.codechef.com',
'www.w3resource.com',
'www.hackerearth.com',
'openstax.org'
]
tavily_search = TavilySearchResults(max_results=8, verbose=False, include_domains=dsa_search_domains)
tools = [tavily_search]
tool_node = ToolNode(tools)
# from langchain_openai import ChatOpenAI
# logger.debug(f"LLM model: {os.getenv('OPENAI_MODEL_NAME')}")
# llm = ChatOpenAI(model=os.getenv("OPENAI_MODEL_NAME"))
# from huggingface_hub import InferenceClient
# llm = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
# DOESN'T support bind_tools
# from langchain_huggingface import HuggingFaceEndpoint
# llm = HuggingFaceEndpoint(
# repo_id="HuggingFaceH4/zephyr-7b-beta",
# max_length=128,
# temperature=0.5,
# huggingfacehub_api_token=os.getenv('HUGGINGFACEHUB_API_TOKEN'),
# )
from langchain_groq import ChatGroq
logger.debug(f"LLM model: {os.getenv('GROQ_MODEL')}")
llm = ChatGroq(model_name=os.getenv('GROQ_MODEL'), temperature=0.1)
# from langchain_mistralai import ChatMistralAI
# logger.debug(f"LLM model: {os.getenv('MISTRAL_MODEL_NAME')}")
# llm = ChatMistralAI(
# model=os.getenv('MISTRAL_MODEL_NAME'),
# temperature=0,
# max_retries=2,
# )
bound_model = llm.bind_tools(tools)
def should_continue(state: MessagesState):
"""Return the next node to execute."""
last_message = state["messages"][-1]
logger.debug(f'***should_continue*** : last_message = {last_message}')
if not last_message.tool_calls:
return END
return "action"
def search_agent(state: MessagesState):
response = bound_model.invoke(state["messages"])
return {"messages": response}
# Define a new graph
workflow = StateGraph(MessagesState)
# Define the two nodes we will cycle between
workflow.add_node("agent", search_agent)
workflow.add_node("action", tool_node)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.add_edge(START, "agent")
# We now add a conditional edge
workflow.add_conditional_edges(
"agent",
should_continue,
["action", END],
)
workflow.add_edge("action", "agent")
app = workflow.compile(checkpointer=memory)