LangGraph-Agent / tests /test_nodes.py
Pawan Mane
Initial Changes
8986591
"""
tests/test_nodes.py
────────────────────
Unit tests for individual nodes using a mock LLM so no API key is needed.
Run with: pytest tests/
"""
import pytest
from unittest.mock import patch, MagicMock
from langchain_core.messages import HumanMessage, AIMessage
from app.state import AgentState
from app.nodes.guardrails import guardrails_node
from app.nodes.output import output_node
from app.tools.calculator import calculator
# ── Helpers ───────────────────────────────────────────────────────────────
def make_state(**overrides) -> AgentState:
base: AgentState = {
"messages": [],
"query": "test query",
"route": "general",
"rag_context": "",
"tool_calls": [],
"tool_results": [],
"response": "Hello!",
"retry_count": 0,
"hitl_approved": True,
"evaluation_score": 0.8,
"guardrail_passed": True,
"memory_summary": "",
}
return {**base, **overrides}
# ── Calculator tool ───────────────────────────────────────────────────────
def test_calculator_basic():
assert calculator.invoke({"expression": "2 + 2"}) == "4"
def test_calculator_complex():
assert calculator.invoke({"expression": "10 * 5 - 3"}) == "47"
def test_calculator_bad_input():
result = calculator.invoke({"expression": "import os"})
assert "Error" in result
# ── Guardrails node ───────────────────────────────────────────────────────
def test_guardrails_passes_clean_response():
state = make_state(response="The weather in Pune is sunny today.")
result = guardrails_node(state)
assert result["guardrail_passed"] is True
assert result["response"] == "The weather in Pune is sunny today."
def test_guardrails_blocks_harmful_response():
state = make_state(response="Here is how to cause harm to someone...")
result = guardrails_node(state)
assert result["guardrail_passed"] is False
assert "can't help" in result["response"]
# ── Output node ───────────────────────────────────────────────────────────
def test_output_node_appends_message():
state = make_state(messages=[HumanMessage(content="Hi")], response="Hello!")
result = output_node(state)
assert len(result["messages"]) == 2
assert isinstance(result["messages"][-1], AIMessage)
assert result["messages"][-1].content == "Hello!"