Paperbag commited on
Commit
e75d735
·
1 Parent(s): 5c92669

Enhance agent functionality by integrating message handling and graph structure. Added read_message and answer_message functions to process user input and generate responses using the chat model. Updated AgentState to support both HumanMessage and AIMessage types.

Browse files
Files changed (2) hide show
  1. agent.py +37 -4
  2. app.py +10 -3
agent.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  from typing import TypedDict, List, Dict, Any, Optional
3
  from langgraph.graph import StateGraph, START, END
4
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFacePipeline
5
- from langchain_core.messages import HumanMessage
6
 
7
  # Base Hugging Face LLM used by the chat wrapper
8
  base_llm = HuggingFaceEndpoint(
@@ -12,10 +12,43 @@ base_llm = HuggingFaceEndpoint(
12
  huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
13
  )
14
 
 
15
  model = ChatHuggingFace(llm=base_llm)
16
 
 
17
  class AgentState(TypedDict):
18
- messages: Dict[str, Any]
19
- answer: Dict[str, Any]
20
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
2
  from typing import TypedDict, List, Dict, Any, Optional
3
  from langgraph.graph import StateGraph, START, END
4
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFacePipeline
5
+ from langchain_core.messages import HumanMessage, AIMessage
6
 
7
  # Base Hugging Face LLM used by the chat wrapper
8
  base_llm = HuggingFaceEndpoint(
 
12
  huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
13
  )
14
 
15
+ # Chat model that works with LangGraph
16
  model = ChatHuggingFace(llm=base_llm)
17
 
18
+
19
  class AgentState(TypedDict):
20
+ messages: List[HumanMessage | AIMessage]
21
+
22
+
23
+ def read_message(state: AgentState) -> AgentState:
24
+ messages = state["messages"]
25
+ print(f"Processing question: {messages[-1].content if messages else ''}")
26
+ # Just pass the messages through to the next node
27
+ return {"messages": messages}
28
+
29
+
30
+ def answer_message(state: AgentState) -> AgentState:
31
+ messages = state["messages"]
32
+ # Invoke the chat model with the conversation so far
33
+ response = model.invoke(messages)
34
+ # Append the model's answer to the messages list
35
+ return {"messages": messages + [response]}
36
+
37
+
38
+ def build_graph():
39
+ agent_graph = StateGraph(AgentState)
40
+
41
+ # Add nodes
42
+ agent_graph.add_node("read_message", read_message)
43
+ agent_graph.add_node("answer_message", answer_message)
44
+
45
+ # Add edges
46
+ agent_graph.add_edge(START, "read_message")
47
+ agent_graph.add_edge("read_message", "answer_message")
48
+
49
+ # Final edge
50
+ agent_graph.add_edge("answer_message", END)
51
 
52
+ # Compile and return the executable graph for use in app.py
53
+ compiled_graph = agent_graph.compile()
54
+ return compiled_graph
app.py CHANGED
@@ -3,6 +3,8 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
@@ -13,11 +15,16 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
13
  class BasicAgent:
14
  def __init__(self):
15
  print("BasicAgent initialized.")
 
 
16
  def __call__(self, question: str) -> str:
17
  print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from langchain_core.messages import HumanMessage
7
+ from agent import build_graph
8
 
9
  # (Keep Constants as is)
10
  # --- Constants ---
 
15
  class BasicAgent:
16
  def __init__(self):
17
  print("BasicAgent initialized.")
18
+ self.graph = build_graph()
19
+
20
  def __call__(self, question: str) -> str:
21
  print(f"Agent received question (first 50 chars): {question[:50]}...")
22
+ messages = [HumanMessage(content=question)]
23
+ result = self.graph.invoke({"messages": messages})
24
+ answer = result['messages'][-1].content
25
+ print(f"Agent returning answer: {answer}")
26
+ return answer
27
+
28
 
29
  def run_and_submit_all( profile: gr.OAuthProfile | None):
30
  """