dramella's picture
added validator node
5555a89
from langchain.chat_models import init_chat_model
from urllib.parse import urlparse
import os
import base64
import pandas as pd
import requests
from io import BytesIO, StringIO
from typing import Annotated
from tools import *
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import ToolNode, tools_condition
system_prompt = """You are a general AI assistant. I will ask you a question.
You must:
1. Think step-by-step (invisibly to the user).
2. End your visible answer with the final answer only — nothing else.
Rules for the final answer:
- If the answer is a number:
• No commas in the number.
• No units (e.g., $, %, km) unless the question explicitly asks for them.
- If the answer is a string:
• No articles ("a", "an", "the").
• No abbreviations (e.g., for city names).
• Write digits as plain words unless instructed otherwise.
- If the answer is a comma-separated list:
• Apply the above rules individually to each element.
IMPORTANT:
- Do not add any extra words before or after the final answer.
- Do not explain your reasoning to the user — keep it hidden.
- The output must be exactly the final answer following the above rules.
Examples:
Q: Who wrote the novel 1984?
A: George Orwell
Q: How many plays did Shakespeare write?
A: 38
"""
class State(TypedDict):
messages: Annotated[list, add_messages]
uploaded_filename: str
uploaded_file: str
def _is_url(path_or_url: str) -> bool:
try:
result = urlparse(path_or_url)
return result.scheme in ("http", "https")
except:
return False
_ARTICLES = {"a", "an", "the"}
def _sanitize_visible_answer(text: str) -> str:
"""Keep a single-line final answer; strip quotes and leftover tags."""
if not text:
return ""
t = text.strip()
if (t.startswith('"') and t.endswith('"')) or (t.startswith("'") and t.endswith("'")):
t = t[1:-1].strip()
lines = [ln.strip() for ln in t.splitlines() if ln.strip()]
if lines:
t = lines[-1]
t = t.replace("[YOUR FINAL ANSWER]", "").strip()
t = t.replace("Final answer: ", "").strip()
t = re.sub(r"\s+", " ", t)
t = re.sub(r"<[^>]*>", "", t)
return t
def _is_number_token(s: str) -> bool:
return bool(re.fullmatch(r"-?\d+(\.\d+)?", s))
def _has_units(s: str) -> bool:
return bool(re.search(r"\d\s*[A-Za-z%$]", s))
def _has_commas_in_number(s: str) -> bool:
return bool(re.search(r"\d,\d", s))
def _starts_with_article(s: str) -> bool:
toks = re.split(r"[,\s]+", s.strip())
return bool(toks) and toks[0].lower() in _ARTICLES
def _is_valid_final_answer(ans: str) -> bool:
"""Validate against your rules:
- single line, non-empty
- if numeric → no commas, no units
- if list → each element validated as number or string
- string → no leading article
"""
if not ans or "\n" in ans:
return False
if "," in ans:
parts = [p.strip() for p in ans.split(",")]
if any(not p for p in parts):
return False
for p in parts:
if re.fullmatch(r".*\d.*", p): # contains a digit → treat as a number-like
if not _is_number_token(p):
return False
if _has_commas_in_number(p):
return False
if _has_units(p):
return False
else:
if _starts_with_article(p):
return False
return True
if re.fullmatch(r".*\d.*", ans): # number-like
if not _is_number_token(ans):
return False
if _has_commas_in_number(ans):
return False
if _has_units(ans):
return False
return True
else:
if _starts_with_article(ans):
return False
return True
def _process_uploaded_file(file_name: str, file_path: str) -> str:
"""Process a single local file or file URL and return context for the question."""
try:
if _is_url(file_path):
response = requests.get(file_path)
response.raise_for_status()
file_ext = os.path.splitext(file_name)[1].lower()
content_bytes = response.content
if file_ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']:
return f"[UPLOADED IMAGE: {file_name}] - URL: {file_path}"
elif file_ext in ['.txt', '.md', '.py', '.js', '.html', '.css', '.json', '.xml']:
content_text = content_bytes.decode('utf-8')
return f"[Code Content:\n{content_text}"
elif file_ext == '.csv':
df = pd.read_csv(StringIO(content_bytes.decode('utf-8')))
return f"[UPLOADED CSV FILE: {file_name}] : {df}"
elif file_ext in ['.xlsx', '.xls']:
df = pd.read_excel(BytesIO(content_bytes))
return f"[EXCEL FILE DATAFRAME: {df}"
else:
return f"[UPLOADED FILE: {file_name}] - URL: {file_path}"
except Exception as e:
print(f"Error processing file {file_path}: {e}")
return f"[ERROR PROCESSING FILE: {os.path.basename(file_path)}] - {str(e)}"
def build_and_compile():
graph_builder = StateGraph(State)
tools = [
web_search,
wiki_search,
academic_search,
python_code,
image_info,
read_mp3_transcript,
ocr_image,
math_solver,
plot_data_tool,
unit_converter,
date_time_calculator,
api_request_tool,
html_table_extractor,
multiply,
add,
subtract,
divide,
modulus,
power,
square_root
]
llm = init_chat_model("openai:gpt-4.1-mini",temperature=0, seed=42)
llm_with_tools = llm.bind_tools(tools)
final_llm = llm.bind(response_format={"type": "json_object"})
def chatbot(state: State):
file_context = ""
if "uploaded_file" in state and state["uploaded_file"]:
file_context = "\n\nAdditional file context:\n" + _process_uploaded_file(file_name=state["uploaded_filename"],file_path=state["uploaded_file"])
final_prompt = system_prompt + file_context
return {"messages": [llm_with_tools.invoke([SystemMessage(final_prompt)] + state["messages"])]}
def validator(state: State):
"""
Ensure the last assistant message is a valid final answer per system rules.
If invalid, rewrite once with final_llm (JSON) and output only final_answer.
"""
# Get last assistant message text
last = state["messages"][-1]
text = getattr(last, "content", "") or str(last)
# 1) sanitize
clean = _sanitize_visible_answer(text)
# 2) validate
if _is_valid_final_answer(clean):
# Replace the last message with the sanitized one-line answer
return {"messages": [{"role": "assistant", "content": clean}]}
# 3) one-shot fixer pass (no tools, JSON enforced)
fix_instruction = (
"Rewrite the final answer to comply with these rules:\n"
"- Output only the final answer (single line), no extra words.\n"
"- Numbers should always be expressed as digits.\n"
"- If number: no commas, no units.\n"
"- If string: no leading articles ('a','an','the'); no abbreviations.\n"
"- If list: comma-separated; apply the same rules to each element.\n\n"
"Return JSON: {\"final_answer\": \"...\"}."
)
msgs = [
SystemMessage(system_prompt),
{"role": "user", "content": fix_instruction + f"\n\nOriginal answer:\n{clean}"}
]
fixed = final_llm.invoke(msgs)
fixed_text = str(getattr(fixed, "content", "") or "").strip()
try:
obj = json.loads(fixed_text)
fa = (obj.get("final_answer") or "").strip()
except Exception:
# fallback: keep sanitized original if JSON parsing fails
fa = clean
fa = _sanitize_visible_answer(fa)
if not _is_valid_final_answer(fa):
# last resort: keep last line of whatever we have
fa = (fa or clean).splitlines()[-1].strip()
return {"messages": [{"role": "assistant", "content": fa}]}
graph_builder.add_node("chatbot", chatbot)
tool_node = ToolNode(tools=tools)
graph_builder.add_node("tools", tool_node)
graph_builder.add_node("validator", validator)
# If the model wants to call tools → go to tools; else → go to validator
graph_builder.add_conditional_edges(
"chatbot",
tools_condition,
{"tools": "tools", "__end__": "validator"},
)
# After tools run, go back to chatbot
graph_builder.add_edge("tools", "chatbot")
# After validator, we are done
graph_builder.add_edge("validator", END)
graph_builder.add_edge(START, "chatbot")
graph = graph_builder.compile()
return graph