TommasoBB commited on
Commit
32c95ee
·
verified ·
1 Parent(s): 9705726

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -23
app.py CHANGED
@@ -8,6 +8,7 @@ import inspect
8
  import pandas as pd
9
  import tools
10
  from smolagents import CodeAgent
 
11
  try:
12
  from smolagents import InferenceClientModel as _HFModel # smolagents >= 1.0
13
  except ImportError:
@@ -15,10 +16,13 @@ except ImportError:
15
  from smolagents.models import HfApiModel as _HFModel
16
  except ImportError:
17
  from smolagents import HfApiModel as _HFModel
18
-
19
  from typing import TypedDict, List, Dict, Any, Optional
20
  from langgraph.graph import StateGraph, START, END
21
- from langchain_core.messages import HumanMessage
 
 
 
 
22
 
23
 
24
  # (Keep Constants as is)
@@ -29,19 +33,21 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
29
  def _build_hf_model(model_name: str):
30
  """Build a text model across smolagents versions."""
31
  for kwargs in (
 
32
  {"model_id": model_name, "max_new_tokens": 2048, "temperature": 0.3},
 
33
  {"repo_id": model_name, "max_new_tokens": 2048, "temperature": 0.3},
34
  ):
35
  try:
36
  return _HFModel(**kwargs)
37
- except (TypeError, Exception):
38
  continue
39
  raise RuntimeError(f"Cannot instantiate model {model_name} with available smolagents version")
40
 
41
 
42
  # Text/math models via smolagents
43
- model = _build_hf_model("Qwen3.5-35B-A3B")
44
- math_model = _build_hf_model("Qwen/Qwen2.5-Math-1.5B")
45
 
46
  # FireRed OCR (Transformers) loaded lazily to avoid startup crashes
47
  _fire_red_model = None
@@ -82,8 +88,6 @@ def _extract_text_from_response(response: Any) -> str:
82
  return str(content)
83
  return str(response)
84
 
85
-
86
-
87
  #define the state
88
  class AgentState(TypedDict):
89
  question: str
@@ -94,6 +98,7 @@ class AgentState(TypedDict):
94
  is_math: Optional[bool]
95
  have_image: Optional[bool]
96
  final_answer: Optional[str] # The final answer produced by the agent
 
97
  messages: List[Dict[str, Any]] # Track conversation with LLM for analysis
98
  #define nodes
99
 
@@ -119,8 +124,9 @@ def classify(state: AgentState) -> str:
119
  "have_image": false
120
  }}
121
  """
122
- messages = [HumanMessage(content=prompt)]
123
- response = model.invoke(messages)
 
124
  # Parse JSON from the model's response
125
  import json, re
126
  match = re.search(r'\{.*?\}', raw, re.DOTALL)
@@ -134,7 +140,6 @@ def classify(state: AgentState) -> str:
134
  have_file = bool(data.get("have_file", False))
135
  is_math = bool(data.get("is_math", False))
136
  have_image = bool(data.get("have_image", False))
137
-
138
  print(f"Classification result: is_searching={is_searching}, have_file={have_file}, is_math={is_math}, have_image={have_image}")
139
  mew_messages = state.get("messages", []) + [
140
  {"role": "system", "content": "Classify the question to determine which tools to use."},
@@ -178,7 +183,7 @@ def handle_image(state: AgentState) -> str:
178
 
179
  # Use ImageReaderTool to download the image as base64
180
  image_reader = tools.ImageReaderTool()
181
- image_data_uri = image_reader(task_id) if task_id and file_name else ""
182
 
183
  if not image_data_uri or image_data_uri.startswith("Failed"):
184
  print(f"Could not download image for task {task_id}")
@@ -203,8 +208,7 @@ Return a JSON object with the following fields:
203
  "transcribed_text": "All text visible in the image transcribed here."
204
  }}"""
205
 
206
-
207
-
208
  try:
209
  # Decode base64 data URI into bytes/PIL image
210
  _, b64_data = image_data_uri.split(",", 1)
