Mohammad Haghir commited on
Commit
2f7b616
·
1 Parent(s): a18d877
Files changed (2) hide show
  1. agent_utils.py +26 -0
  2. app.py +43 -44
agent_utils.py CHANGED
@@ -112,3 +112,29 @@ def handle_file_tool(input: dict) -> str:
112
  else:
113
  return f"Unsupported file type: .{file_ext}"
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  else:
113
  return f"Unsupported file type: .{file_ext}"
114
 
115
+ @tool
116
+ def add(a: float, b: float):
117
+ """calculate summation of two numbers"""
118
+ return a + b
119
+
120
+ @tool
121
+ def subtract(a: float, b: float):
122
+ """calculate subtraction of two numbers"""
123
+ return a - b
124
+
125
+ @tool
126
+ def multiplication(a: float, b: float):
127
+ """calculate multiplication of two numbers"""
128
+ return a * b
129
+
130
+ @tool
131
+ def division(a: float, b: float):
132
+ """calculate division of two numbers"""
133
+
134
+ return a / b
135
+
136
+ @tool
137
+ def mode(a: float, b: float):
138
+ """calculate remainder of two numbers"""
139
+
140
+ return a % b
app.py CHANGED
@@ -14,23 +14,30 @@ from typing import Annotated
14
  from langchain_groq import ChatGroq
15
  from langchain_core.messages import HumanMessage
16
 
17
- from langgraph.graph import START, END, StateGraph
18
  from langgraph.prebuilt import ToolNode, tools_condition
19
 
20
- from agent_utils import wiki_ret, arxiv_ret, tavily_ret, handle_file_tool
21
 
22
  # (Keep Constants as is)
23
  # --- Constants ---
24
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
25
 
26
  groq_api_key = os.getenv("GROQ_API_KEY")
27
- tools = [wiki_ret, arxiv_ret, tavily_ret, handle_file_tool]
28
- llm = ChatGroq(api_key=groq_api_key, model="gemma2-9b-it")
29
  llm_with_tools = llm.bind_tools(tools)
30
  class GraphState(TypedDict):
31
  messages: str #Annotated[Dict, operator.add]
32
  context: str
33
 
 
 
 
 
 
 
 
34
  # --- Basic Agent Definition ---
35
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
36
  class BasicAgent:
@@ -40,48 +47,39 @@ class BasicAgent:
40
 
41
  def __call__(self, question: str) -> str:
42
  print("question: ", question)
43
- response = (self.graph).invoke({"messages": question})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  return response["messages"]
45
 
46
  def agent(self, state: GraphState):
47
 
48
- # print(f"Agent received question (first 50 chars): {question[:50]}...")
49
- # fixed_answer = "This is a default answer. --- 1"
50
- # print(f"Agent returning fixed answer: {fixed_answer}")
51
- # context = self.wiki_ret(question)
52
- context = state.get("context", "")
53
- question = state.get("messages", "")
54
- print("question agent: ", question)
55
- print("context: ", context)
56
- prompt = f"""
57
- You are a general AI assistant. I will ask you a question.
58
- YOUR FINAL ANSWER should be a number OR
59
- as few words as possible OR a comma separated list of numbers and/or strings.
60
- If you are asked for a number, don't use comma to write your number neither use
61
- units such as $ or percent sign unless specified otherwise. If you are asked for
62
- a string, don't use articles, neither abbreviations (e.g. for cities), and write
63
- the digits in plain text unless specified otherwise. If you are asked for a comma
64
- separated list, apply the above rules depending of whether the element to be put
65
- in the list is a number or a string. Use the tools available to you to answer the question. Question: {question}
66
- For answering the question use this context: {context}, if no context is provided
67
- use your knowledge to answer the question."""
68
- # Your answer must be in the following format:
69
-
70
- # {{"task_id": "task_id_1", "model_answer": "Answer 1 from your model", "reasoning_trace": "The different steps by which your model reached answer 1"}}
71
- # {{"task_id": "task_id_2", "model_answer": "Answer 2 from your model", "reasoning_trace": "The different steps by which your model reached answer 2"}}
72
-
73
- # Just make up a task_id.
74
- # Call the LLM
75
- messages = [HumanMessage(content=prompt)]
76
- print("messages: ", messages)
77
  response = llm_with_tools.invoke(messages)
