thivy commited on
Commit
7022d5d
·
1 Parent(s): 81917a3

add agent

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
agents.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tools import general_tools, file_agent_tools, data_agent_tools, math_agent_tools
2
+ from langgraph.prebuilt import create_react_agent
3
+ from langgraph.checkpoint.memory import MemorySaver
4
+ from langchain_openai import ChatOpenAI
5
+ from langgraph_supervisor import create_supervisor
6
+
7
+
8
+ llm = ChatOpenAI(model="o4-mini")
9
+
10
+ memory = MemorySaver()
11
+
12
+ with open("system_prompt.txt", "r") as f:
13
+ prompt = f.read()
14
+
15
+ general_agent = create_react_agent(
16
+ model=llm,
17
+ tools=general_tools(),
18
+ checkpointer=memory,
19
+ prompt=prompt
20
+ )
21
+
22
+ # Create agents
23
+ file_agent = create_react_agent(
24
+ model=llm,
25
+ tools=file_agent_tools(),
26
+ name="file_reader",
27
+ prompt="You read files. Use tools to read files."
28
+ )
29
+
30
+ math_agent = create_react_agent(
31
+ model=llm,
32
+ tools=math_agent_tools(),
33
+ name="calculator",
34
+ prompt="You do math. Use tools for all calculations."
35
+ )
36
+
37
+ data_agent = create_react_agent(
38
+ model=llm,
39
+ tools=data_agent_tools(),
40
+ name="data_processor",
41
+ prompt="You process data. Use tools to filter and extract data."
42
+ )
43
+
44
+ prompt = """You are a supervisor. You coordinate file_reader, calculator, and data_processor to solve problems step by step.
45
+ Do not do calculations or file reading yourself, use the tools.
46
+ Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
47
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
48
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
49
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
50
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
51
+ """
52
+ # Supervisor
53
+ excel_supervisor = create_supervisor(
54
+ [file_agent, math_agent, data_agent],
55
+ model=llm,
56
+ prompt=prompt
57
+ ).compile()
58
+
files/1f975693-876d-457b-a649-393859e79bf3.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:200f767e732b49efef5c05d128903ee4d2c34e66fdce7f5593ac123b2e637673
3
+ size 280868
files/7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx ADDED
Binary file (5.29 kB). View file
 