@@ -275,7 +279,7 @@ def handle_file(state: AgentState) -> str:
275
 
276
  # Use the file_reader tool to fetch the file content
277
  file_reader = tools.FileReaderTool()
278
- file_content = file_reader(task_id) if task_id and file_name else ""
279
 
280
  # Build prompt with the retrieved file content
281
  file_context = ""
@@ -293,8 +297,8 @@ Return a JSON object with the following field:
293
  {{
294
  "extracted_info": "The relevant extracted information from the file."
295
  }}"""
296
- messages = [HumanMessage(content=prompt)]
297
- response = model.invoke(messages)
298
  extracted_info = _extract_text_from_response(response)
299
  print(f"Extracted file info: {extracted_info[:100]}...")
300
  new_messages = state.get("messages", []) + [
@@ -311,8 +315,8 @@ def handle_math(state: AgentState) -> str:
311
  """Agent handles a math problem if classified as a math problem."""
312
  question = state["question"]
313
  print(f"Agent is handling a math problem: {question[:50]}...")
314
- messages = [HumanMessage(content=f"Solve the following math problem step by step:\n\n{question}")]
315
- response = math_model.invoke(messages)
316
  solution = _extract_text_from_response(response)
317
  print(f"Math solution: {solution[:100]}...")
318
  new_messages = state.get("messages", []) + [
@@ -345,11 +349,11 @@ Question: {question}
345
  Context gathered:
346
  {context}
347
  """
348
- messages = [HumanMessage(content=prompt)]
349
  # Use the general model for final answer synthesis
350
- response = model.invoke(messages)
351
  raw_response = _extract_text_from_response(response)
352
-
353
  # Extract the final answer after "FINAL ANSWER:" if present
354
  if "FINAL ANSWER:" in raw_response:
355
  final_answer = raw_response.split("FINAL ANSWER:")[-1].strip()
@@ -360,6 +364,52 @@ Context gathered:
360
  return {"final_answer": final_answer}
361
 
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  def route_after_classify(state: AgentState) -> str:
364
  """Routing function: decide which handler to invoke based on classification."""
365
  if state.get("have_image"):
@@ -383,6 +433,7 @@ agent_graph.add_node("handle_image", handle_image)
383
  agent_graph.add_node("handle_file", handle_file)
384
  agent_graph.add_node("handle_math", handle_math)
385
  agent_graph.add_node("answer", answer)
 
386
 
387
  agent_graph.add_edge(START, "read")
388
  agent_graph.add_edge("read", "classify")
@@ -395,7 +446,11 @@ agent_graph.add_edge("handle_search", "answer")
395
  agent_graph.add_edge("handle_image", "answer")
396
  agent_graph.add_edge("handle_file", "answer")
397
  agent_graph.add_edge("handle_math", "answer")
398
- agent_graph.add_edge("answer", END)
 
 
 
 
399
 
400
  compiled_agent = agent_graph.compile()
401
 
@@ -424,7 +479,8 @@ class BasicAgent:
424
  "have_file": False,
425
  "is_math": False,
426
  "have_image": False,
427
- "final_answer": ""
 
428
  })
429
 
430
  # Extract the final answer from the state
 
8
  import pandas as pd
9
  import tools
10
  from smolagents import CodeAgent
11
+ # Resolve the correct LLM model class across smolagents versions
12
  try:
13
  from smolagents import InferenceClientModel as _HFModel # smolagents >= 1.0
14
  except ImportError:
 
16
  from smolagents.models import HfApiModel as _HFModel
17
  except ImportError:
18
  from smolagents import HfApiModel as _HFModel
 
19
  from typing import TypedDict, List, Dict, Any, Optional
20
  from langgraph.graph import StateGraph, START, END
21
+ from langchain_core.messages import HumanMessage # kept for LangGraph compatibility
22
+
23
+ # Helper to build a smolagents-compatible message list
24
+ def _msg(content: str) -> list:
25
+ return [{"role": "user", "content": content}]
26
 
