TommasoBB commited on
Commit
00a50e0
·
verified ·
1 Parent(s): 81917a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +318 -7
app.py CHANGED
@@ -1,23 +1,327 @@
1
  import os
2
  import gradio as gr
 
3
  import requests
4
  import inspect
5
  import pandas as pd
 
 
 
 
 
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # --- Basic Agent Definition ---
 
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
  class BasicAgent:
14
  def __init__(self):
15
- print("BasicAgent initialized.")
16
- def __call__(self, question: str) -> str:
 
 
 
 
 
 
17
  print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
@@ -75,12 +379,19 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
75
  print(f"Running agent on {len(questions_data)} questions...")
76
  for item in questions_data:
77
  task_id = item.get("task_id")
78
- question_text = item.get("question")
 
79
  if not task_id or question_text is None:
80
  print(f"Skipping item with missing task_id or question: {item}")
81
  continue
 
 
 
 
 
 
82
  try:
83
- submitted_answer = agent(question_text)
84
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
85
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
86
  except Exception as e:
 
1
  import os
2
  import gradio as gr
3
+ from gradio_client import file
4
  import requests
5
  import inspect
6
  import pandas as pd
7
+ import tools
8
+ from smolagents import CodeAgent, HfApiModel
9
+ from typing import TypedDict, List, Dict, Any, Optional
10
+ from langgraph.graph import StateGraph, START, END
11
+ from langgraph.messages import HumanMessage
12
+
13
 
14
  # (Keep Constants as is)
15
  # --- Constants ---
16
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
17
 
