| from langgraph.graph.message import add_messages |
| from langchain_core.messages import AnyMessage, HumanMessage, AIMessage |
| from langgraph.prebuilt import ToolNode |
| from langgraph.graph import START, StateGraph |
| from langgraph.prebuilt import tools_condition |
| from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace |
| from tools import extract_text |
| from langchain_community.tools import DuckDuckGoSearchRun |
|
|
|
|
| class AgentState(TypedDict): |
| messages: Annotated[list[AnyMessage], add_messages] |
|
|
|
|
| class Agent(): |
| def __init__(self, llm): |
| chat = ChatHuggingFace(llm=llm, verbose=True) |
| |
| search_tool = DuckDuckGoSearchRun() |
| vision_llm = ChatOpenAI(model="gpt-4o") |
|
|
| chat_with_tools = chat.bind_tools([extract_text,search_tool]) |
| self._initialize_graph() |
|
|
| def _initialize_graph(self): |
| builder = StateGraph(AgentState) |
|
|
| |
| builder.add_node("assistant", self.assistant) |
| builder.add_node("tools", ToolNode(self.tools)) |
|
|
| |
| builder.add_edge(START, "assistant") |
| builder.add_conditional_edges("assistant",tools_condition) |
| builder.add_edge("tools", "assistant") |
|
|
| |
| self.agent = builder.compile() |
|
|
| def call_agent(self, messages): |
| self.agent.invoke({"messages":messages}) |
|
|
| def assistant(state: AgentState): |
| return { |
| "messages": [chat_with_tools.invoke(state["messages"])], |
| } |
|
|
|
|
|
|
| |