dramella commited on
Commit
5555a89
·
1 Parent(s): c1b893b

added validator node

Browse files
Files changed (2) hide show
  1. agent.py +140 -4
  2. tools.py +0 -31
agent.py CHANGED
@@ -15,7 +15,6 @@ from langchain_core.messages import SystemMessage
15
 
16
  from langgraph.prebuilt import ToolNode, tools_condition
17
 
18
- SUPPORTING_FILES_URL = "https://huggingface.co/datasets/gaia-benchmark/GAIA/resolve/main/2023/validation/"
19
 
20
  system_prompt = """You are a general AI assistant. I will ask you a question.
21
 
@@ -60,6 +59,89 @@ def _is_url(path_or_url: str) -> bool:
60
  except:
61
  return False
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def _process_uploaded_file(file_name: str, file_path: str) -> str:
64
  """Process a single local file or file URL and return context for the question."""
65
  try:
@@ -101,7 +183,6 @@ def build_and_compile():
101
  python_code,
102
  image_info,
103
  read_mp3_transcript,
104
- pdf_text_extractor,
105
  ocr_image,
106
  math_solver,
107
  plot_data_tool,
@@ -121,6 +202,7 @@ def build_and_compile():
121
 
122
  llm = init_chat_model("openai:gpt-4.1-mini",temperature=0, seed=42)
123
  llm_with_tools = llm.bind_tools(tools)
 
124
 
125
  def chatbot(state: State):
126
  file_context = ""
@@ -129,18 +211,72 @@ def build_and_compile():
129
  final_prompt = system_prompt + file_context
130
  return {"messages": [llm_with_tools.invoke([SystemMessage(final_prompt)] + state["messages"])]}
131
 
 
 
 
 
 
 
 
 
132
 
133
- graph_builder.add_node("chatbot", chatbot)
 
 
 
 
 
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  tool_node = ToolNode(tools=tools)
136
  graph_builder.add_node("tools", tool_node)
 
137
 
 
138
  graph_builder.add_conditional_edges(
139
  "chatbot",
140
  tools_condition,
 
141
  )
142
- # Any time a tool is called, we return to the chatbot to decide the next step
 
143
  graph_builder.add_edge("tools", "chatbot")
 
 
 
 
144
  graph_builder.add_edge(START, "chatbot")
 
145
  graph = graph_builder.compile()
146
  return graph
 
15
 
16
  from langgraph.prebuilt import ToolNode, tools_condition
17
 
 
18
 
19
  system_prompt = """You are a general AI assistant. I will ask you a question.
20
 
 
59
  except:
60
  return False
61
 
62
+ _ARTICLES = {"a", "an", "the"}
63
+
64
+ def _sanitize_visible_answer(text: str) -> str:
65
+ """Keep a single-line final answer; strip quotes and leftover tags."""
66
+ if not text:
67
+ return ""
68
+ t = text.strip()
69
+
70
+ if (t.startswith('"') and t.endswith('"')) or (t.startswith("'") and t.endswith("'")):
71
+ t = t[1:-1].strip()
72
+
73
+ lines = [ln.strip() for ln in t.splitlines() if ln.strip()]
74
+ if lines:
75
+ t = lines[-1]
76
+
77
+ t = t.replace("[YOUR FINAL ANSWER]", "").strip()
78
+ t = t.replace("Final answer: ", "").strip()
79
+
80
+
81
+ t = re.sub(r"\s+", " ", t)
82
+ t = re.sub(r"<[^>]*>", "", t)
83
+
84
+ return t
85
+
86
+
87
+ def _is_number_token(s: str) -> bool:
88
+ return bool(re.fullmatch(r"-?\d+(\.\d+)?", s))
89
+
90
+
91
+ def _has_units(s: str) -> bool:
92
+ return bool(re.search(r"\d\s*[A-Za-z%$]", s))
93
+
94
+
95
+ def _has_commas_in_number(s: str) -> bool:
96
+ return bool(re.search(r"\d,\d", s))
97
+
98
+
99
+ def _starts_with_article(s: str) -> bool:
100
+ toks = re.split(r"[,\s]+", s.strip())
101
+ return bool(toks) and toks[0].lower() in _ARTICLES
102
+
103
+
104
+ def _is_valid_final_answer(ans: str) -> bool:
105
+ """Validate against your rules:
106
+ - single line, non-empty
107
+ - if numeric → no commas, no units
108
+ - if list → each element validated as number or string
109
+ - string → no leading article
110
+ """
111
+ if not ans or "\n" in ans:
112
+ return False
113
+
114
+ if "," in ans:
115
+ parts = [p.strip() for p in ans.split(",")]
116
+ if any(not p for p in parts):
117
+ return False
118
+ for p in parts:
119
+ if re.fullmatch(r".*\d.*", p): # contains a digit → treat as a number-like
120
+ if not _is_number_token(p):
121
+ return False
122
+ if _has_commas_in_number(p):
123
+ return False
124
+ if _has_units(p):
125
+ return False
126
+ else:
127
+ if _starts_with_article(p):
128
+ return False
129
+ return True
130
+
131
+ if re.fullmatch(r".*\d.*", ans): # number-like
132
+ if not _is_number_token(ans):
133
+ return False
134
+ if _has_commas_in_number(ans):
135
+ return False
136
+ if _has_units(ans):
137
+ return False
138
+ return True
139
+ else:
140
+ if _starts_with_article(ans):
141
+ return False
142
+ return True
143
+
144
+
145
  def _process_uploaded_file(file_name: str, file_path: str) -> str:
146
  """Process a single local file or file URL and return context for the question."""
147
  try:
 
183
  python_code,
184
  image_info,