18
+ # --- Models ---
19
+ # Vision model for image analysis / OCR
20
+ vision_model = HfApiModel(repo_id="FireRedTeam/FireRed-OCR", max_new_tokens=2048, temperature=0.3)
21
+ math_model = HfApiModel(repo_id="Qwen/Qwen2.5-Math-1.5B", max_new_tokens=2048, temperature=0.3)
22
+
23
+ #define the state
24
+ class AgentState(TypedDict):
25
+ question: str
26
+ task_id: Optional[str]
27
+ file_name: Optional[str]
28
+ is_searching: Optional[bool]
29
+ have_file: Optional[bool]
30
+ is_math: Optional[bool]
31
+ have_image: Optional[bool]
32
+ final_answer: Optional[str] # The final answer produced by the agent
33
+ messages: List[Dict[str, Any]] # Track conversation with LLM for analysis
34
+ #define nodes
35
+
36
+ def read(state: AgentState) -> str:
37
+ """Agent reads and logs the incoming question."""
38
+ question = state["question"]
39
+ print(f"Agent is reading the question: {question[:50]}...")
40
+ return {}
41
+ def classify(state: AgentState) -> str:
42
+ """Agent classifies the question to determine which tools to use."""
43
+ question = state["question"].lower()
44
+
45
+ #prompt for LLM to classify the question
46
+ prompt = f"""
47
+ You are an agent that classifies questions to determine which tools to use.
48
+ Classify the following question into the categories: 'need to be searched on web/wikipidia', 'has a file in the question', 'is a math problem', 'has an image in the question'.
49
+ Question: {question}
50
+ Return a JSON object with boolean fields for each category, for example:
51
+ {{
52
+ "is_searching": true,
53
+ "have_file": false,
54
+ "is_math": false,
55
+ "have_image": false
56
+ }}
57
+ """
58
+ messages = [HumanMessage(content=prompt)]
59
+ response = model.invoke(messages)
60
+ is_searching = response.get("is_searching", False)
61
+ have_file = response.get("have_file", False)
62
+ is_math = response.get("is_math", False)
63
+ have_image = response.get("have_image", False)
64
+ print(f"Classification result: is_searching={is_searching}, have_file={have_file}, is_math={is_math}, have_image={have_image}")
65
+ mew_messages = state.get("messages", []) + [
66
+ {"role": "system", "content": "Classify the question to determine which tools to use."},
67
+ {"role": "user", "content": question},
68
+ {"role": "assistant", "content": f"Classification result: is_searching={is_searching}, have_file={have_file}, is_math={is_math}, have_image={have_image}"}
69
+ ]
70
+
71
+ return {
72
+ "is_searching": is_searching,
73
+ "have_file": have_file,
74
+ "is_math": is_math,
75
+ "have_image": have_image,
76
+ "messages": mew_messages
77
+ }
78
+
79
+
80
+
81
+ def handele_search(state: AgentState) -> str:
82
+ """Agent performs a web search if classified as needing search."""
83
+ question = state["question"]
84
+ print(f"Agent is performing a web search for: {question[:50]}...")
85
+ search_results = tools.WebSearchTool().run(question)
86
+ print(f"Search results: {search_results[:100]}...")
87
+ new_messages = state.get("messages", []) + [
88
+ {"role": "system", "content": "Perform a web search if classified as needing search."},
89
+ {"role": "user", "content": question},
90
+ {"role": "assistant", "content": f"Search results: {search_results[:100]}..."}
91
+ ]
92
+ return {
93
+ "search_results": search_results,
94
+ "messages": new_messages
95
+ }
96
+
97
+ def handle_image(state: AgentState) -> str:
98
+ """Agent handles an image if classified as having an image.
99
+ Downloads the image as base64 and sends it to a vision-capable model
100
+ using a multimodal message format."""
101
+ question = state["question"]
102
+ task_id = state.get("task_id", "")
103
+ file_name = state.get("file_name", "")
104
+
105
+ # Use ImageReaderTool to download the image as base64
106
+ image_reader = tools.ImageReaderTool()
107
+ image_data_uri = image_reader(task_id) if task_id and file_name else ""
108
+
109
+ if not image_data_uri or image_data_uri.startswith("Failed"):
110
+ print(f"Could not download image for task {task_id}")
111
+ new_messages = state.get("messages", []) + [
112
+ {"role": "assistant", "content": f"[Could not download image '{file_name}' for analysis.]"}
113
+ ]
114
+ return {
115
+ "image_description": "",
116
+ "transcribed_text": "",
117
+ "messages": new_messages
118
+ }
119
+
120
+ # Build multimodal message with image for a vision-capable model
121
+ prompt_text = f"""Analyze the attached image in detail.
122
+ Describe the content of the image and transcribe all text visible in it.
123
+
124
+ Question: {question}
125
+
126
+ Return a JSON object with the following fields:
127
+ {{
128
+ "image_description": "A detailed description of the image content.",
129
+ "transcribed_text": "All text visible in the image transcribed here."
130
+ }}"""
131
+
132
+ # Multimodal message: the vision model receives both text and image
133
+ messages = [
134
+ HumanMessage(content=[
135
+ {"type": "text", "text": prompt_text},
136
+ {"type": "image_url", "image_url": {"url": image_data_uri}}
137
+ ])
138
+ ]
139
+ # Use the dedicated vision model (FireRed-OCR) for image analysis
140
+ response = vision_model.invoke(messages)
141
+ image_description = response.get("image_description", "")
142
+ transcribed_text = response.get("transcribed_text", "")
143
+ print(f"Image description: {image_description[:100]}...")
144
+ print(f"Transcribed text: {transcribed_text[:100]}...")
145
+ new_messages = state.get("messages", []) + [
146
+ {"role": "system", "content": "Analyze and describe the image if classified as having an image."},
147
+ {"role": "user", "content": question},
148
+ {"role": "assistant", "content": f"Image description: {image_description[:100]}..., Transcribed text: {transcribed_text[:100]}..."}
149
+ ]
150
+ return {
151
+ "image_description": image_description,
152
+ "transcribed_text": transcribed_text,
153
+ "messages": new_messages
154
+ }
155
+
156
+
157
+
158
+ def handle_file(state: AgentState) -> str:
159
+ """Agent processes the file if classified as having a file.
160
+ Uses the FileReaderTool to download and read the file from the API."""
161
+ question = state["question"]
162
+ task_id = state.get("task_id", "")
163
+ file_name = state.get("file_name", "")
164
+
165
+ # Use the file_reader tool to fetch the file content
166
+ file_reader = tools.FileReaderTool()
167
+ file_content = file_reader(task_id) if task_id and file_name else ""
168
+
169
+ # Build prompt with the retrieved file content
170
+ file_context = ""
171
+ if file_content:
172
+ file_context = f"\n\n--- Attached file: {file_name} ---\n{file_content}\n--- End of file ---"
173
+ elif file_name:
174
+ file_context = f"\n\n[Note: A file '{file_name}' was referenced but could not be retrieved.]"
175
+
176
+ prompt = f"""You are an agent that can read and extract information from files.
177
+ Below is the content of the attached file retrieved from the API. Read it carefully and extract any relevant information that could help answer the question.
178
+
179
+ Question: {question}{file_context}
180
+
181
+ Return a JSON object with the following field:
182
+ {{
183
+ "extracted_info": "The relevant extracted information from the file."
184
+ }}"""
185
+ messages = [HumanMessage(content=prompt)]
186
+ response = model.invoke(messages)
187
+ extracted_info = response.get("extracted_info", "")
188
+ print(f"Extracted file info: {extracted_info[:100]}...")
189
+ new_messages = state.get("messages", []) + [
190
+ {"role": "system", "content": "Read and extract information from the attached file."},
191
+ {"role": "user", "content": question},
192
+ {"role": "assistant", "content": f"Extracted info: {extracted_info[:100]}..."}
193
+ ]
194
+ return {
195
+ "extracted_info": extracted_info,
196
+ "messages": new_messages
197
+ }
198
+
199
+ def handle_math(state: AgentState) -> str:
200
+ """Agent handles a math problem if classified as a math problem."""
201
+ question = state["question"]
202
+ print(f"Agent is handling a math problem: {question[:50]}...")
203
+ messages = [HumanMessage(content=f"Solve the following math problem step by step:\n\n{question}")]
204
+ response = math_model.invoke(messages)
205
+ solution = response.get("solution", "")
206
+ print(f"Math solution: {solution[:100]}...")
207
+ new_messages = state.get("messages", []) + [
208
+ {"role": "system", "content": "Handle the question if classified as a math problem."},
209
+ {"role": "user", "content": question},
210
+ {"role": "assistant", "content": f"Math solution: {solution[:100]}..."}
211
+ ]
212
+ return {
213
+ "math_solution": solution,
214
+ "messages": new_messages
215
+ }
216
+
217
+
218
+ def answer(state: AgentState) -> dict:
219
+ """Synthesize a final answer from all gathered context in messages."""
220
+ question = state["question"]
221
+ messages_history = state.get("messages", [])
222
+
223
+ # Build context summary from all assistant messages
224
+ context_parts = []
225
+ for msg in messages_history:
226
+ if msg.get("role") == "assistant":
227
+ context_parts.append(msg["content"])
228
+ context = "\n".join(context_parts) if context_parts else "No additional context gathered."
229
+
230
+ prompt = f"""You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
231
+
232
+ Question: {question}
233
+
234
+ Context gathered:
235
+ {context}
236
+ """
237
+ messages = [HumanMessage(content=prompt)]
238
+ # Use the general model for final answer synthesis
239
+ general_model = HfApiModel(repo_id="Qwen3.5-35B-A3B", max_new_tokens=2048, temperature=0.3)
240
+ response = general_model.invoke(messages)
241
+ raw_response = response.content if hasattr(response, 'content') else str(response)
242
+
243
+ # Extract the final answer after "FINAL ANSWER:" if present
244
+ if "FINAL ANSWER:" in raw_response:
245
+ final_answer = raw_response.split("FINAL ANSWER:")[-1].strip()
246
+ else:
247
+ final_answer = raw_response.strip()
248
+
249
+ print(f"Final answer: {final_answer[:100]}...")
250
+ return {"final_answer": final_answer}
251
+
252
+
253
+ def route_after_classify(state: AgentState) -> str:
254
+ """Routing function: decide which handler to invoke based on classification."""
255
+ if state.get("have_image"):
256
+ return "handle_image"
257
+ if state.get("have_file"):
258
+ return "handle_file"
259
+ if state.get("is_math"):
260
+ return "handle_math"
261
+ if state.get("is_searching"):
262
+ return "handle_search"
263
+ # Default: go straight to answer
264
+ return "answer"
265
+
266
+
267
+ #create the graph
268
+ agent_graph = StateGraph(AgentState)
269
+ agent_graph.add_node("read", read)
270
+ agent_graph.add_node("classify", classify)
271
+ agent_graph.add_node("handle_search", handele_search)
272
+ agent_graph.add_node("handle_image", handle_image)
273
+ agent_graph.add_node("handle_file", handle_file)
274
+ agent_graph.add_node("handle_math", handle_math)
275
+ agent_graph.add_node("answer", answer)
276
+
277
+ agent_graph.add_edge(START, "read")
278
+ agent_graph.add_edge("read", "classify")
279
+ agent_graph.add_conditional_edges(
280
+ "classify",
281
+ route_after_classify,
282
+ )
283
+
284
+ agent_graph.add_edge("handle_search", "answer")
285
+ agent_graph.add_edge("handle_image", "answer")
286
+ agent_graph.add_edge("handle_file", "answer")
287
+ agent_graph.add_edge("handle_math", "answer")
288
+ agent_graph.add_edge("answer", END)
289
+
290
+ compiled_agent = agent_graph.compile()
291
+
292
+
293
  # --- Basic Agent Definition ---
