data_pipeline_agent / tests /test_workflow.py
cryogenic22's picture
Create tests/test_workflow.py
9fb584d verified
"""
Tests for the pharmaceutical data management agent workflow.
"""
import unittest
import os
import json
from unittest.mock import MagicMock, patch
from anthropic import Anthropic
from langgraph.graph import START
from agents.state import AgentState
from graph.workflow import create_agent_graph
class TestAgentWorkflow(unittest.TestCase):
"""Test case for the agent workflow."""
def setUp(self):
"""Set up test fixtures."""
# Mock the Anthropic client
self.anthropic_client = MagicMock()
self.anthropic_client.messages.create.return_value = MagicMock(
content=[MagicMock(text="INTENT_COMPLETE\nThis is a test intent.")]
)
# Mock the database
self.db = MagicMock()
self.db.get_tables.return_value = {
"raw_tables": ["RAW_SALES_TRANSACTIONS", "RAW_HCP_DATA", "RAW_PRODUCT_DATA"],
"staging_tables": ["STG_SALES", "STG_HCP", "STG_PRODUCT"],
"ard_tables": ["ARD_SALES_PERFORMANCE", "ARD_HCP_ENGAGEMENT", "ARD_MARKET_ANALYSIS"],
"data_products": ["DP_SALES_DASHBOARD", "DP_HCP_TARGETING"]
}
# Create a test agent graph
self.agent_graph, self.update_state_dict = create_agent_graph(
self.anthropic_client, self.db
)
def test_workflow_has_entry_point(self):
"""Test that the workflow has an entry point."""
# Get the graph builder
graph_builder = self.agent_graph.builder
# Check if there's an edge from START
start_edges = [edge for edge in graph_builder.edges if edge[0] == START]
# Assert that there's at least one edge from START
self.assertTrue(len(start_edges) > 0, "Workflow must have at least one edge from START")
def test_workflow_execution(self):
"""Test the workflow execution."""
# Prepare the input state
input_state = {
"messages": [{"role": "user", "content": "Create a dashboard showing sales by region"}],
"user_intent": {},
"data_context": {},
"pipeline_plan": {},
"sql_queries": [],
"execution_results": {},
"confidence_scores": {},
"status": "planning",
"current_agent": "understanding_agent"
}
# Invoke the workflow
result = self.agent_graph.invoke(input_state)
# Check that the workflow updated the state
self.assertIn("user_intent", result)
self.assertTrue(result["user_intent"].get("understood", False))
if __name__ == "__main__":
unittest.main()