files/99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b218c951c1f888f0bbe6f46c080f57afc7c9348fffc7ba4da35749ff1e2ac40f
3
+ size 179304
files/cca530fc-4052-43b2-b130-b30968d8aa44.png ADDED
files/f918266a-b3e0-4914-865d-4faa564f1aef.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import randint
2
+ import time
3
+
4
+ class UhOh(Exception):
5
+ pass
6
+
7
+ class Hmm:
8
+ def __init__(self):
9
+ self.value = randint(-100, 100)
10
+
11
+ def Yeah(self):
12
+ if self.value == 0:
13
+ return True
14
+ else:
15
+ raise UhOh()
16
+
17
+ def Okay():
18
+ while True:
19
+ yield Hmm()
20
+
21
+ def keep_trying(go, first_try=True):
22
+ maybe = next(go)
23
+ try:
24
+ if maybe.Yeah():
25
+ return maybe.value
26
+ except UhOh:
27
+ if first_try:
28
+ print("Working...")
29
+ print("Please wait patiently...")
30
+ time.sleep(0.1)
31
+ return keep_trying(go, first_try=False)
32
+
33
+ if __name__ == "__main__":
34
+ go = Okay()
35
+ print(f"{keep_trying(go)}")
qa_graph.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from langgraph.graph import START, StateGraph, END
3
+ from typing import TypedDict
4
+ from agents import general_agent, excel_supervisor
5
+ import os
6
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
7
+ os.environ["OPENAI_API_KEY"] = str(OPENAI_API_KEY)
8
+
9
+ @dataclass
10
+ class Question:
11
+ task_id: str
12
+ question: str
13
+ Level: str
14
+ file_name: str
15
+ local_file_path: str|None = None
16
+
17
+ def get_file_type(file_path: str) -> str:
18
+ """Determine file type from extension."""
19
+ if not file_path:
20
+ return "none"
21
+
22
+ file_path = file_path.lower()
23
+
24
+ if file_path.endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
25
+ return "image"
26
+ elif file_path.endswith(('.xlsx', '.xls', '.csv')):
27
+ return "excel"
28
+ elif file_path.endswith('.py'):
29
+ return "python"
30
+ elif file_path.endswith(('.mp3', '.wav', '.m4a', '.ogg')):
31
+ return "audio"
32
+ else:
33
+ return "unknown"
34
+
35
+ def ask_question(question: str, thread_id: str = "default") -> str:
36
+ """Ask the agent a question."""
37
+ config = {"configurable": {"thread_id": thread_id}}
38
+
39
+ try:
40
+ response = general_agent.invoke(
41
+ {"messages": [{"role": "user", "content": question}]},
42
+ config=config
43
+ )
44
+ return response["messages"][-1].content
45
+ except Exception as e:
46
+ return f"Error: {str(e)}"
47
+
48
+ def ask_question_with_file(question: Question, thread_id: str = "default") -> str:
49
+ """Ask the agent a question, with optional file analysis."""
50
+ q = question.question
51
+ root_file = "./files"
52
+ file_path = root_file + "/" + question.file_name
53
+ if not question.file_name:
54
+ return ask_question(q, thread_id)
55
+
56
+ file_type = get_file_type(file_path)
57
+
58
+ # Create enhanced question with file guidance
59
+ if file_type == "image":
60
+ enhanced_question = f"{q}\n\nThere is an image file at '{file_path}'. Use the analyze_image tool to examine it."
61
+ elif file_type == "excel":
62
+ enhanced_question = f"{q}\n\nFile path: {file_path}"
63
+ result = excel_supervisor.invoke({
64
+ "messages": [
65
+ {"role": "user", "content": enhanced_question}
66
+ ]
67
+ })
68
+ return result["messages"][-1].content
69
+ elif file_type == "python":
70
+ enhanced_question = f"{q}\n\nThere is a Python file at '{file_path}'. Use the read_python_file tool to examine it."
71
+ elif file_type == "audio":
72
+ enhanced_question = f"{q}\n\nThere is an audio file at '{file_path}'. Use the transcribe_audio tool to process it."
73
+ else:
74
+ enhanced_question = f"{q}\n\nThere is a file at '{file_path}' but I'm not sure what type it is."
75
+
76
+ return ask_question(enhanced_question, thread_id)
77
+
78
+ test = [
79
+ # {
80
+ # "task_id": "cca530fc-4052-43b2-b130-b30968d8aa44",
81
+ # "question": "Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation.",
82
+ # "Level": "1",
83
+ # "file_name": "cca530fc-4052-43b2-b130-b30968d8aa44.png"
84
+ # },
85
+ # {
86
+ # "task_id": "1f975693-876d-457b-a649-393859e79bf3",
87
+ # "question": "Hi, I was out sick from my classes on Friday, so I'm trying to figure out what I need to study for my Calculus mid-term next week. My friend from class sent me an audio recording of Professor Willowbrook giving out the recommended reading for the test, but my headphones are broken :(\n\nCould you please listen to the recording for me and tell me the page numbers I'm supposed to go over? I've attached a file called Homework.mp3 that has the recording. Please provide just the page numbers as a comma-delimited list. And please provide the list in ascending order.",
88
+ # "Level": "1",
89
+ # "file_name": "1f975693-876d-457b-a649-393859e79bf3.mp3"
90
+ # },
91
+ {
92
+ "task_id": "7bd855d8-463d-4ed5-93ca-5fe35145f733",
93
+ "question": "The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.",
94
+ "Level": "1",
95
+ "file_name": "7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx"
96
+ }
97
+ ]
98
+
99
+ questions = [Question(**item) for item in test]
100
+ for q in questions:
101
+ print(q.question)
102
+ print(q.file_name)
103
+ print(q.local_file_path)
104
+
105
+
106
+
107
+ # State
108
+ class State(TypedDict):
109
+ question: Question
110
+ decision: str
111
+ answer: str
112
+
113
+ # NODE FUNCTIONS - These are the ones that work with LangGraph
114
+ def ask_question_node(state: State) -> dict:
115
+ """Node function for questions without files."""
116
+ question_obj = state["question"]
117
+ thread_id = f"test_{question_obj.task_id}"
118
+
119
+ # Call your existing function
120
+ answer = ask_question(question_obj.question, thread_id)
121
+
122
+ # Return dict to update state
123
+ return {"answer": answer}
124
+
125
+ def ask_question_with_file_node(state: State) -> dict:
126
+ """Node function for questions with files."""
127
+ question_obj = state["question"]
128
+ thread_id = f"test_{question_obj.task_id}"
129
+
130
+ # Call your existing function
131
+ answer = ask_question_with_file(question_obj, thread_id)
132
+
133
+ # Return dict to update state
134
+ return {"answer": answer}
135
+
136
+ def router_node(state: State):
137
+ """Router node - returns dict to update state"""
138
+ if state["question"].file_name:
139
+ decision = "query_with_file"
140
+ else:
141
+ decision = "query"
142
+
143
+ return {"decision": decision}
144
+
145
+ def router_function(state: State):
146
+ """Routing function - returns string to choose path"""
147
+ return state["decision"]
148
+
149
+ # Graph
150
+ builder = StateGraph(State)
151
+
152
+ # Use the NODE functions (not the original functions)
153
+ builder.add_node("query_with_file", ask_question_with_file_node)
154
+ builder.add_node("query", ask_question_node)
155
+ builder.add_node("router", router_node)
156
+
157
+ # Define edges
158
+ builder.add_edge(START, "router")
159
+ builder.add_conditional_edges(
160
+ "router",
161
+ router_function,
162
+ {
163
+ "query_with_file": "query_with_file",
164
+ "query": "query",
165
+ },
166
+ )
167
+ builder.add_edge("query_with_file", END)
168
+ builder.add_edge("query", END)
169
+
170
+ react_graph = builder.compile()
171
+
172
+ if __name__ == "__main__":
173
+ for i, question in enumerate(questions):
174
+ print(f"\n{i}. {question.question}")
175
+
176
+ # Invoke the graph and capture the result
177
+ result = react_graph.invoke({
178
+ "question": question,
179
+ "decision": "",
180
+ "answer": ""
181
+ })
182
+
183
+ print(f"Answer: {result['answer']}")
184
+ print("-" * 50)
system_prompt.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ You are a general AI assistant.
2
+ I will ask you a question.
3
+ Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
4
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
5
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
6
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
7
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
tools.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.messages import HumanMessage
2
+ from langchain_core.tools import tool
3
+ from langchain_community.tools import (
4
+ DuckDuckGoSearchRun,
5
+ WikipediaQueryRun,
6
+ ArxivQueryRun
7
+ )
8
+ from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper
9
+ from langchain_openai import ChatOpenAI
10
+
11
+ import base64
12
+ import pandas as pd
13
+ import os
14
+
15
+ import os
16
+ from huggingface_hub import InferenceClient
17
+ import json
18
+ import requests
19
+
20
+ from dotenv import load_dotenv
21
+ load_dotenv()
22
+ HF_TOKEN = os.getenv("HF_TOKEN")
23
+ client = InferenceClient(
24
+ provider="hf-inference",
25
+ api_key=HF_TOKEN,
26
+ )
27
+
28
+ llm = ChatOpenAI(model="o4-mini")
29
+ vision_llm = ChatOpenAI(model="gpt-4o")
30
+
31
+ @tool
32
+ def analyze_image(img_path: str, question: str) -> str:
33
+ """Analyze an image and answer a question about it."""
34
+ try:
35
+ with open(img_path, "rb") as image_file:
36
+ image_bytes = image_file.read()
37
+
38
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
39
+
40
+ message = [
41
+ HumanMessage(
42
+ content=[
43
+ {"type": "text", "text": question},
44
+ {
45
+ "type": "image_url",
46
+ "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
47
+ }
48
+ ]
49
+ )
50
+ ]
51
+
52
+ response = vision_llm.invoke(message)
53
+ return response.content
54
+
55
+ except Exception as e:
56
+ return f"Error analyzing image: {str(e)}"
57
+
58
+ @tool
59
+ def read_excel_file(file_path: str, question: str) -> str:
60
+ """Read and analyze an Excel file to answer a question."""
61
+ try:
62
+ # Read Excel file
63
+ df = pd.read_excel(file_path)
64
+
65
+ df_dict = df.to_dict(orient='records')
66
+ info = json.dumps(df_dict)
67
+ return info
68
+
69
+ except Exception as e:
70
+ return f"Error reading Excel file: {str(e)}"
71
+
72
+ @tool
73
+ def read_python_file(file_path: str, question: str) -> str:
74
+ """Read and analyze a Python file to answer a question."""
75
+ try:
76
+ with open(file_path, 'r', encoding='utf-8') as f:
77
+ code_content = f.read()
78
+
79
+ prompt = f"""Here is Python code from a file:
80
+
81
+ ```python
82
+ {code_content}
83
+ ```
84
+
85
+ Question: {question}
86
+
87
+ Please analyze the code and answer the question."""
88
+
89
+ response = llm.invoke([HumanMessage(content=prompt)])
90
+ return response.content
91
+
92
+ except Exception as e:
93
+ return f"Error reading Python file: {str(e)}"
94
+
95
+ @tool
96
+ def transcribe_audio(file_path: str, question: str) -> str:
97
+ """Transcribe audio file."""
98
+ try:
99
+ headers = {
100
+ "Authorization": f"Bearer {HF_TOKEN}",
101
+ "Content-Type": "audio/mpeg" # Add this line for MP3 files
102
+ }
103
+ API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3"
104
+
105
+ def query(filename):
106
+ with open(filename, "rb") as f:
107
+ data = f.read()
108
+ response = requests.request("POST", API_URL, headers=headers, data=data)
109
+ return json.loads(response.content.decode("utf-8"))
110
+
111
+ data = query(file_path)
112
+ return data
113
+
114
+ except Exception as e:
115
+ return f"Error transcribing audio: {str(e)}"
116
+
117
+ #### Excel supervisor agent
118
+
119
+
120
+ def general_tools():
121
+ tools = [
122
+ DuckDuckGoSearchRun(),
123
+ WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
124
+ ArxivQueryRun(api_wrapper=ArxivAPIWrapper()),
125
+ analyze_image,
126
+ read_python_file,
127
+ transcribe_audio,
128
+ ]
129
+ return tools
130
+
131
+
132
+ # Simple file tools
133
+ @tool
134
+ def read_excel(file_path: str) -> str:
135
+ """Read any Excel file and return as JSON."""
136
+ df = pd.read_excel(file_path)
137
+ return json.dumps(df.to_dict(orient='records'))
138
+
139
+ # Simple math tools
140
+ @tool
141
+ def add(a: float, b: float) -> float:
142
+ """Add two numbers."""
143
+ return a + b
144
+
145
+ @tool
146
+ def sum_list(numbers: list) -> float:
147
+ """Sum a list of numbers."""
148
+ return sum(numbers)
149
+
150
+ # Simple data tools
151
+ @tool
152
+ def extract_values(data: str, column: str) -> list:
153
+ """Extract all values from a column in JSON data."""
154
+ parsed = json.loads(data)
155
+ values = []
156
+ for row in parsed:
157
+ for key, value in row.items():
158
+ if column.lower() in key.lower():
159
+ try:
160
+ values.append(float(value))
161
+ except:
162
+ pass
163
+ return values
164
+
165
+ @tool
166
+ def filter_rows(data: str, exclude_words: list) -> str:
167
+ """Remove rows containing any of the exclude words."""
168
+ parsed = json.loads(data)
169
+ filtered = []
170
+ for row in parsed:
171
+ row_text = " ".join(str(v).lower() for v in row.values())
172
+ if not any(word.lower() in row_text for word in exclude_words):
173
+ filtered.append(row)
174
+ return json.dumps(filtered)
175
+
176
+ def file_agent_tools():
177
+ tools = [read_excel]
178
+ return tools
179
+
180
+ def math_agent_tools():
181
+ tools = [add, sum_list]
182
+ return tools
183
+
184
+ def data_agent_tools():
185
+ tools = [extract_values, filter_rows]
186
+ return tools