abtsousa commited on
Commit
03f4295
·
1 Parent(s): 60d1fd6

Enhance OracleBot to accept optional file path for context in answers; add utility to fetch task files from API.

Browse files
Files changed (3) hide show
  1. agent/agent.py +44 -7
  2. app.py +73 -94
  3. utils.py +68 -0
agent/agent.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Literal
2
  from typing_extensions import TypedDict
3
  from langgraph.graph import StateGraph, START, END
@@ -17,10 +18,18 @@ class OracleBot:
17
  self.config = create_agent_config(self.name, self.thread_id)
18
  self.graph = self._build_agent(self.name)
19
 
20
- def answer_question(self, question: str):
21
  """
22
  Answer a question using the LangGraph agent.
 
 
 
 
23
  """
 
 
 
 
24
  messages = [HumanMessage(content=question)]
25
 
26
  for mode, chunk in self.graph.stream({"messages": messages}, config=self.config, stream_mode=["messages", "updates"]): # type: ignore
@@ -48,7 +57,8 @@ class OracleBot:
48
  # Handle final answer messages (no tool calls)
49
  elif hasattr(message, 'content') and message.content:
50
  cprint(f"\n{message.content}\n", color="black", on_color="on_white", attrs=["bold"])
51
-
 
52
  # Look for tool outputs in updates
53
  elif isinstance(chunk, dict) and 'tools' in chunk:
54
  tools_update = chunk['tools']
@@ -57,6 +67,36 @@ class OracleBot:
57
  if hasattr(message, 'content') and message.content:
58
  cprint(f"\n📤 Tool output:\n{message.content}\n", color="green")
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def _build_agent(self, name: str):
61
  """
62
  Get our LangGraph agent with the given model and tools.
@@ -77,10 +117,7 @@ class OracleBot:
77
  graph.add_conditional_edges("agent", tools_condition)
78
  graph.add_edge("tools", "agent")
79
 
80
- # Add memory
81
- memory = InMemorySaver()
82
-
83
- return graph.compile(checkpointer=memory)
84
 
85
  # test
86
  if __name__ == "__main__":
@@ -92,7 +129,7 @@ if __name__ == "__main__":
92
  from config import start_phoenix
93
  start_phoenix()
94
  bot = OracleBot()
95
- bot.answer_question(question)
96
 
97
  except Exception as e:
98
  print(f"Error running agent: {e}")
 
1
+ import os
2
  from typing import Literal
3
  from typing_extensions import TypedDict
4
  from langgraph.graph import StateGraph, START, END
 
18
  self.config = create_agent_config(self.name, self.thread_id)
19
  self.graph = self._build_agent(self.name)
20
 
21
+ def answer_question(self, question: str, file_path: str | None = None):
22
  """
23
  Answer a question using the LangGraph agent.
24
+
25
+ Args:
26
+ question: The question to answer
27
+ file_path: Optional path to a file associated with this question
28
  """
29
+ # Enhance question with file context if available
30
+ if file_path and os.path.exists(file_path):
31
+ question = f"{question}\n\nNote: There is an associated file at: {file_path}\nYou can use the file management tools to read and analyze this file."
32
+
33
  messages = [HumanMessage(content=question)]
34
 
35
  for mode, chunk in self.graph.stream({"messages": messages}, config=self.config, stream_mode=["messages", "updates"]): # type: ignore
 
57
  # Handle final answer messages (no tool calls)
58
  elif hasattr(message, 'content') and message.content:
59
  cprint(f"\n{message.content}\n", color="black", on_color="on_white", attrs=["bold"])
60
+ return message.content # Return final answer
61
+
62
  # Look for tool outputs in updates
63
  elif isinstance(chunk, dict) and 'tools' in chunk:
64
  tools_update = chunk['tools']
 
67
  if hasattr(message, 'content') and message.content:
68
  cprint(f"\n📤 Tool output:\n{message.content}\n", color="green")
69
 
70
+ async def answer_question_async(self, question: str, file_path: str | None = None) -> str:
71
+ """
72
+ Answer a question using the LangGraph agent asynchronously.
73
+
74
+ Args:
75
+ question: The question to answer
76
+ file_path: Optional path to a file associated with this question
77
+
78
+ Returns the final answer as a string.
79
+ """
80
+ from langchain_core.runnables import RunnableConfig
81
+ from typing import cast
82
+
83
+ # Enhance question with file context if available
84
+ if file_path and os.path.exists(file_path):
85
+ question = f"{question}\n\nNote: There is an associated file at: {file_path}\nYou can use the file management tools to read and analyze this file."
86
+
87
+ messages = [HumanMessage(content=question)]
88
+
89
+ # Use LangGraph's built-in ainvoke method
90
+ result = await self.graph.ainvoke({"messages": messages}, config=cast(RunnableConfig, self.config)) # type: ignore
91
+
92
+ # Extract the content from the last message
93
+ if "messages" in result and result["messages"]:
94
+ last_message = result["messages"][-1]
95
+ if hasattr(last_message, 'content'):
96
+ return last_message.content or ""
97
+
98
+ return ""
99
+
100
  def _build_agent(self, name: str):
101
  """
