Mohammad Haghir commited on
Commit ·
2f7b616
1
Parent(s): a18d877
update
Browse files- agent_utils.py +26 -0
- app.py +43 -44
agent_utils.py
CHANGED
|
@@ -112,3 +112,29 @@ def handle_file_tool(input: dict) -> str:
|
|
| 112 |
else:
|
| 113 |
return f"Unsupported file type: .{file_ext}"
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
else:
|
| 113 |
return f"Unsupported file type: .{file_ext}"
|
| 114 |
|
| 115 |
+
@tool
|
| 116 |
+
def add(a: float, b: float):
|
| 117 |
+
"""calculate summation of two numbers"""
|
| 118 |
+
return a + b
|
| 119 |
+
|
| 120 |
+
@tool
|
| 121 |
+
def subtract(a: float, b: float):
|
| 122 |
+
"""calculate subtraction of two numbers"""
|
| 123 |
+
return a - b
|
| 124 |
+
|
| 125 |
+
@tool
|
| 126 |
+
def multiplication(a: float, b: float):
|
| 127 |
+
"""calculate multiplication of two numbers"""
|
| 128 |
+
return a * b
|
| 129 |
+
|
| 130 |
+
@tool
|
| 131 |
+
def division(a: float, b: float):
|
| 132 |
+
"""calculate division of two numbers"""
|
| 133 |
+
|
| 134 |
+
return a / b
|
| 135 |
+
|
| 136 |
+
@tool
|
| 137 |
+
def mode(a: float, b: float):
|
| 138 |
+
"""calculate remainder of two numbers"""
|
| 139 |
+
|
| 140 |
+
return a % b
|
app.py
CHANGED
|
@@ -14,23 +14,30 @@ from typing import Annotated
|
|
| 14 |
from langchain_groq import ChatGroq
|
| 15 |
from langchain_core.messages import HumanMessage
|
| 16 |
|
| 17 |
-
from langgraph.graph import START, END, StateGraph
|
| 18 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 19 |
|
| 20 |
-
from agent_utils import wiki_ret, arxiv_ret, tavily_ret, handle_file_tool
|
| 21 |
|
| 22 |
# (Keep Constants as is)
|
| 23 |
# --- Constants ---
|
| 24 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 25 |
|
| 26 |
groq_api_key = os.getenv("GROQ_API_KEY")
|
| 27 |
-
tools = [wiki_ret, arxiv_ret, tavily_ret, handle_file_tool]
|
| 28 |
-
llm = ChatGroq(api_key=groq_api_key, model="
|
| 29 |
llm_with_tools = llm.bind_tools(tools)
|
| 30 |
class GraphState(TypedDict):
|
| 31 |
messages: str #Annotated[Dict, operator.add]
|
| 32 |
context: str
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# --- Basic Agent Definition ---
|
| 35 |
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
|
| 36 |
class BasicAgent:
|
|
@@ -40,48 +47,39 @@ class BasicAgent:
|
|
| 40 |
|
| 41 |
def __call__(self, question: str) -> str:
|
| 42 |
print("question: ", question)
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
return response["messages"]
|
| 45 |
|
| 46 |
def agent(self, state: GraphState):
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
# context = self.wiki_ret(question)
|
| 52 |
-
context = state.get("context", "")
|
| 53 |
-
question = state.get("messages", "")
|
| 54 |
-
print("question agent: ", question)
|
| 55 |
-
print("context: ", context)
|
| 56 |
-
prompt = f"""
|
| 57 |
-
You are a general AI assistant. I will ask you a question.
|
| 58 |
-
YOUR FINAL ANSWER should be a number OR
|
| 59 |
-
as few words as possible OR a comma separated list of numbers and/or strings.
|
| 60 |
-
If you are asked for a number, don't use comma to write your number neither use
|
| 61 |
-
units such as $ or percent sign unless specified otherwise. If you are asked for
|
| 62 |
-
a string, don't use articles, neither abbreviations (e.g. for cities), and write
|
| 63 |
-
the digits in plain text unless specified otherwise. If you are asked for a comma
|
| 64 |
-
separated list, apply the above rules depending of whether the element to be put
|
| 65 |
-
in the list is a number or a string. Use the tools available to you to answer the question. Question: {question}
|
| 66 |
-
For answering the question use this context: {context}, if no context is provided
|
| 67 |
-
use your knowledge to answer the question."""
|
| 68 |
-
# Your answer must be in the following format:
|
| 69 |
-
|
| 70 |
-
# {{"task_id": "task_id_1", "model_answer": "Answer 1 from your model", "reasoning_trace": "The different steps by which your model reached answer 1"}}
|
| 71 |
-
# {{"task_id": "task_id_2", "model_answer": "Answer 2 from your model", "reasoning_trace": "The different steps by which your model reached answer 2"}}
|
| 72 |
-
|
| 73 |
-
# Just make up a task_id.
|
| 74 |
-
# Call the LLM
|
| 75 |
-
messages = [HumanMessage(content=prompt)]
|
| 76 |
-
print("messages: ", messages)
|
| 77 |
response = llm_with_tools.invoke(messages)
|
| 78 |
print("response: ", response)
|
| 79 |
-
# cleaned_text = re.sub(r"<think>.*?</think>", "", response.content, flags=re.DOTALL)
|
| 80 |
-
|
| 81 |
-
# json_start = response.find('{')
|
| 82 |
-
# json_end = response.rfind('}') + 1
|
| 83 |
-
# json_str = response[json_start:json_end]
|
| 84 |
-
# res = json.loads(json_str)
|
| 85 |
return {"messages": response}
|
| 86 |
|
| 87 |
def create_graph(self):
|
|
@@ -90,12 +88,13 @@ class BasicAgent:
|
|
| 90 |
builder.add_node("tools", ToolNode(tools = tools))
|
| 91 |
|
| 92 |
builder.add_edge(START, "agent")
|
| 93 |
-
builder.add_conditional_edges("agent", tools_condition, ["
|
|
|
|
| 94 |
builder.add_edge("agent", END)
|
| 95 |
graph = builder.compile()
|
| 96 |
-
image = graph.get_graph().draw_mermaid_png()
|
| 97 |
-
with open("output_graph.png", "wb") as file:
|
| 98 |
-
|
| 99 |
return graph
|
| 100 |
|
| 101 |
|
|
|
|
| 14 |
from langchain_groq import ChatGroq
|
| 15 |
from langchain_core.messages import HumanMessage
|
| 16 |
|
| 17 |
+
from langgraph.graph import START, END, StateGraph, MessagesState
|
| 18 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 19 |
|
| 20 |
+
from agent_utils import wiki_ret, arxiv_ret, tavily_ret, handle_file_tool, add, subtract, division, multiplication, mode
|
| 21 |
|
| 22 |
# (Keep Constants as is)
|
| 23 |
# --- Constants ---
|
| 24 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 25 |
|
| 26 |
groq_api_key = os.getenv("GROQ_API_KEY")
|
| 27 |
+
tools = [wiki_ret, arxiv_ret, tavily_ret, handle_file_tool, add, subtract, division, multiplication, mode]
|
| 28 |
+
llm = ChatGroq(api_key=groq_api_key, model="qwen-qwq-32b")
|
| 29 |
llm_with_tools = llm.bind_tools(tools)
|
| 30 |
class GraphState(TypedDict):
|
| 31 |
messages: str #Annotated[Dict, operator.add]
|
| 32 |
context: str
|
| 33 |
|
| 34 |
+
# --- Basic Agent Definition ---
|
| 35 |
+
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
|
| 36 |
+
llm_with_tools = llm.bind_tools(tools)
|
| 37 |
+
class GraphState(MessagesState):
|
| 38 |
+
# messages: Annotated[BaseMessage, operator.add]
|
| 39 |
+
context: str
|
| 40 |
+
|
| 41 |
# --- Basic Agent Definition ---
|
| 42 |
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
|
| 43 |
class BasicAgent:
|
|
|
|
| 47 |
|
| 48 |
def __call__(self, question: str) -> str:
|
| 49 |
print("question: ", question)
|
| 50 |
+
|
| 51 |
+
prompt = f"""
|
| 52 |
+
You are an AI assistant designed to answer user questions using available tools.
|
| 53 |
+
Provide your final answer in one of the following formats:
|
| 54 |
+
|
| 55 |
+
A plain number (without commas, currency symbols, or percent signs unless explicitly requested).
|
| 56 |
+
|
| 57 |
+
A concise phrase (no articles or abbreviations).
|
| 58 |
+
|
| 59 |
+
With as few words as possible.
|
| 60 |
+
|
| 61 |
+
A comma-separated list of numbers and/or strings, following the above rules.
|
| 62 |
+
|
| 63 |
+
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
|
| 64 |
+
|
| 65 |
+
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
|
| 66 |
+
|
| 67 |
+
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
| 68 |
+
|
| 69 |
+
Use the tools at your disposal to find the correct answer. Question: {question}
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
messages = [HumanMessage(content=prompt)]
|
| 73 |
+
response = (self.graph).invoke({"messages": messages})
|
| 74 |
return response["messages"]
|
| 75 |
|
| 76 |
def agent(self, state: GraphState):
|
| 77 |
|
| 78 |
+
|
| 79 |
+
messages = state["messages"]
|
| 80 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
response = llm_with_tools.invoke(messages)
|
| 82 |
print("response: ", response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
return {"messages": response}
|
| 84 |
|
| 85 |
def create_graph(self):
|
|
|
|
| 88 |
builder.add_node("tools", ToolNode(tools = tools))
|
| 89 |
|
| 90 |
builder.add_edge(START, "agent")
|
| 91 |
+
builder.add_conditional_edges("agent", tools_condition, ["tools", END])
|
| 92 |
+
builder.add_edge("tools", "agent")
|
| 93 |
builder.add_edge("agent", END)
|
| 94 |
graph = builder.compile()
|
| 95 |
+
# image = graph.get_graph().draw_mermaid_png()
|
| 96 |
+
# with open("output_graph.png", "wb") as file:
|
| 97 |
+
# file.write(image)
|
| 98 |
return graph
|
| 99 |
|
| 100 |
|