78
  print("response: ", response)
79
- # cleaned_text = re.sub(r"<think>.*?</think>", "", response.content, flags=re.DOTALL)
80
-
81
- # json_start = response.find('{')
82
- # json_end = response.rfind('}') + 1
83
- # json_str = response[json_start:json_end]
84
- # res = json.loads(json_str)
85
  return {"messages": response}
86
 
87
  def create_graph(self):
@@ -90,12 +88,13 @@ class BasicAgent:
90
  builder.add_node("tools", ToolNode(tools = tools))
91
 
92
  builder.add_edge(START, "agent")
93
- builder.add_conditional_edges("agent", tools_condition, ["agent", END])
 
94
  builder.add_edge("agent", END)
95
  graph = builder.compile()
96
- image = graph.get_graph().draw_mermaid_png()
97
- with open("output_graph.png", "wb") as file:
98
- file.write(image)
99
  return graph
100
 
101
 
 
14
  from langchain_groq import ChatGroq
15
  from langchain_core.messages import HumanMessage
16
 
17
+ from langgraph.graph import START, END, StateGraph, MessagesState
18
  from langgraph.prebuilt import ToolNode, tools_condition
19
 
20
+ from agent_utils import wiki_ret, arxiv_ret, tavily_ret, handle_file_tool, add, subtract, division, multiplication, mode
21
 
22
  # (Keep Constants as is)
23
  # --- Constants ---
24
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
25
 
26
  groq_api_key = os.getenv("GROQ_API_KEY")
27
+ tools = [wiki_ret, arxiv_ret, tavily_ret, handle_file_tool, add, subtract, division, multiplication, mode]
28
+ llm = ChatGroq(api_key=groq_api_key, model="qwen-qwq-32b")
29
  llm_with_tools = llm.bind_tools(tools)
30
  class GraphState(TypedDict):
31
  messages: str #Annotated[Dict, operator.add]
32
  context: str
33
 
34
+ # --- Basic Agent Definition ---
35
+ # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
36
+ llm_with_tools = llm.bind_tools(tools)
37
+ class GraphState(MessagesState):
38
+ # messages: Annotated[BaseMessage, operator.add]
39
+ context: str
40
+
41
  # --- Basic Agent Definition ---
42
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
43
  class BasicAgent:
 
47
 
48
  def __call__(self, question: str) -> str:
49
  print("question: ", question)
50
+
51
+ prompt = f"""
52
+ You are an AI assistant designed to answer user questions using available tools.
53
+ Provide your final answer in one of the following formats:
54
+
55
+ A plain number (without commas, currency symbols, or percent signs unless explicitly requested).
56
+
57
+ A concise phrase (no articles or abbreviations).
58
+
59
+ With as few words as possible.
60
+
61
+ A comma-separated list of numbers and/or strings, following the above rules.
62
+
63
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
64
+
65
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
66
+
67
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
68
+
69
+ Use the tools at your disposal to find the correct answer. Question: {question}
70
+ """
71
+
72
+ messages = [HumanMessage(content=prompt)]
73
+ response = (self.graph).invoke({"messages": messages})
74
  return response["messages"]
75
 
76
  def agent(self, state: GraphState):
77
 
78
+
79
+ messages = state["messages"]
80
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  response = llm_with_tools.invoke(messages)
82
  print("response: ", response)
 
 
 
 
 
 
83
  return {"messages": response}
84
 
85
  def create_graph(self):
 
88
  builder.add_node("tools", ToolNode(tools = tools))
89
 
90
  builder.add_edge(START, "agent")
91
+ builder.add_conditional_edges("agent", tools_condition, ["tools", END])
92
+ builder.add_edge("tools", "agent")
93
  builder.add_edge("agent", END)
94
  graph = builder.compile()
95
+ # image = graph.get_graph().draw_mermaid_png()
96
+ # with open("output_graph.png", "wb") as file:
97
+ # file.write(image)
98
  return graph
99
 
100