27
 
28
  # (Keep Constants as is)
 
33
  def _build_hf_model(model_name: str):
34
  """Build a text model across smolagents versions."""
35
  for kwargs in (
36
+ {"model_id": model_name, "max_tokens": 2048, "temperature": 0.3},
37
  {"model_id": model_name, "max_new_tokens": 2048, "temperature": 0.3},
38
+ {"repo_id": model_name, "max_tokens": 2048, "temperature": 0.3},
39
  {"repo_id": model_name, "max_new_tokens": 2048, "temperature": 0.3},
40
  ):
41
  try:
42
  return _HFModel(**kwargs)
43
+ except TypeError:
44
  continue
45
  raise RuntimeError(f"Cannot instantiate model {model_name} with available smolagents version")
46
 
47
 
48
  # Text/math models via smolagents
49
+ model = _build_hf_model("meta-llama/Llama-3.2-3B-Instruct") # General model for classification and final answer synthesis
50
+ math_model = _build_hf_model("deepseek-ai/deepseek-math-7b-instruct")
51
 
52
  # FireRed OCR (Transformers) loaded lazily to avoid startup crashes
53
  _fire_red_model = None
 
88
  return str(content)
89
  return str(response)
90
 
 
 
91
  #define the state
92
  class AgentState(TypedDict):
93
  question: str
 
98
  is_math: Optional[bool]
99
  have_image: Optional[bool]
100
  final_answer: Optional[str] # The final answer produced by the agent
101
+ retry_count: Optional[int] # Number of retries so far
102
  messages: List[Dict[str, Any]] # Track conversation with LLM for analysis
103
  #define nodes
104
 
 
124
  "have_image": false
125
  }}
126
  """
127
+ messages = _msg(prompt)
128
+ response = model(messages)
129
+ raw = _extract_text_from_response(response)
130
  # Parse JSON from the model's response
131
  import json, re
132
  match = re.search(r'\{.*?\}', raw, re.DOTALL)
 
140
  have_file = bool(data.get("have_file", False))
141
  is_math = bool(data.get("is_math", False))
142
  have_image = bool(data.get("have_image", False))
 
143
  print(f"Classification result: is_searching={is_searching}, have_file={have_file}, is_math={is_math}, have_image={have_image}")
144
  mew_messages = state.get("messages", []) + [
145
  {"role": "system", "content": "Classify the question to determine which tools to use."},
 
183
 
184
  # Use ImageReaderTool to download the image as base64
185
  image_reader = tools.ImageReaderTool()
186
+ image_data_uri = image_reader(task_id, file_name) if task_id and file_name else ""
187
 
188
  if not image_data_uri or image_data_uri.startswith("Failed"):
189
  print(f"Could not download image for task {task_id}")
 
208
  "transcribed_text": "All text visible in the image transcribed here."
209
  }}"""
210
 
211
+ # Run OCR through FireRed-OCR using Transformers
 
212
  try:
213
  # Decode base64 data URI into bytes/PIL image
214
  _, b64_data = image_data_uri.split(",", 1)
 
279
 
280
  # Use the file_reader tool to fetch the file content
281
  file_reader = tools.FileReaderTool()
282
+ file_content = file_reader(task_id, file_name) if task_id and file_name else ""
283
 
284
  # Build prompt with the retrieved file content
285
  file_context = ""
 
297
  {{
298
  "extracted_info": "The relevant extracted information from the file."
299
  }}"""
300
+ messages = _msg(prompt)
301
+ response = model(messages)
302
  extracted_info = _extract_text_from_response(response)
303
  print(f"Extracted file info: {extracted_info[:100]}...")
304
  new_messages = state.get("messages", []) + [
 
315
  """Agent handles a math problem if classified as a math problem."""
316
  question = state["question"]
317
  print(f"Agent is handling a math problem: {question[:50]}...")
318
+ messages = _msg(f"Solve the following math problem step by step:\n\n{question}")
319
+ response = math_model(messages)
320
  solution = _extract_text_from_response(response)
321
  print(f"Math solution: {solution[:100]}...")
322
  new_messages = state.get("messages", []) + [
 
349
  Context gathered:
350
  {context}
351
  """