294
+
295
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
296
  class BasicAgent:
297
  def __init__(self):
298
+ self.file_reader = tools.FileReaderTool()
299
+ self.image_reader = tools.ImageReaderTool()
300
+ self.web_search = tools.WebSearchTool()
301
+ self.tools = [self.file_reader, self.image_reader, self.web_search]
302
+ self.vision_model = vision_model # FireRedTeam/FireRed-OCR for image tasks
303
+ print("Agent initialized.")
304
+
305
+ def __call__(self, question: str, task_id: str = "", file_name: str = "") -> str:
306
  print(f"Agent received question (first 50 chars): {question[:50]}...")
307
+
308
+ # Run the LangGraph workflow
309
+ result_state = compiled_agent.invoke({
310
+ "question": question,
311
+ "task_id": task_id,
312
+ "file_name": file_name,
313
+ "messages": [],
314
+ "is_searching": False,
315
+ "have_file": False,
316
+ "is_math": False,
317
+ "have_image": False,
318
+ "final_answer": ""
319
+ })
320
+
321
+ # Extract the final answer from the state
322
+ final_answer = result_state.get("final_answer", "No answer produced.")
323
+ print(f"Agent returning answer: {final_answer[:100]}...")
324
+ return final_answer
325
 
326
  def run_and_submit_all( profile: gr.OAuthProfile | None):
327
  """
 
379
  print(f"Running agent on {len(questions_data)} questions...")
380
  for item in questions_data:
381
  task_id = item.get("task_id")
382
+ # Handle both "Question" (dataset format) and "question" (API format)
383
+ question_text = item.get("question") or item.get("Question")
384
  if not task_id or question_text is None:
385
  print(f"Skipping item with missing task_id or question: {item}")
386
  continue
387
+
388
+ # Check for attached file
389
+ file_name = item.get("file_name", "")
390
+ if file_name:
391
+ print(f"Task {task_id} has attached file: {file_name}")
392
+
393
  try:
394
+ submitted_answer = agent(question_text, task_id=task_id, file_name=file_name)
395
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
396
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
397
  except Exception as e: