Spaces:
Sleeping
Sleeping
add agent
Browse files- .gitattributes +1 -0
- .gitignore +1 -0
- agents.py +58 -0
- files/1f975693-876d-457b-a649-393859e79bf3.mp3 +3 -0
- files/7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx +0 -0
- files/99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3.mp3 +3 -0
- files/cca530fc-4052-43b2-b130-b30968d8aa44.png +0 -0
- files/f918266a-b3e0-4914-865d-4faa564f1aef.py +35 -0
- qa_graph.py +184 -0
- system_prompt.txt +7 -0
- tools.py +186 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.env
|
agents.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tools import general_tools, file_agent_tools, data_agent_tools, math_agent_tools
|
| 2 |
+
from langgraph.prebuilt import create_react_agent
|
| 3 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 4 |
+
from langchain_openai import ChatOpenAI
|
| 5 |
+
from langgraph_supervisor import create_supervisor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
llm = ChatOpenAI(model="o4-mini")
|
| 9 |
+
|
| 10 |
+
memory = MemorySaver()
|
| 11 |
+
|
| 12 |
+
with open("system_prompt.txt", "r") as f:
|
| 13 |
+
prompt = f.read()
|
| 14 |
+
|
| 15 |
+
general_agent = create_react_agent(
|
| 16 |
+
model=llm,
|
| 17 |
+
tools=general_tools(),
|
| 18 |
+
checkpointer=memory,
|
| 19 |
+
prompt=prompt
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Create agents
|
| 23 |
+
file_agent = create_react_agent(
|
| 24 |
+
model=llm,
|
| 25 |
+
tools=file_agent_tools(),
|
| 26 |
+
name="file_reader",
|
| 27 |
+
prompt="You read files. Use tools to read files."
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
math_agent = create_react_agent(
|
| 31 |
+
model=llm,
|
| 32 |
+
tools=math_agent_tools(),
|
| 33 |
+
name="calculator",
|
| 34 |
+
prompt="You do math. Use tools for all calculations."
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
data_agent = create_react_agent(
|
| 38 |
+
model=llm,
|
| 39 |
+
tools=data_agent_tools(),
|
| 40 |
+
name="data_processor",
|
| 41 |
+
prompt="You process data. Use tools to filter and extract data."
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
prompt = """You are a supervisor. You coordinate file_reader, calculator, and data_processor to solve problems step by step.
|
| 45 |
+
Do not do calculations or file reading yourself, use the tools.
|
| 46 |
+
Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
|
| 47 |
+
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
| 48 |
+
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
|
| 49 |
+
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
|
| 50 |
+
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
| 51 |
+
"""
|
| 52 |
+
# Supervisor
|
| 53 |
+
excel_supervisor = create_supervisor(
|
| 54 |
+
[file_agent, math_agent, data_agent],
|
| 55 |
+
model=llm,
|
| 56 |
+
prompt=prompt
|
| 57 |
+
).compile()
|
| 58 |
+
|
files/1f975693-876d-457b-a649-393859e79bf3.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:200f767e732b49efef5c05d128903ee4d2c34e66fdce7f5593ac123b2e637673
|
| 3 |
+
size 280868
|
files/7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx
ADDED
|
Binary file (5.29 kB). View file
|
|
|
files/99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b218c951c1f888f0bbe6f46c080f57afc7c9348fffc7ba4da35749ff1e2ac40f
|
| 3 |
+
size 179304
|
files/cca530fc-4052-43b2-b130-b30968d8aa44.png
ADDED
|
files/f918266a-b3e0-4914-865d-4faa564f1aef.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from random import randint
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
class UhOh(Exception):
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
class Hmm:
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.value = randint(-100, 100)
|
| 10 |
+
|
| 11 |
+
def Yeah(self):
|
| 12 |
+
if self.value == 0:
|
| 13 |
+
return True
|
| 14 |
+
else:
|
| 15 |
+
raise UhOh()
|
| 16 |
+
|
| 17 |
+
def Okay():
|
| 18 |
+
while True:
|
| 19 |
+
yield Hmm()
|
| 20 |
+
|
| 21 |
+
def keep_trying(go, first_try=True):
|
| 22 |
+
maybe = next(go)
|
| 23 |
+
try:
|
| 24 |
+
if maybe.Yeah():
|
| 25 |
+
return maybe.value
|
| 26 |
+
except UhOh:
|
| 27 |
+
if first_try:
|
| 28 |
+
print("Working...")
|
| 29 |
+
print("Please wait patiently...")
|
| 30 |
+
time.sleep(0.1)
|
| 31 |
+
return keep_trying(go, first_try=False)
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
go = Okay()
|
| 35 |
+
print(f"{keep_trying(go)}")
|
qa_graph.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from langgraph.graph import START, StateGraph, END
|
| 3 |
+
from typing import TypedDict
|
| 4 |
+
from agents import general_agent, excel_supervisor
|
| 5 |
+
import os
|
| 6 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 7 |
+
os.environ["OPENAI_API_KEY"] = str(OPENAI_API_KEY)
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class Question:
|
| 11 |
+
task_id: str
|
| 12 |
+
question: str
|
| 13 |
+
Level: str
|
| 14 |
+
file_name: str
|
| 15 |
+
local_file_path: str|None = None
|
| 16 |
+
|
| 17 |
+
def get_file_type(file_path: str) -> str:
|
| 18 |
+
"""Determine file type from extension."""
|
| 19 |
+
if not file_path:
|
| 20 |
+
return "none"
|
| 21 |
+
|
| 22 |
+
file_path = file_path.lower()
|
| 23 |
+
|
| 24 |
+
if file_path.endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
|
| 25 |
+
return "image"
|
| 26 |
+
elif file_path.endswith(('.xlsx', '.xls', '.csv')):
|
| 27 |
+
return "excel"
|
| 28 |
+
elif file_path.endswith('.py'):
|
| 29 |
+
return "python"
|
| 30 |
+
elif file_path.endswith(('.mp3', '.wav', '.m4a', '.ogg')):
|
| 31 |
+
return "audio"
|
| 32 |
+
else:
|
| 33 |
+
return "unknown"
|
| 34 |
+
|
| 35 |
+
def ask_question(question: str, thread_id: str = "default") -> str:
|
| 36 |
+
"""Ask the agent a question."""
|
| 37 |
+
config = {"configurable": {"thread_id": thread_id}}
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
response = general_agent.invoke(
|
| 41 |
+
{"messages": [{"role": "user", "content": question}]},
|
| 42 |
+
config=config
|
| 43 |
+
)
|
| 44 |
+
return response["messages"][-1].content
|
| 45 |
+
except Exception as e:
|
| 46 |
+
return f"Error: {str(e)}"
|
| 47 |
+
|
| 48 |
+
def ask_question_with_file(question: Question, thread_id: str = "default") -> str:
|
| 49 |
+
"""Ask the agent a question, with optional file analysis."""
|
| 50 |
+
q = question.question
|
| 51 |
+
root_file = "./files"
|
| 52 |
+
file_path = root_file + "/" + question.file_name
|
| 53 |
+
if not question.file_name:
|
| 54 |
+
return ask_question(q, thread_id)
|
| 55 |
+
|
| 56 |
+
file_type = get_file_type(file_path)
|
| 57 |
+
|
| 58 |
+
# Create enhanced question with file guidance
|
| 59 |
+
if file_type == "image":
|
| 60 |
+
enhanced_question = f"{q}\n\nThere is an image file at '{file_path}'. Use the analyze_image tool to examine it."
|
| 61 |
+
elif file_type == "excel":
|
| 62 |
+
enhanced_question = f"{q}\n\nFile path: {file_path}"
|
| 63 |
+
result = excel_supervisor.invoke({
|
| 64 |
+
"messages": [
|
| 65 |
+
{"role": "user", "content": enhanced_question}
|
| 66 |
+
]
|
| 67 |
+
})
|
| 68 |
+
return result["messages"][-1].content
|
| 69 |
+
elif file_type == "python":
|
| 70 |
+
enhanced_question = f"{q}\n\nThere is a Python file at '{file_path}'. Use the read_python_file tool to examine it."
|
| 71 |
+
elif file_type == "audio":
|
| 72 |
+
enhanced_question = f"{q}\n\nThere is an audio file at '{file_path}'. Use the transcribe_audio tool to process it."
|
| 73 |
+
else:
|
| 74 |
+
enhanced_question = f"{q}\n\nThere is a file at '{file_path}' but I'm not sure what type it is."
|
| 75 |
+
|
| 76 |
+
return ask_question(enhanced_question, thread_id)
|
| 77 |
+
|
| 78 |
+
test = [
|
| 79 |
+
# {
|
| 80 |
+
# "task_id": "cca530fc-4052-43b2-b130-b30968d8aa44",
|
| 81 |
+
# "question": "Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation.",
|
| 82 |
+
# "Level": "1",
|
| 83 |
+
# "file_name": "cca530fc-4052-43b2-b130-b30968d8aa44.png"
|
| 84 |
+
# },
|
| 85 |
+
# {
|
| 86 |
+
# "task_id": "1f975693-876d-457b-a649-393859e79bf3",
|
| 87 |
+
# "question": "Hi, I was out sick from my classes on Friday, so I'm trying to figure out what I need to study for my Calculus mid-term next week. My friend from class sent me an audio recording of Professor Willowbrook giving out the recommended reading for the test, but my headphones are broken :(\n\nCould you please listen to the recording for me and tell me the page numbers I'm supposed to go over? I've attached a file called Homework.mp3 that has the recording. Please provide just the page numbers as a comma-delimited list. And please provide the list in ascending order.",
|
| 88 |
+
# "Level": "1",
|
| 89 |
+
# "file_name": "1f975693-876d-457b-a649-393859e79bf3.mp3"
|
| 90 |
+
# },
|
| 91 |
+
{
|
| 92 |
+
"task_id": "7bd855d8-463d-4ed5-93ca-5fe35145f733",
|
| 93 |
+
"question": "The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.",
|
| 94 |
+
"Level": "1",
|
| 95 |
+
"file_name": "7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx"
|
| 96 |
+
}
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
questions = [Question(**item) for item in test]
|
| 100 |
+
for q in questions:
|
| 101 |
+
print(q.question)
|
| 102 |
+
print(q.file_name)
|
| 103 |
+
print(q.local_file_path)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# State
|
| 108 |
+
class State(TypedDict):
|
| 109 |
+
question: Question
|
| 110 |
+
decision: str
|
| 111 |
+
answer: str
|
| 112 |
+
|
| 113 |
+
# NODE FUNCTIONS - These are the ones that work with LangGraph
|
| 114 |
+
def ask_question_node(state: State) -> dict:
|
| 115 |
+
"""Node function for questions without files."""
|
| 116 |
+
question_obj = state["question"]
|
| 117 |
+
thread_id = f"test_{question_obj.task_id}"
|
| 118 |
+
|
| 119 |
+
# Call your existing function
|
| 120 |
+
answer = ask_question(question_obj.question, thread_id)
|
| 121 |
+
|
| 122 |
+
# Return dict to update state
|
| 123 |
+
return {"answer": answer}
|
| 124 |
+
|
| 125 |
+
def ask_question_with_file_node(state: State) -> dict:
|
| 126 |
+
"""Node function for questions with files."""
|
| 127 |
+
question_obj = state["question"]
|
| 128 |
+
thread_id = f"test_{question_obj.task_id}"
|
| 129 |
+
|
| 130 |
+
# Call your existing function
|
| 131 |
+
answer = ask_question_with_file(question_obj, thread_id)
|
| 132 |
+
|
| 133 |
+
# Return dict to update state
|
| 134 |
+
return {"answer": answer}
|
| 135 |
+
|
| 136 |
+
def router_node(state: State):
|
| 137 |
+
"""Router node - returns dict to update state"""
|
| 138 |
+
if state["question"].file_name:
|
| 139 |
+
decision = "query_with_file"
|
| 140 |
+
else:
|
| 141 |
+
decision = "query"
|
| 142 |
+
|
| 143 |
+
return {"decision": decision}
|
| 144 |
+
|
| 145 |
+
def router_function(state: State):
|
| 146 |
+
"""Routing function - returns string to choose path"""
|
| 147 |
+
return state["decision"]
|
| 148 |
+
|
| 149 |
+
# Graph
|
| 150 |
+
builder = StateGraph(State)
|
| 151 |
+
|
| 152 |
+
# Use the NODE functions (not the original functions)
|
| 153 |
+
builder.add_node("query_with_file", ask_question_with_file_node)
|
| 154 |
+
builder.add_node("query", ask_question_node)
|
| 155 |
+
builder.add_node("router", router_node)
|
| 156 |
+
|
| 157 |
+
# Define edges
|
| 158 |
+
builder.add_edge(START, "router")
|
| 159 |
+
builder.add_conditional_edges(
|
| 160 |
+
"router",
|
| 161 |
+
router_function,
|
| 162 |
+
{
|
| 163 |
+
"query_with_file": "query_with_file",
|
| 164 |
+
"query": "query",
|
| 165 |
+
},
|
| 166 |
+
)
|
| 167 |
+
builder.add_edge("query_with_file", END)
|
| 168 |
+
builder.add_edge("query", END)
|
| 169 |
+
|
| 170 |
+
react_graph = builder.compile()
|
| 171 |
+
|
| 172 |
+
if __name__ == "__main__":
|
| 173 |
+
for i, question in enumerate(questions):
|
| 174 |
+
print(f"\n{i}. {question.question}")
|
| 175 |
+
|
| 176 |
+
# Invoke the graph and capture the result
|
| 177 |
+
result = react_graph.invoke({
|
| 178 |
+
"question": question,
|
| 179 |
+
"decision": "",
|
| 180 |
+
"answer": ""
|
| 181 |
+
})
|
| 182 |
+
|
| 183 |
+
print(f"Answer: {result['answer']}")
|
| 184 |
+
print("-" * 50)
|
system_prompt.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are a general AI assistant.
|
| 2 |
+
I will ask you a question.
|
| 3 |
+
Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
|
| 4 |
+
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
| 5 |
+
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
|
| 6 |
+
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
|
| 7 |
+
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
tools.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.messages import HumanMessage
|
| 2 |
+
from langchain_core.tools import tool
|
| 3 |
+
from langchain_community.tools import (
|
| 4 |
+
DuckDuckGoSearchRun,
|
| 5 |
+
WikipediaQueryRun,
|
| 6 |
+
ArxivQueryRun
|
| 7 |
+
)
|
| 8 |
+
from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper
|
| 9 |
+
from langchain_openai import ChatOpenAI
|
| 10 |
+
|
| 11 |
+
import base64
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from huggingface_hub import InferenceClient
|
| 17 |
+
import json
|
| 18 |
+
import requests
|
| 19 |
+
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
load_dotenv()
|
| 22 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 23 |
+
client = InferenceClient(
|
| 24 |
+
provider="hf-inference",
|
| 25 |
+
api_key=HF_TOKEN,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
llm = ChatOpenAI(model="o4-mini")
|
| 29 |
+
vision_llm = ChatOpenAI(model="gpt-4o")
|
| 30 |
+
|
| 31 |
+
@tool
|
| 32 |
+
def analyze_image(img_path: str, question: str) -> str:
|
| 33 |
+
"""Analyze an image and answer a question about it."""
|
| 34 |
+
try:
|
| 35 |
+
with open(img_path, "rb") as image_file:
|
| 36 |
+
image_bytes = image_file.read()
|
| 37 |
+
|
| 38 |
+
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
| 39 |
+
|
| 40 |
+
message = [
|
| 41 |
+
HumanMessage(
|
| 42 |
+
content=[
|
| 43 |
+
{"type": "text", "text": question},
|
| 44 |
+
{
|
| 45 |
+
"type": "image_url",
|
| 46 |
+
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
|
| 47 |
+
}
|
| 48 |
+
]
|
| 49 |
+
)
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
response = vision_llm.invoke(message)
|
| 53 |
+
return response.content
|
| 54 |
+
|
| 55 |
+
except Exception as e:
|
| 56 |
+
return f"Error analyzing image: {str(e)}"
|
| 57 |
+
|
| 58 |
+
@tool
|
| 59 |
+
def read_excel_file(file_path: str, question: str) -> str:
|
| 60 |
+
"""Read and analyze an Excel file to answer a question."""
|
| 61 |
+
try:
|
| 62 |
+
# Read Excel file
|
| 63 |
+
df = pd.read_excel(file_path)
|
| 64 |
+
|
| 65 |
+
df_dict = df.to_dict(orient='records')
|
| 66 |
+
info = json.dumps(df_dict)
|
| 67 |
+
return info
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
return f"Error reading Excel file: {str(e)}"
|
| 71 |
+
|
| 72 |
+
@tool
|
| 73 |
+
def read_python_file(file_path: str, question: str) -> str:
|
| 74 |
+
"""Read and analyze a Python file to answer a question."""
|
| 75 |
+
try:
|
| 76 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 77 |
+
code_content = f.read()
|
| 78 |
+
|
| 79 |
+
prompt = f"""Here is Python code from a file:
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
{code_content}
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Question: {question}
|
| 86 |
+
|
| 87 |
+
Please analyze the code and answer the question."""
|
| 88 |
+
|
| 89 |
+
response = llm.invoke([HumanMessage(content=prompt)])
|
| 90 |
+
return response.content
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
return f"Error reading Python file: {str(e)}"
|
| 94 |
+
|
| 95 |
+
@tool
|
| 96 |
+
def transcribe_audio(file_path: str, question: str) -> str:
|
| 97 |
+
"""Transcribe audio file."""
|
| 98 |
+
try:
|
| 99 |
+
headers = {
|
| 100 |
+
"Authorization": f"Bearer {HF_TOKEN}",
|
| 101 |
+
"Content-Type": "audio/mpeg" # Add this line for MP3 files
|
| 102 |
+
}
|
| 103 |
+
API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3"
|
| 104 |
+
|
| 105 |
+
def query(filename):
|
| 106 |
+
with open(filename, "rb") as f:
|
| 107 |
+
data = f.read()
|
| 108 |
+
response = requests.request("POST", API_URL, headers=headers, data=data)
|
| 109 |
+
return json.loads(response.content.decode("utf-8"))
|
| 110 |
+
|
| 111 |
+
data = query(file_path)
|
| 112 |
+
return data
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
return f"Error transcribing audio: {str(e)}"
|
| 116 |
+
|
| 117 |
+
#### Excel supervisor agent
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def general_tools():
|
| 121 |
+
tools = [
|
| 122 |
+
DuckDuckGoSearchRun(),
|
| 123 |
+
WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
|
| 124 |
+
ArxivQueryRun(api_wrapper=ArxivAPIWrapper()),
|
| 125 |
+
analyze_image,
|
| 126 |
+
read_python_file,
|
| 127 |
+
transcribe_audio,
|
| 128 |
+
]
|
| 129 |
+
return tools
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# Simple file tools
|
| 133 |
+
@tool
|
| 134 |
+
def read_excel(file_path: str) -> str:
|
| 135 |
+
"""Read any Excel file and return as JSON."""
|
| 136 |
+
df = pd.read_excel(file_path)
|
| 137 |
+
return json.dumps(df.to_dict(orient='records'))
|
| 138 |
+
|
| 139 |
+
# Simple math tools
|
| 140 |
+
@tool
|
| 141 |
+
def add(a: float, b: float) -> float:
|
| 142 |
+
"""Add two numbers."""
|
| 143 |
+
return a + b
|
| 144 |
+
|
| 145 |
+
@tool
|
| 146 |
+
def sum_list(numbers: list) -> float:
|
| 147 |
+
"""Sum a list of numbers."""
|
| 148 |
+
return sum(numbers)
|
| 149 |
+
|
| 150 |
+
# Simple data tools
|
| 151 |
+
@tool
|
| 152 |
+
def extract_values(data: str, column: str) -> list:
|
| 153 |
+
"""Extract all values from a column in JSON data."""
|
| 154 |
+
parsed = json.loads(data)
|
| 155 |
+
values = []
|
| 156 |
+
for row in parsed:
|
| 157 |
+
for key, value in row.items():
|
| 158 |
+
if column.lower() in key.lower():
|
| 159 |
+
try:
|
| 160 |
+
values.append(float(value))
|
| 161 |
+
except:
|
| 162 |
+
pass
|
| 163 |
+
return values
|
| 164 |
+
|
| 165 |
+
@tool
|
| 166 |
+
def filter_rows(data: str, exclude_words: list) -> str:
|
| 167 |
+
"""Remove rows containing any of the exclude words."""
|
| 168 |
+
parsed = json.loads(data)
|
| 169 |
+
filtered = []
|
| 170 |
+
for row in parsed:
|
| 171 |
+
row_text = " ".join(str(v).lower() for v in row.values())
|
| 172 |
+
if not any(word.lower() in row_text for word in exclude_words):
|
| 173 |
+
filtered.append(row)
|
| 174 |
+
return json.dumps(filtered)
|
| 175 |
+
|
| 176 |
+
def file_agent_tools():
|
| 177 |
+
tools = [read_excel]
|
| 178 |
+
return tools
|
| 179 |
+
|
| 180 |
+
def math_agent_tools():
|
| 181 |
+
tools = [add, sum_list]
|
| 182 |
+
return tools
|
| 183 |
+
|
| 184 |
+
def data_agent_tools():
|
| 185 |
+
tools = [extract_values, filter_rows]
|
| 186 |
+
return tools
|