Spaces:
Sleeping
Sleeping
added validator node
Browse files
agent.py
CHANGED
|
@@ -15,7 +15,6 @@ from langchain_core.messages import SystemMessage
|
|
| 15 |
|
| 16 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 17 |
|
| 18 |
-
SUPPORTING_FILES_URL = "https://huggingface.co/datasets/gaia-benchmark/GAIA/resolve/main/2023/validation/"
|
| 19 |
|
| 20 |
system_prompt = """You are a general AI assistant. I will ask you a question.
|
| 21 |
|
|
@@ -60,6 +59,89 @@ def _is_url(path_or_url: str) -> bool:
|
|
| 60 |
except:
|
| 61 |
return False
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
def _process_uploaded_file(file_name: str, file_path: str) -> str:
|
| 64 |
"""Process a single local file or file URL and return context for the question."""
|
| 65 |
try:
|
|
@@ -101,7 +183,6 @@ def build_and_compile():
|
|
| 101 |
python_code,
|
| 102 |
image_info,
|
| 103 |
read_mp3_transcript,
|
| 104 |
-
pdf_text_extractor,
|
| 105 |
ocr_image,
|
| 106 |
math_solver,
|
| 107 |
plot_data_tool,
|
|
@@ -121,6 +202,7 @@ def build_and_compile():
|
|
| 121 |
|
| 122 |
llm = init_chat_model("openai:gpt-4.1-mini",temperature=0, seed=42)
|
| 123 |
llm_with_tools = llm.bind_tools(tools)
|
|
|
|
| 124 |
|
| 125 |
def chatbot(state: State):
|
| 126 |
file_context = ""
|
|
@@ -129,18 +211,72 @@ def build_and_compile():
|
|
| 129 |
final_prompt = system_prompt + file_context
|
| 130 |
return {"messages": [llm_with_tools.invoke([SystemMessage(final_prompt)] + state["messages"])]}
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
tool_node = ToolNode(tools=tools)
|
| 136 |
graph_builder.add_node("tools", tool_node)
|
|
|
|
| 137 |
|
|
|
|
| 138 |
graph_builder.add_conditional_edges(
|
| 139 |
"chatbot",
|
| 140 |
tools_condition,
|
|
|
|
| 141 |
)
|
| 142 |
-
|
|
|
|
| 143 |
graph_builder.add_edge("tools", "chatbot")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
graph_builder.add_edge(START, "chatbot")
|
|
|
|
| 145 |
graph = graph_builder.compile()
|
| 146 |
return graph
|
|
|
|
| 15 |
|
| 16 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 17 |
|
|
|
|
| 18 |
|
| 19 |
system_prompt = """You are a general AI assistant. I will ask you a question.
|
| 20 |
|
|
|
|
| 59 |
except:
|
| 60 |
return False
|
| 61 |
|
| 62 |
+
_ARTICLES = {"a", "an", "the"}
|
| 63 |
+
|
| 64 |
+
def _sanitize_visible_answer(text: str) -> str:
|
| 65 |
+
"""Keep a single-line final answer; strip quotes and leftover tags."""
|
| 66 |
+
if not text:
|
| 67 |
+
return ""
|
| 68 |
+
t = text.strip()
|
| 69 |
+
|
| 70 |
+
if (t.startswith('"') and t.endswith('"')) or (t.startswith("'") and t.endswith("'")):
|
| 71 |
+
t = t[1:-1].strip()
|
| 72 |
+
|
| 73 |
+
lines = [ln.strip() for ln in t.splitlines() if ln.strip()]
|
| 74 |
+
if lines:
|
| 75 |
+
t = lines[-1]
|
| 76 |
+
|
| 77 |
+
t = t.replace("[YOUR FINAL ANSWER]", "").strip()
|
| 78 |
+
t = t.replace("Final answer: ", "").strip()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
t = re.sub(r"\s+", " ", t)
|
| 82 |
+
t = re.sub(r"<[^>]*>", "", t)
|
| 83 |
+
|
| 84 |
+
return t
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _is_number_token(s: str) -> bool:
|
| 88 |
+
return bool(re.fullmatch(r"-?\d+(\.\d+)?", s))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _has_units(s: str) -> bool:
|
| 92 |
+
return bool(re.search(r"\d\s*[A-Za-z%$]", s))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _has_commas_in_number(s: str) -> bool:
|
| 96 |
+
return bool(re.search(r"\d,\d", s))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _starts_with_article(s: str) -> bool:
|
| 100 |
+
toks = re.split(r"[,\s]+", s.strip())
|
| 101 |
+
return bool(toks) and toks[0].lower() in _ARTICLES
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _is_valid_final_answer(ans: str) -> bool:
|
| 105 |
+
"""Validate against your rules:
|
| 106 |
+
- single line, non-empty
|
| 107 |
+
- if numeric → no commas, no units
|
| 108 |
+
- if list → each element validated as number or string
|
| 109 |
+
- string → no leading article
|
| 110 |
+
"""
|
| 111 |
+
if not ans or "\n" in ans:
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
if "," in ans:
|
| 115 |
+
parts = [p.strip() for p in ans.split(",")]
|
| 116 |
+
if any(not p for p in parts):
|
| 117 |
+
return False
|
| 118 |
+
for p in parts:
|
| 119 |
+
if re.fullmatch(r".*\d.*", p): # contains a digit → treat as a number-like
|
| 120 |
+
if not _is_number_token(p):
|
| 121 |
+
return False
|
| 122 |
+
if _has_commas_in_number(p):
|
| 123 |
+
return False
|
| 124 |
+
if _has_units(p):
|
| 125 |
+
return False
|
| 126 |
+
else:
|
| 127 |
+
if _starts_with_article(p):
|
| 128 |
+
return False
|
| 129 |
+
return True
|
| 130 |
+
|
| 131 |
+
if re.fullmatch(r".*\d.*", ans): # number-like
|
| 132 |
+
if not _is_number_token(ans):
|
| 133 |
+
return False
|
| 134 |
+
if _has_commas_in_number(ans):
|
| 135 |
+
return False
|
| 136 |
+
if _has_units(ans):
|
| 137 |
+
return False
|
| 138 |
+
return True
|
| 139 |
+
else:
|
| 140 |
+
if _starts_with_article(ans):
|
| 141 |
+
return False
|
| 142 |
+
return True
|
| 143 |
+
|
| 144 |
+
|
| 145 |
def _process_uploaded_file(file_name: str, file_path: str) -> str:
|
| 146 |
"""Process a single local file or file URL and return context for the question."""
|
| 147 |
try:
|
|
|
|
| 183 |
python_code,
|
| 184 |
image_info,
|
| 185 |
read_mp3_transcript,
|
|
|
|
| 186 |
ocr_image,
|
| 187 |
math_solver,
|
| 188 |
plot_data_tool,
|
|
|
|
| 202 |
|
| 203 |
llm = init_chat_model("openai:gpt-4.1-mini",temperature=0, seed=42)
|
| 204 |
llm_with_tools = llm.bind_tools(tools)
|
| 205 |
+
final_llm = llm.bind(response_format={"type": "json_object"})
|
| 206 |
|
| 207 |
def chatbot(state: State):
|
| 208 |
file_context = ""
|
|
|
|
| 211 |
final_prompt = system_prompt + file_context
|
| 212 |
return {"messages": [llm_with_tools.invoke([SystemMessage(final_prompt)] + state["messages"])]}
|
| 213 |
|
| 214 |
+
def validator(state: State):
|
| 215 |
+
"""
|
| 216 |
+
Ensure the last assistant message is a valid final answer per system rules.
|
| 217 |
+
If invalid, rewrite once with final_llm (JSON) and output only final_answer.
|
| 218 |
+
"""
|
| 219 |
+
# Get last assistant message text
|
| 220 |
+
last = state["messages"][-1]
|
| 221 |
+
text = getattr(last, "content", "") or str(last)
|
| 222 |
|
| 223 |
+
# 1) sanitize
|
| 224 |
+
clean = _sanitize_visible_answer(text)
|
| 225 |
+
|
| 226 |
+
# 2) validate
|
| 227 |
+
if _is_valid_final_answer(clean):
|
| 228 |
+
# Replace the last message with the sanitized one-line answer
|
| 229 |
+
return {"messages": [{"role": "assistant", "content": clean}]}
|
| 230 |
|
| 231 |
+
# 3) one-shot fixer pass (no tools, JSON enforced)
|
| 232 |
+
fix_instruction = (
|
| 233 |
+
"Rewrite the final answer to comply with these rules:\n"
|
| 234 |
+
"- Output only the final answer (single line), no extra words.\n"
|
| 235 |
+
"- Numbers should always be expressed as digits.\n"
|
| 236 |
+
"- If number: no commas, no units.\n"
|
| 237 |
+
"- If string: no leading articles ('a','an','the'); no abbreviations.\n"
|
| 238 |
+
"- If list: comma-separated; apply the same rules to each element.\n\n"
|
| 239 |
+
"Return JSON: {\"final_answer\": \"...\"}."
|
| 240 |
+
)
|
| 241 |
+
msgs = [
|
| 242 |
+
SystemMessage(system_prompt),
|
| 243 |
+
{"role": "user", "content": fix_instruction + f"\n\nOriginal answer:\n{clean}"}
|
| 244 |
+
]
|
| 245 |
+
fixed = final_llm.invoke(msgs)
|
| 246 |
+
fixed_text = str(getattr(fixed, "content", "") or "").strip()
|
| 247 |
+
try:
|
| 248 |
+
obj = json.loads(fixed_text)
|
| 249 |
+
fa = (obj.get("final_answer") or "").strip()
|
| 250 |
+
except Exception:
|
| 251 |
+
# fallback: keep sanitized original if JSON parsing fails
|
| 252 |
+
fa = clean
|
| 253 |
+
|
| 254 |
+
fa = _sanitize_visible_answer(fa)
|
| 255 |
+
if not _is_valid_final_answer(fa):
|
| 256 |
+
# last resort: keep last line of whatever we have
|
| 257 |
+
fa = (fa or clean).splitlines()[-1].strip()
|
| 258 |
+
|
| 259 |
+
return {"messages": [{"role": "assistant", "content": fa}]}
|
| 260 |
+
|
| 261 |
+
graph_builder.add_node("chatbot", chatbot)
|
| 262 |
tool_node = ToolNode(tools=tools)
|
| 263 |
graph_builder.add_node("tools", tool_node)
|
| 264 |
+
graph_builder.add_node("validator", validator)
|
| 265 |
|
| 266 |
+
# If the model wants to call tools → go to tools; else → go to validator
|
| 267 |
graph_builder.add_conditional_edges(
|
| 268 |
"chatbot",
|
| 269 |
tools_condition,
|
| 270 |
+
{"tools": "tools", "__end__": "validator"},
|
| 271 |
)
|
| 272 |
+
|
| 273 |
+
# After tools run, go back to chatbot
|
| 274 |
graph_builder.add_edge("tools", "chatbot")
|
| 275 |
+
|
| 276 |
+
# After validator, we are done
|
| 277 |
+
graph_builder.add_edge("validator", END)
|
| 278 |
+
|
| 279 |
graph_builder.add_edge(START, "chatbot")
|
| 280 |
+
|
| 281 |
graph = graph_builder.compile()
|
| 282 |
return graph
|
tools.py
CHANGED
|
@@ -211,37 +211,6 @@ def read_mp3_transcript(path: str) -> str:
|
|
| 211 |
return _fmt_error("read_mp3_transcript", e)
|
| 212 |
|
| 213 |
|
| 214 |
-
@tool("pdf_text_extractor")
|
| 215 |
-
def pdf_text_extractor(args: str) -> str:
|
| 216 |
-
"""Extract text from a PDF. Usage:
|
| 217 |
-
- 'path/to/file.pdf'
|
| 218 |
-
- 'path/to/file.pdf|pages=1-3' (1-indexed inclusive range)
|
| 219 |
-
Returns a concatenated text excerpt (truncated)."""
|
| 220 |
-
try:
|
| 221 |
-
if pdfplumber is None:
|
| 222 |
-
raise RuntimeError("pdfplumber not installed")
|
| 223 |
-
path, start, end = args, None, None
|
| 224 |
-
m = re.search(r"\|pages=(\d+)-(\d+)$", args.strip())
|
| 225 |
-
if m:
|
| 226 |
-
path = args[: args.rfind("|pages=")]
|
| 227 |
-
start, end = int(m.group(1)), int(m.group(2))
|
| 228 |
-
text_parts: List[str] = []
|
| 229 |
-
with pdfplumber.open(path) as pdf:
|
| 230 |
-
total = len(pdf.pages)
|
| 231 |
-
s = max(1, start) if start else 1
|
| 232 |
-
e = min(end, total) if end else total
|
| 233 |
-
for p in range(s - 1, e):
|
| 234 |
-
page = pdf.pages[p]
|
| 235 |
-
text_parts.append(page.extract_text() or "")
|
| 236 |
-
text = "\n".join(text_parts).strip()
|
| 237 |
-
if not text:
|
| 238 |
-
text = "(no extractable text)"
|
| 239 |
-
meta = {"path": path, "pages": f"{start or 1}-{end or 'end'}"}
|
| 240 |
-
return _fmt_block("PDFText", meta, _truncate(text, 4000))
|
| 241 |
-
except Exception as e:
|
| 242 |
-
return _fmt_error("pdf_text_extractor", e)
|
| 243 |
-
|
| 244 |
-
|
| 245 |
@tool("ocr_image")
|
| 246 |
def ocr_image(path: str) -> str:
|
| 247 |
"""Run OCR on an image and return extracted text (requires pytesseract + Tesseract installed)."""
|
|
|
|
| 211 |
return _fmt_error("read_mp3_transcript", e)
|
| 212 |
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
@tool("ocr_image")
|
| 215 |
def ocr_image(path: str) -> str:
|
| 216 |
"""Run OCR on an image and return extracted text (requires pytesseract + Tesseract installed)."""
|