Spaces:
Running
Running
File size: 2,865 Bytes
8986591 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | """
tests/test_nodes.py
ββββββββββββββββββββ
Unit tests for individual nodes using a mock LLM so no API key is needed.
Run with: pytest tests/
"""
import pytest
from unittest.mock import patch, MagicMock
from langchain_core.messages import HumanMessage, AIMessage
from app.state import AgentState
from app.nodes.guardrails import guardrails_node
from app.nodes.output import output_node
from app.tools.calculator import calculator
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def make_state(**overrides) -> AgentState:
base: AgentState = {
"messages": [],
"query": "test query",
"route": "general",
"rag_context": "",
"tool_calls": [],
"tool_results": [],
"response": "Hello!",
"retry_count": 0,
"hitl_approved": True,
"evaluation_score": 0.8,
"guardrail_passed": True,
"memory_summary": "",
}
return {**base, **overrides}
# ββ Calculator tool βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_calculator_basic():
assert calculator.invoke({"expression": "2 + 2"}) == "4"
def test_calculator_complex():
assert calculator.invoke({"expression": "10 * 5 - 3"}) == "47"
def test_calculator_bad_input():
result = calculator.invoke({"expression": "import os"})
assert "Error" in result
# ββ Guardrails node βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_guardrails_passes_clean_response():
state = make_state(response="The weather in Pune is sunny today.")
result = guardrails_node(state)
assert result["guardrail_passed"] is True
assert result["response"] == "The weather in Pune is sunny today."
def test_guardrails_blocks_harmful_response():
state = make_state(response="Here is how to cause harm to someone...")
result = guardrails_node(state)
assert result["guardrail_passed"] is False
assert "can't help" in result["response"]
# ββ Output node βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_output_node_appends_message():
state = make_state(messages=[HumanMessage(content="Hi")], response="Hello!")
result = output_node(state)
assert len(result["messages"]) == 2
assert isinstance(result["messages"][-1], AIMessage)
assert result["messages"][-1].content == "Hello!"
|