Update app.py
Browse files
app.py
CHANGED
|
@@ -8,6 +8,7 @@ import inspect
|
|
| 8 |
import pandas as pd
|
| 9 |
import tools
|
| 10 |
from smolagents import CodeAgent
|
|
|
|
| 11 |
try:
|
| 12 |
from smolagents import InferenceClientModel as _HFModel # smolagents >= 1.0
|
| 13 |
except ImportError:
|
|
@@ -15,10 +16,13 @@ except ImportError:
|
|
| 15 |
from smolagents.models import HfApiModel as _HFModel
|
| 16 |
except ImportError:
|
| 17 |
from smolagents import HfApiModel as _HFModel
|
| 18 |
-
|
| 19 |
from typing import TypedDict, List, Dict, Any, Optional
|
| 20 |
from langgraph.graph import StateGraph, START, END
|
| 21 |
-
from langchain_core.messages import HumanMessage
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
# (Keep Constants as is)
|
|
@@ -29,19 +33,21 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
|
| 29 |
def _build_hf_model(model_name: str):
|
| 30 |
"""Build a text model across smolagents versions."""
|
| 31 |
for kwargs in (
|
|
|
|
| 32 |
{"model_id": model_name, "max_new_tokens": 2048, "temperature": 0.3},
|
|
|
|
| 33 |
{"repo_id": model_name, "max_new_tokens": 2048, "temperature": 0.3},
|
| 34 |
):
|
| 35 |
try:
|
| 36 |
return _HFModel(**kwargs)
|
| 37 |
-
except
|
| 38 |
continue
|
| 39 |
raise RuntimeError(f"Cannot instantiate model {model_name} with available smolagents version")
|
| 40 |
|
| 41 |
|
| 42 |
# Text/math models via smolagents
|
| 43 |
-
model = _build_hf_model("
|
| 44 |
-
math_model = _build_hf_model("
|
| 45 |
|
| 46 |
# FireRed OCR (Transformers) loaded lazily to avoid startup crashes
|
| 47 |
_fire_red_model = None
|
|
@@ -82,8 +88,6 @@ def _extract_text_from_response(response: Any) -> str:
|
|
| 82 |
return str(content)
|
| 83 |
return str(response)
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
#define the state
|
| 88 |
class AgentState(TypedDict):
|
| 89 |
question: str
|
|
@@ -94,6 +98,7 @@ class AgentState(TypedDict):
|
|
| 94 |
is_math: Optional[bool]
|
| 95 |
have_image: Optional[bool]
|
| 96 |
final_answer: Optional[str] # The final answer produced by the agent
|
|
|
|
| 97 |
messages: List[Dict[str, Any]] # Track conversation with LLM for analysis
|
| 98 |
#define nodes
|
| 99 |
|
|
@@ -119,8 +124,9 @@ def classify(state: AgentState) -> str:
|
|
| 119 |
"have_image": false
|
| 120 |
}}
|
| 121 |
"""
|
| 122 |
-
messages =
|
| 123 |
-
response = model
|
|
|
|
| 124 |
# Parse JSON from the model's response
|
| 125 |
import json, re
|
| 126 |
match = re.search(r'\{.*?\}', raw, re.DOTALL)
|
|
@@ -134,7 +140,6 @@ def classify(state: AgentState) -> str:
|
|
| 134 |
have_file = bool(data.get("have_file", False))
|
| 135 |
is_math = bool(data.get("is_math", False))
|
| 136 |
have_image = bool(data.get("have_image", False))
|
| 137 |
-
|
| 138 |
print(f"Classification result: is_searching={is_searching}, have_file={have_file}, is_math={is_math}, have_image={have_image}")
|
| 139 |
mew_messages = state.get("messages", []) + [
|
| 140 |
{"role": "system", "content": "Classify the question to determine which tools to use."},
|
|
@@ -178,7 +183,7 @@ def handle_image(state: AgentState) -> str:
|
|
| 178 |
|
| 179 |
# Use ImageReaderTool to download the image as base64
|
| 180 |
image_reader = tools.ImageReaderTool()
|
| 181 |
-
image_data_uri = image_reader(task_id) if task_id and file_name else ""
|
| 182 |
|
| 183 |
if not image_data_uri or image_data_uri.startswith("Failed"):
|
| 184 |
print(f"Could not download image for task {task_id}")
|
|
@@ -203,8 +208,7 @@ Return a JSON object with the following fields:
|
|
| 203 |
"transcribed_text": "All text visible in the image transcribed here."
|
| 204 |
}}"""
|
| 205 |
|
| 206 |
-
|
| 207 |
-
|
| 208 |
try:
|
| 209 |
# Decode base64 data URI into bytes/PIL image
|
| 210 |
_, b64_data = image_data_uri.split(",", 1)
|
|
@@ -275,7 +279,7 @@ def handle_file(state: AgentState) -> str:
|
|
| 275 |
|
| 276 |
# Use the file_reader tool to fetch the file content
|
| 277 |
file_reader = tools.FileReaderTool()
|
| 278 |
-
file_content = file_reader(task_id) if task_id and file_name else ""
|
| 279 |
|
| 280 |
# Build prompt with the retrieved file content
|
| 281 |
file_context = ""
|
|
@@ -293,8 +297,8 @@ Return a JSON object with the following field:
|
|
| 293 |
{{
|
| 294 |
"extracted_info": "The relevant extracted information from the file."
|
| 295 |
}}"""
|
| 296 |
-
messages =
|
| 297 |
-
response = model
|
| 298 |
extracted_info = _extract_text_from_response(response)
|
| 299 |
print(f"Extracted file info: {extracted_info[:100]}...")
|
| 300 |
new_messages = state.get("messages", []) + [
|
|
@@ -311,8 +315,8 @@ def handle_math(state: AgentState) -> str:
|
|
| 311 |
"""Agent handles a math problem if classified as a math problem."""
|
| 312 |
question = state["question"]
|
| 313 |
print(f"Agent is handling a math problem: {question[:50]}...")
|
| 314 |
-
messages =
|
| 315 |
-
response = math_model
|
| 316 |
solution = _extract_text_from_response(response)
|
| 317 |
print(f"Math solution: {solution[:100]}...")
|
| 318 |
new_messages = state.get("messages", []) + [
|
|
@@ -345,11 +349,11 @@ Question: {question}
|
|
| 345 |
Context gathered:
|
| 346 |
{context}
|
| 347 |
"""
|
| 348 |
-
messages =
|
| 349 |
# Use the general model for final answer synthesis
|
| 350 |
-
response = model
|
| 351 |
raw_response = _extract_text_from_response(response)
|
| 352 |
-
|
| 353 |
# Extract the final answer after "FINAL ANSWER:" if present
|
| 354 |
if "FINAL ANSWER:" in raw_response:
|
| 355 |
final_answer = raw_response.split("FINAL ANSWER:")[-1].strip()
|
|
@@ -360,6 +364,52 @@ Context gathered:
|
|
| 360 |
return {"final_answer": final_answer}
|
| 361 |
|
| 362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
def route_after_classify(state: AgentState) -> str:
|
| 364 |
"""Routing function: decide which handler to invoke based on classification."""
|
| 365 |
if state.get("have_image"):
|
|
@@ -383,6 +433,7 @@ agent_graph.add_node("handle_image", handle_image)
|
|
| 383 |
agent_graph.add_node("handle_file", handle_file)
|
| 384 |
agent_graph.add_node("handle_math", handle_math)
|
| 385 |
agent_graph.add_node("answer", answer)
|
|
|
|
| 386 |
|
| 387 |
agent_graph.add_edge(START, "read")
|
| 388 |
agent_graph.add_edge("read", "classify")
|
|
@@ -395,7 +446,11 @@ agent_graph.add_edge("handle_search", "answer")
|
|
| 395 |
agent_graph.add_edge("handle_image", "answer")
|
| 396 |
agent_graph.add_edge("handle_file", "answer")
|
| 397 |
agent_graph.add_edge("handle_math", "answer")
|
| 398 |
-
agent_graph.add_edge("answer",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
|
| 400 |
compiled_agent = agent_graph.compile()
|
| 401 |
|
|
@@ -424,7 +479,8 @@ class BasicAgent:
|
|
| 424 |
"have_file": False,
|
| 425 |
"is_math": False,
|
| 426 |
"have_image": False,
|
| 427 |
-
"final_answer": ""
|
|
|
|
| 428 |
})
|
| 429 |
|
| 430 |
# Extract the final answer from the state
|
|
|
|
| 8 |
import pandas as pd
|
| 9 |
import tools
|
| 10 |
from smolagents import CodeAgent
|
| 11 |
+
# Resolve the correct LLM model class across smolagents versions
|
| 12 |
try:
|
| 13 |
from smolagents import InferenceClientModel as _HFModel # smolagents >= 1.0
|
| 14 |
except ImportError:
|
|
|
|
| 16 |
from smolagents.models import HfApiModel as _HFModel
|
| 17 |
except ImportError:
|
| 18 |
from smolagents import HfApiModel as _HFModel
|
|
|
|
| 19 |
from typing import TypedDict, List, Dict, Any, Optional
|
| 20 |
from langgraph.graph import StateGraph, START, END
|
| 21 |
+
from langchain_core.messages import HumanMessage # kept for LangGraph compatibility
|
| 22 |
+
|
| 23 |
+
# Helper to build a smolagents-compatible message list
|
| 24 |
+
def _msg(content: str) -> list:
|
| 25 |
+
return [{"role": "user", "content": content}]
|
| 26 |
|
| 27 |
|
| 28 |
# (Keep Constants as is)
|
|
|
|
| 33 |
def _build_hf_model(model_name: str):
|
| 34 |
"""Build a text model across smolagents versions."""
|
| 35 |
for kwargs in (
|
| 36 |
+
{"model_id": model_name, "max_tokens": 2048, "temperature": 0.3},
|
| 37 |
{"model_id": model_name, "max_new_tokens": 2048, "temperature": 0.3},
|
| 38 |
+
{"repo_id": model_name, "max_tokens": 2048, "temperature": 0.3},
|
| 39 |
{"repo_id": model_name, "max_new_tokens": 2048, "temperature": 0.3},
|
| 40 |
):
|
| 41 |
try:
|
| 42 |
return _HFModel(**kwargs)
|
| 43 |
+
except TypeError:
|
| 44 |
continue
|
| 45 |
raise RuntimeError(f"Cannot instantiate model {model_name} with available smolagents version")
|
| 46 |
|
| 47 |
|
| 48 |
# Text/math models via smolagents
|
| 49 |
+
model = _build_hf_model("meta-llama/Llama-3.2-3B-Instruct") # General model for classification and final answer synthesis
|
| 50 |
+
math_model = _build_hf_model("deepseek-ai/deepseek-math-7b-instruct")
|
| 51 |
|
| 52 |
# FireRed OCR (Transformers) loaded lazily to avoid startup crashes
|
| 53 |
_fire_red_model = None
|
|
|
|
| 88 |
return str(content)
|
| 89 |
return str(response)
|
| 90 |
|
|
|
|
|
|
|
| 91 |
#define the state
|
| 92 |
class AgentState(TypedDict):
|
| 93 |
question: str
|
|
|
|
| 98 |
is_math: Optional[bool]
|
| 99 |
have_image: Optional[bool]
|
| 100 |
final_answer: Optional[str] # The final answer produced by the agent
|
| 101 |
+
retry_count: Optional[int] # Number of retries so far
|
| 102 |
messages: List[Dict[str, Any]] # Track conversation with LLM for analysis
|
| 103 |
#define nodes
|
| 104 |
|
|
|
|
| 124 |
"have_image": false
|
| 125 |
}}
|
| 126 |
"""
|
| 127 |
+
messages = _msg(prompt)
|
| 128 |
+
response = model(messages)
|
| 129 |
+
raw = _extract_text_from_response(response)
|
| 130 |
# Parse JSON from the model's response
|
| 131 |
import json, re
|
| 132 |
match = re.search(r'\{.*?\}', raw, re.DOTALL)
|
|
|
|
| 140 |
have_file = bool(data.get("have_file", False))
|
| 141 |
is_math = bool(data.get("is_math", False))
|
| 142 |
have_image = bool(data.get("have_image", False))
|
|
|
|
| 143 |
print(f"Classification result: is_searching={is_searching}, have_file={have_file}, is_math={is_math}, have_image={have_image}")
|
| 144 |
mew_messages = state.get("messages", []) + [
|
| 145 |
{"role": "system", "content": "Classify the question to determine which tools to use."},
|
|
|
|
| 183 |
|
| 184 |
# Use ImageReaderTool to download the image as base64
|
| 185 |
image_reader = tools.ImageReaderTool()
|
| 186 |
+
image_data_uri = image_reader(task_id, file_name) if task_id and file_name else ""
|
| 187 |
|
| 188 |
if not image_data_uri or image_data_uri.startswith("Failed"):
|
| 189 |
print(f"Could not download image for task {task_id}")
|
|
|
|
| 208 |
"transcribed_text": "All text visible in the image transcribed here."
|
| 209 |
}}"""
|
| 210 |
|
| 211 |
+
# Run OCR through FireRed-OCR using Transformers
|
|
|
|
| 212 |
try:
|
| 213 |
# Decode base64 data URI into bytes/PIL image
|
| 214 |
_, b64_data = image_data_uri.split(",", 1)
|
|
|
|
| 279 |
|
| 280 |
# Use the file_reader tool to fetch the file content
|
| 281 |
file_reader = tools.FileReaderTool()
|
| 282 |
+
file_content = file_reader(task_id, file_name) if task_id and file_name else ""
|
| 283 |
|
| 284 |
# Build prompt with the retrieved file content
|
| 285 |
file_context = ""
|
|
|
|
| 297 |
{{
|
| 298 |
"extracted_info": "The relevant extracted information from the file."
|
| 299 |
}}"""
|
| 300 |
+
messages = _msg(prompt)
|
| 301 |
+
response = model(messages)
|
| 302 |
extracted_info = _extract_text_from_response(response)
|
| 303 |
print(f"Extracted file info: {extracted_info[:100]}...")
|
| 304 |
new_messages = state.get("messages", []) + [
|
|
|
|
| 315 |
"""Agent handles a math problem if classified as a math problem."""
|
| 316 |
question = state["question"]
|
| 317 |
print(f"Agent is handling a math problem: {question[:50]}...")
|
| 318 |
+
messages = _msg(f"Solve the following math problem step by step:\n\n{question}")
|
| 319 |
+
response = math_model(messages)
|
| 320 |
solution = _extract_text_from_response(response)
|
| 321 |
print(f"Math solution: {solution[:100]}...")
|
| 322 |
new_messages = state.get("messages", []) + [
|
|
|
|
| 349 |
Context gathered:
|
| 350 |
{context}
|
| 351 |
"""
|
| 352 |
+
messages = _msg(prompt)
|
| 353 |
# Use the general model for final answer synthesis
|
| 354 |
+
response = model(messages)
|
| 355 |
raw_response = _extract_text_from_response(response)
|
| 356 |
+
|
| 357 |
# Extract the final answer after "FINAL ANSWER:" if present
|
| 358 |
if "FINAL ANSWER:" in raw_response:
|
| 359 |
final_answer = raw_response.split("FINAL ANSWER:")[-1].strip()
|
|
|
|
| 364 |
return {"final_answer": final_answer}
|
| 365 |
|
| 366 |
|
| 367 |
+
def evaluate(state: AgentState) -> dict:
|
| 368 |
+
"""LLM evaluates whether the current final_answer is adequate.
|
| 369 |
+
If not, increments retry_count so the graph can loop back."""
|
| 370 |
+
import json, re
|
| 371 |
+
question = state["question"]
|
| 372 |
+
current_answer = state.get("final_answer", "")
|
| 373 |
+
retry_count = state.get("retry_count", 0) or 0
|
| 374 |
+
|
| 375 |
+
prompt = f"""You are a strict evaluator. Given the question and a candidate answer, decide if the answer is complete, relevant, and not an error message.
|
| 376 |
+
|
| 377 |
+
Question: {question}
|
| 378 |
+
Candidate answer: {current_answer}
|
| 379 |
+
|
| 380 |
+
Return ONLY a JSON object:
|
| 381 |
+
{{"is_adequate": true}} if the answer looks correct and complete,
|
| 382 |
+
{{"is_adequate": false}} if the answer is wrong, incomplete, an error, or just says it could not find information."""
|
| 383 |
+
|
| 384 |
+
response = model(_msg(prompt))
|
| 385 |
+
raw = _extract_text_from_response(response)
|
| 386 |
+
match = re.search(r'\{.*?\}', raw, re.DOTALL)
|
| 387 |
+
data = {}
|
| 388 |
+
if match:
|
| 389 |
+
try:
|
| 390 |
+
data = json.loads(match.group())
|
| 391 |
+
except json.JSONDecodeError:
|
| 392 |
+
pass
|
| 393 |
+
is_adequate = bool(data.get("is_adequate", True)) # default: accept
|
| 394 |
+
print(f"Evaluation: is_adequate={is_adequate}, retry_count={retry_count}")
|
| 395 |
+
return {
|
| 396 |
+
"retry_count": retry_count + (0 if is_adequate else 1),
|
| 397 |
+
"is_searching": False if not is_adequate else state.get("is_searching"),
|
| 398 |
+
"have_file": False if not is_adequate else state.get("have_file"),
|
| 399 |
+
"is_math": False if not is_adequate else state.get("is_math"),
|
| 400 |
+
"have_image": False if not is_adequate else state.get("have_image"),
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def route_after_evaluate(state: AgentState) -> str:
|
| 405 |
+
"""If answer was inadequate and retries remain, search web for more context."""
|
| 406 |
+
retry_count = state.get("retry_count", 0) or 0
|
| 407 |
+
if retry_count > 0 and retry_count <= 2:
|
| 408 |
+
print(f"Answer inadequate — retry {retry_count}/2, routing to web search")
|
| 409 |
+
return "handle_search"
|
| 410 |
+
return END
|
| 411 |
+
|
| 412 |
+
|
| 413 |
def route_after_classify(state: AgentState) -> str:
|
| 414 |
"""Routing function: decide which handler to invoke based on classification."""
|
| 415 |
if state.get("have_image"):
|
|
|
|
| 433 |
agent_graph.add_node("handle_file", handle_file)
|
| 434 |
agent_graph.add_node("handle_math", handle_math)
|
| 435 |
agent_graph.add_node("answer", answer)
|
| 436 |
+
agent_graph.add_node("evaluate", evaluate)
|
| 437 |
|
| 438 |
agent_graph.add_edge(START, "read")
|
| 439 |
agent_graph.add_edge("read", "classify")
|
|
|
|
| 446 |
agent_graph.add_edge("handle_image", "answer")
|
| 447 |
agent_graph.add_edge("handle_file", "answer")
|
| 448 |
agent_graph.add_edge("handle_math", "answer")
|
| 449 |
+
agent_graph.add_edge("answer", "evaluate")
|
| 450 |
+
agent_graph.add_conditional_edges(
|
| 451 |
+
"evaluate",
|
| 452 |
+
route_after_evaluate,
|
| 453 |
+
)
|
| 454 |
|
| 455 |
compiled_agent = agent_graph.compile()
|
| 456 |
|
|
|
|
| 479 |
"have_file": False,
|
| 480 |
"is_math": False,
|
| 481 |
"have_image": False,
|
| 482 |
+
"final_answer": "",
|
| 483 |
+
"retry_count": 0
|
| 484 |
})
|
| 485 |
|
| 486 |
# Extract the final answer from the state
|