102
  Get our LangGraph agent with the given model and tools.
 
117
  graph.add_conditional_edges("agent", tools_condition)
118
  graph.add_edge("tools", "agent")
119
 
120
+ return graph.compile()
 
 
 
121
 
122
  # test
123
  if __name__ == "__main__":
 
129
  from config import start_phoenix
130
  start_phoenix()
131
  bot = OracleBot()
132
+ bot.answer_question(question, None)
133
 
134
  except Exception as e:
135
  print(f"Error running agent: {e}")
app.py CHANGED
@@ -4,14 +4,12 @@ import requests
4
  import pandas as pd
5
  from os import getenv
6
  from dotenv import load_dotenv
7
- from langchain_core.messages import HumanMessage
8
- from langchain_core.runnables import RunnableConfig
9
  import asyncio
10
- from typing import cast
11
 
12
  from agent.agent import OracleBot
13
- from agent.config import create_agent_config
14
  from config import start_phoenix, APP_NAME, DEFAULT_API_URL
 
15
 
16
  load_dotenv()
17
 
@@ -28,38 +26,8 @@ start_phoenix()
28
  # # (in this case, it appends messages to the list, rather than overwriting them)
29
  # messages: Annotated[list, add_messages]
30
 
31
- class BasicAgent:
32
- def __init__(self):
33
- self.agent = OracleBot()
34
-
35
- async def __call__(self, question: str) -> str:
36
- print(f"Agent received question: {question}")
37
-
38
- # Create configuration like in main.py
39
- config = create_agent_config(app_name=APP_NAME)
40
-
41
- # Call the agent with the question and config (like main.py)
42
- answer = await self.agent.ainvoke(
43
- {"messages": [HumanMessage(content=question)]},
44
- cast(RunnableConfig, config)
45
- )
46
-
47
- print(f"Agent returning answer: {answer}")
48
-
49
- # Extract content from the last message in the response
50
- if "messages" in answer and answer["messages"]:
51
- last_message = answer["messages"][-1]
52
- if hasattr(last_message, 'content'):
53
- content = last_message.content
54
- else:
55
- content = str(last_message)
56
- else:
57
- content = str(answer)
58
-
59
- return str(content) if content is not None else ""
60
-
61
  # Simplified concurrent processor: launch all tasks immediately and await them together
62
- async def process_questions(agent: BasicAgent, questions_data: list):
63
  print(f"Running agent on {len(questions_data)} questions concurrently (simple fan-out)...")
64
 
65
  async def handle(item: dict):
@@ -69,7 +37,16 @@ async def process_questions(agent: BasicAgent, questions_data: list):
69
  print(f"Skipping item with missing task_id or question: {item}")
70
  return None
71
  try:
72
- submitted_answer = await agent(question_text)
 
 
 
 
 
 
 
 
 