185
  read_mp3_transcript,
 
186
  ocr_image,
187
  math_solver,
188
  plot_data_tool,
 
202
 
203
  llm = init_chat_model("openai:gpt-4.1-mini",temperature=0, seed=42)
204
  llm_with_tools = llm.bind_tools(tools)
205
+ final_llm = llm.bind(response_format={"type": "json_object"})
206
 
207
  def chatbot(state: State):
208
  file_context = ""
 
211
  final_prompt = system_prompt + file_context
212
  return {"messages": [llm_with_tools.invoke([SystemMessage(final_prompt)] + state["messages"])]}
213
 
214
+ def validator(state: State):
215
+ """
216
+ Ensure the last assistant message is a valid final answer per system rules.
217
+ If invalid, rewrite once with final_llm (JSON) and output only final_answer.
218
+ """
219
+ # Get last assistant message text
220
+ last = state["messages"][-1]
221
+ text = getattr(last, "content", "") or str(last)
222
 
223
+ # 1) sanitize
224
+ clean = _sanitize_visible_answer(text)
225
+
226
+ # 2) validate
227
+ if _is_valid_final_answer(clean):
228
+ # Replace the last message with the sanitized one-line answer
229
+ return {"messages": [{"role": "assistant", "content": clean}]}
230
 
231
+ # 3) one-shot fixer pass (no tools, JSON enforced)
232
+ fix_instruction = (
233
+ "Rewrite the final answer to comply with these rules:\n"
234
+ "- Output only the final answer (single line), no extra words.\n"
235
+ "- Numbers should always be expressed as digits.\n"
236
+ "- If number: no commas, no units.\n"
237
+ "- If string: no leading articles ('a','an','the'); no abbreviations.\n"
238
+ "- If list: comma-separated; apply the same rules to each element.\n\n"
239
+ "Return JSON: {\"final_answer\": \"...\"}."
240
+ )
241
+ msgs = [
242
+ SystemMessage(system_prompt),
243
+ {"role": "user", "content": fix_instruction + f"\n\nOriginal answer:\n{clean}"}
244
+ ]
245
+ fixed = final_llm.invoke(msgs)
246
+ fixed_text = str(getattr(fixed, "content", "") or "").strip()
247
+ try:
248
+ obj = json.loads(fixed_text)
249
+ fa = (obj.get("final_answer") or "").strip()
250
+ except Exception:
251
+ # fallback: keep sanitized original if JSON parsing fails
252
+ fa = clean
253
+
254
+ fa = _sanitize_visible_answer(fa)
255
+ if not _is_valid_final_answer(fa):
256
+ # last resort: keep last line of whatever we have
257
+ fa = (fa or clean).splitlines()[-1].strip()
258
+
259
+ return {"messages": [{"role": "assistant", "content": fa}]}
260
+
261
+ graph_builder.add_node("chatbot", chatbot)
262
  tool_node = ToolNode(tools=tools)
263
  graph_builder.add_node("tools", tool_node)
264
+ graph_builder.add_node("validator", validator)
265
 
266
+ # If the model wants to call tools → go to tools; else → go to validator
267
  graph_builder.add_conditional_edges(
268
  "chatbot",
269
  tools_condition,
270
+ {"tools": "tools", "__end__": "validator"},
271
  )
272
+
273
+ # After tools run, go back to chatbot
274
  graph_builder.add_edge("tools", "chatbot")
275
+
276
+ # After validator, we are done
277
+ graph_builder.add_edge("validator", END)
278
+
279
  graph_builder.add_edge(START, "chatbot")
280
+
281
  graph = graph_builder.compile()
282
  return graph
tools.py CHANGED
@@ -211,37 +211,6 @@ def read_mp3_transcript(path: str) -> str:
211
  return _fmt_error("read_mp3_transcript", e)
212
 
213
 
214
- @tool("pdf_text_extractor")
215
- def pdf_text_extractor(args: str) -> str:
216
- """Extract text from a PDF. Usage:
217
- - 'path/to/file.pdf'
218
- - 'path/to/file.pdf|pages=1-3' (1-indexed inclusive range)
219
- Returns a concatenated text excerpt (truncated)."""
220
- try:
221
- if pdfplumber is None:
222
- raise RuntimeError("pdfplumber not installed")
223
- path, start, end = args, None, None
224
- m = re.search(r"\|pages=(\d+)-(\d+)$", args.strip())
225
- if m:
226
- path = args[: args.rfind("|pages=")]
227
- start, end = int(m.group(1)), int(m.group(2))
228
- text_parts: List[str] = []
229
- with pdfplumber.open(path) as pdf:
230
- total = len(pdf.pages)
231
- s = max(1, start) if start else 1
232
- e = min(end, total) if end else total
233
- for p in range(s - 1, e):
234
- page = pdf.pages[p]
235
- text_parts.append(page.extract_text() or "")
236
- text = "\n".join(text_parts).strip()
237
- if not text:
238
- text = "(no extractable text)"
239
- meta = {"path": path, "pages": f"{start or 1}-{end or 'end'}"}
240
- return _fmt_block("PDFText", meta, _truncate(text, 4000))
241
- except Exception as e:
242
- return _fmt_error("pdf_text_extractor", e)
243
-
244
-
245
  @tool("ocr_image")
246
  def ocr_image(path: str) -> str:
247
  """Run OCR on an image and return extracted text (requires pytesseract + Tesseract installed)."""
 
211
  return _fmt_error("read_mp3_transcript", e)
212
 
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  @tool("ocr_image")
215
  def ocr_image(path: str) -> str:
216
  """Run OCR on an image and return extracted text (requires pytesseract + Tesseract installed)."""