352
+ messages = _msg(prompt)
353
  # Use the general model for final answer synthesis
354
+ response = model(messages)
355
  raw_response = _extract_text_from_response(response)
356
+
357
  # Extract the final answer after "FINAL ANSWER:" if present
358
  if "FINAL ANSWER:" in raw_response:
359
  final_answer = raw_response.split("FINAL ANSWER:")[-1].strip()
 
364
  return {"final_answer": final_answer}
365
 
366
 
367
+ def evaluate(state: AgentState) -> dict:
368
+ """LLM evaluates whether the current final_answer is adequate.
369
+ If not, increments retry_count so the graph can loop back."""
370
+ import json, re
371
+ question = state["question"]
372
+ current_answer = state.get("final_answer", "")
373
+ retry_count = state.get("retry_count", 0) or 0
374
+
375
+ prompt = f"""You are a strict evaluator. Given the question and a candidate answer, decide if the answer is complete, relevant, and not an error message.
376
+
377
+ Question: {question}
378
+ Candidate answer: {current_answer}
379
+
380
+ Return ONLY a JSON object:
381
+ {{"is_adequate": true}} if the answer looks correct and complete,
382
+ {{"is_adequate": false}} if the answer is wrong, incomplete, an error, or just says it could not find information."""
383
+
384
+ response = model(_msg(prompt))
385
+ raw = _extract_text_from_response(response)
386
+ match = re.search(r'\{.*?\}', raw, re.DOTALL)
387
+ data = {}
388
+ if match:
389
+ try:
390
+ data = json.loads(match.group())
391
+ except json.JSONDecodeError:
392
+ pass
393
+ is_adequate = bool(data.get("is_adequate", True)) # default: accept
394
+ print(f"Evaluation: is_adequate={is_adequate}, retry_count={retry_count}")
395
+ return {
396
+ "retry_count": retry_count + (0 if is_adequate else 1),
397
+ "is_searching": False if not is_adequate else state.get("is_searching"),
398
+ "have_file": False if not is_adequate else state.get("have_file"),
399
+ "is_math": False if not is_adequate else state.get("is_math"),
400
+ "have_image": False if not is_adequate else state.get("have_image"),
401
+ }
402
+
403
+
404
+ def route_after_evaluate(state: AgentState) -> str:
405
+ """If answer was inadequate and retries remain, search web for more context."""
406
+ retry_count = state.get("retry_count", 0) or 0
407
+ if retry_count > 0 and retry_count <= 2:
408
+ print(f"Answer inadequate — retry {retry_count}/2, routing to web search")
409
+ return "handle_search"
410
+ return END
411
+
412
+
413
  def route_after_classify(state: AgentState) -> str:
414
  """Routing function: decide which handler to invoke based on classification."""
415
  if state.get("have_image"):
 
433
  agent_graph.add_node("handle_file", handle_file)
434
  agent_graph.add_node("handle_math", handle_math)
435
  agent_graph.add_node("answer", answer)
436
+ agent_graph.add_node("evaluate", evaluate)
437
 
438
  agent_graph.add_edge(START, "read")
439
  agent_graph.add_edge("read", "classify")
 
446
  agent_graph.add_edge("handle_image", "answer")
447
  agent_graph.add_edge("handle_file", "answer")
448
  agent_graph.add_edge("handle_math", "answer")
449
+ agent_graph.add_edge("answer", "evaluate")
450
+ agent_graph.add_conditional_edges(
451
+ "evaluate",
452
+ route_after_evaluate,
453
+ )
454
 
455
  compiled_agent = agent_graph.compile()
456
 
 
479
  "have_file": False,
480
  "is_math": False,
481
  "have_image": False,
482
+ "final_answer": "",
483
+ "retry_count": 0
484
  })
485
 
486
  # Extract the final answer from the state