73
  return {
74
  "log": {"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer},
75
  "payload": {"task_id": task_id, "submitted_answer": submitted_answer},
@@ -107,7 +84,7 @@ async def run_and_submit_all( profile: gr.OAuthProfile | None):
107
 
108
  # 1. Instantiate Agent ( modify this part to create your agent)
109
  try:
110
- agent = BasicAgent()
111
  except Exception as e:
112
  print(f"Error instantiating agent: {e}")
113
  return f"Error initializing agent: {e}", None
@@ -137,64 +114,66 @@ async def run_and_submit_all( profile: gr.OAuthProfile | None):
137
  return f"An unexpected error occurred fetching questions: {e}", None
138
 
139
  # 3. Run your Agent concurrently (simple gather)
140
- results_log, answers_payload = await process_questions(agent, questions_data)
141
-
142
- # Remove everything before "FINAL ANSWER: " in submitted answers
143
- for answer in answers_payload:
144
- if "submitted_answer" in answer:
145
- answer["submitted_answer"] = answer["submitted_answer"].split("FINAL ANSWER: ", 1)[-1].strip()
146
-
147
- if not answers_payload:
148
- print("Agent did not produce any answers to submit.")
149
- return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
150
-
151
- # 4. Prepare Submission
152
- submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
153
- status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
154
- print(status_update)
155
-
156
- # 5. Submit
157
- print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
158
- try:
159
- response = requests.post(submit_url, json=submission_data, timeout=60)
160
- response.raise_for_status()
161
- result_data = response.json()
162
- final_status = (
163
- f"Submission Successful!\n"
164
- f"User: {result_data.get('username')}\n"
165
- f"Overall Score: {result_data.get('score', 'N/A')}% "
166
- f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
167
- f"Message: {result_data.get('message', 'No message received.')}"
168
- )
169
- print("Submission successful.")
170
- results_df = pd.DataFrame(results_log)
171
- return final_status, results_df
172
- except requests.exceptions.HTTPError as e:
173
- error_detail = f"Server responded with status {e.response.status_code}."
174
  try:
175
- error_json = e.response.json()
176
- error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
177
- except requests.exceptions.JSONDecodeError:
178
- error_detail += f" Response: {e.response.text[:500]}"
179
- status_message = f"Submission Failed: {error_detail}"
180
- print(status_message)
181
- results_df = pd.DataFrame(results_log)
182
- return status_message, results_df
183
- except requests.exceptions.Timeout:
184
- status_message = "Submission Failed: The request timed out."
185
- print(status_message)
186
- results_df = pd.DataFrame(results_log)
187
- return status_message, results_df
188
- except requests.exceptions.RequestException as e:
189
- status_message = f"Submission Failed: Network error - {e}"
190
- print(status_message)
191
- results_df = pd.DataFrame(results_log)
192
- return status_message, results_df
193
- except Exception as e:
194
- status_message = f"An unexpected error occurred during submission: {e}"
195
- print(status_message)
196
- results_df = pd.DataFrame(results_log)
197
- return status_message, results_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
 
200
  # --- Build Gradio Interface using Blocks ---
 
4
  import pandas as pd
5
  from os import getenv
6
  from dotenv import load_dotenv
 
 
7
  import asyncio
8
+ import tempfile
9
 
10
  from agent.agent import OracleBot
 
11
  from config import start_phoenix, APP_NAME, DEFAULT_API_URL
12
+ from utils import fetch_task_file, extract_task_id_from_question_data
13
 
14
  load_dotenv()
15
 
 
26
  # # (in this case, it appends messages to the list, rather than overwriting them)
27
  # messages: Annotated[list, add_messages]
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Simplified concurrent processor: launch all tasks immediately and await them together
30
+ async def process_questions(agent: OracleBot, questions_data: list, working_dir: str):
31
  print(f"Running agent on {len(questions_data)} questions concurrently (simple fan-out)...")
32
 
33
  async def handle(item: dict):
 
37
  print(f"Skipping item with missing task_id or question: {item}")
38
  return None
39
  try:
40
+ # Fetch associated file if it exists
41
+ file_path = fetch_task_file(task_id, working_dir)
42
+ if file_path:
43
+ print(f"Found file for task {task_id}: {file_path}")
44
+
45
+ # Pass file_path to agent
46
+ submitted_answer = await agent.answer_question_async(question_text, file_path)
47
+ # Extract everything after "FINAL ANSWER: "
48
+ if "FINAL ANSWER: " in submitted_answer:
49
+ submitted_answer = submitted_answer.split("FINAL ANSWER: ", 1)[-1].strip()
50
  return {
51
  "log": {"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer},
52
  "payload": {"task_id": task_id, "submitted_answer": submitted_answer},
 
84
 
85
  # 1. Instantiate Agent ( modify this part to create your agent)
86
  try:
87
+ agent = OracleBot()
88
  except Exception as e:
89
  print(f"Error instantiating agent: {e}")
90
  return f"Error initializing agent: {e}", None
 
114
  return f"An unexpected error occurred fetching questions: {e}", None
115
 
116
  # 3. Run your Agent concurrently (simple gather)
117
+ # Create a temporary working directory for this session
118
+ with tempfile.TemporaryDirectory() as working_dir:
119
+ results_log, answers_payload = await process_questions(agent, questions_data, working_dir)
120
+
121
+ # Remove everything before "FINAL ANSWER: " in submitted answers
122
+ for answer in answers_payload:
123
+ if "submitted_answer" in answer:
124
+ answer["submitted_answer"] = answer["submitted_answer"].split("FINAL ANSWER: ", 1)[-1].strip()
125
+
126
+ if not answers_payload:
127
+ print("Agent did not produce any answers to submit.")
128
+ return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
129
+
130
+ # 4. Prepare Submission
131
+ submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
132
+ status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
133
+ print(status_update)
134
+
135
+ # 5. Submit
136
+ print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  try:
138
+ response = requests.post(submit_url, json=submission_data, timeout=60)
139
+ response.raise_for_status()
140
+ result_data = response.json()
141
+ final_status = (
142
+ f"Submission Successful!\n"
143
+ f"User: {result_data.get('username')}\n"
144
+ f"Overall Score: {result_data.get('score', 'N/A')}% "
145
+ f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
146
+ f"Message: {result_data.get('message', 'No message received.')}"
147
+ )
148
+ print("Submission successful.")
149
+ results_df = pd.DataFrame(results_log)
150
+ return final_status, results_df
151
+ except requests.exceptions.HTTPError as e:
152
+ error_detail = f"Server responded with status {e.response.status_code}."
153
+ try:
154
+ error_json = e.response.json()
155
+ error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
156
+ except requests.exceptions.JSONDecodeError:
157
+ error_detail += f" Response: {e.response.text[:500]}"
158
+ status_message = f"Submission Failed: {error_detail}"
159
+ print(status_message)
160
+ results_df = pd.DataFrame(results_log)
161
+ return status_message, results_df
162
+ except requests.exceptions.Timeout:
163
+ status_message = "Submission Failed: The request timed out."
164
+ print(status_message)
165
+ results_df = pd.DataFrame(results_log)
166
+ return status_message, results_df
167
+ except requests.exceptions.RequestException as e:
168
+ status_message = f"Submission Failed: Network error - {e}"
169
+ print(status_message)
170
+ results_df = pd.DataFrame(results_log)
171
+ return status_message, results_df
172
+ except Exception as e:
173
+ status_message = f"An unexpected error occurred during submission: {e}"
174
+ print(status_message)
175
+ results_df = pd.DataFrame(results_log)
176
+ return status_message, results_df
177
 
178
 
179
  # --- Build Gradio Interface using Blocks ---
utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import tempfile
4
+ from pathlib import Path
5
+ from config import DEFAULT_API_URL
6
+
7
+
8
+ def fetch_task_file(task_id: str, working_dir: str) -> str | None:
9
+ """
10
+ Fetch the file associated with a task_id from the API and save it to the working directory.
11
+
12
+ Args:
13
+ task_id: The task ID to fetch the file for
14
+ working_dir: The working directory to save the file to
15
+
16
+ Returns:
17
+ The path to the downloaded file, or None if no file exists or error occurred
18
+ """
19
+ try:
20
+ files_url = f"{DEFAULT_API_URL}/files/{task_id}"
21
+ response = requests.get(files_url, timeout=30)
22
+
23
+ if response.status_code == 404:
24
+ # No file associated with this task
25
+ return None
26
+ elif response.status_code == 200:
27
+ # Try to determine filename from content-disposition header
28
+ filename = f"task_{task_id}_file"
29
+ if 'content-disposition' in response.headers:
30
+ content_disp = response.headers['content-disposition']
31
+ if 'filename=' in content_disp:
32
+ filename = content_disp.split('filename=')[1].strip('"')
33
+
34
+ # If content type suggests a specific extension
35
+ content_type = response.headers.get('content-type', '')
36
+ if 'json' in content_type and not filename.endswith('.json'):
37
+ filename += '.json'
38
+ elif 'text' in content_type and not filename.endswith('.txt'):
39
+ filename += '.txt'
40
+ elif 'csv' in content_type and not filename.endswith('.csv'):
41
+ filename += '.csv'
42
+
43
+ # Save file to working directory
44
+ file_path = os.path.join(working_dir, filename)
45
+ with open(file_path, 'wb') as f:
46
+ f.write(response.content)
47
+
48
+ print(f"Downloaded file for task {task_id}: {file_path}")
49
+ return file_path
50
+ else:
51
+ response.raise_for_status()
52
+
53
+ except Exception as e:
54
+ print(f"Error fetching file for task {task_id}: {e}")
55
+ return None
56
+
57
+
58
+ def extract_task_id_from_question_data(question_data: dict) -> str | None:
59
+ """
60
+ Extract task_id from question data dictionary.
61
+
62
+ Args:
63
+ question_data: Dictionary containing question information
64
+
65
+ Returns:
66
+ The task_id if found, None otherwise
67
+ """
68
+ return question_data.get("task_id")