|
|
""" |
|
|
Tests for the FastAPI endpoints. |
|
|
""" |
|
|
|
|
|
import pytest |
|
|
from fastapi.testclient import TestClient |
|
|
from httpx import AsyncClient, ASGITransport |
|
|
|
|
|
from app.main import app |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client = TestClient(app) |
|
|
|
|
|
|
|
|
class TestRootEndpoints: |
|
|
"""Tests for root endpoints.""" |
|
|
|
|
|
def test_root(self): |
|
|
"""Test root endpoint.""" |
|
|
response = client.get("/") |
|
|
assert response.status_code == 200 |
|
|
|
|
|
data = response.json() |
|
|
assert "name" in data |
|
|
assert "version" in data |
|
|
assert "endpoints" in data |
|
|
|
|
|
def test_health(self): |
|
|
"""Test health endpoint.""" |
|
|
response = client.get("/health") |
|
|
assert response.status_code == 200 |
|
|
|
|
|
data = response.json() |
|
|
assert data["status"] == "healthy" |
|
|
|
|
|
|
|
|
class TestToolsEndpoints: |
|
|
"""Tests for tools endpoints.""" |
|
|
|
|
|
def test_list_tools(self): |
|
|
"""Test listing tools.""" |
|
|
response = client.get("/tools/") |
|
|
assert response.status_code == 200 |
|
|
|
|
|
data = response.json() |
|
|
assert "tools" in data |
|
|
assert "total" in data |
|
|
assert data["total"] > 0 |
|
|
|
|
|
|
|
|
tool_names = [t["name"] for t in data["tools"]] |
|
|
assert "extract_functions" in tool_names |
|
|
assert "calculate_complexity" in tool_names |
|
|
|
|
|
def test_get_tool(self): |
|
|
"""Test getting a specific tool.""" |
|
|
response = client.get("/tools/extract_functions") |
|
|
assert response.status_code == 200 |
|
|
|
|
|
data = response.json() |
|
|
assert data["name"] == "extract_functions" |
|
|
assert "description" in data |
|
|
|
|
|
def test_get_nonexistent_tool(self): |
|
|
"""Test getting a tool that doesn't exist.""" |
|
|
response = client.get("/tools/nonexistent_tool") |
|
|
assert response.status_code == 404 |
|
|
|
|
|
|
|
|
class TestGraphEndpoints: |
|
|
"""Tests for graph endpoints.""" |
|
|
|
|
|
def test_list_graphs(self): |
|
|
"""Test listing graphs.""" |
|
|
response = client.get("/graph/") |
|
|
assert response.status_code == 200 |
|
|
|
|
|
data = response.json() |
|
|
assert "graphs" in data |
|
|
assert "total" in data |
|
|
|
|
|
def test_get_demo_workflow(self): |
|
|
"""Test getting the demo workflow.""" |
|
|
response = client.get("/graph/code-review-demo") |
|
|
assert response.status_code == 200 |
|
|
|
|
|
data = response.json() |
|
|
assert data["graph_id"] == "code-review-demo" |
|
|
assert data["name"] == "Code Review Demo" |
|
|
assert "mermaid_diagram" in data |
|
|
|
|
|
def test_create_graph(self): |
|
|
"""Test creating a new graph.""" |
|
|
graph_data = { |
|
|
"name": "test_workflow", |
|
|
"description": "A test workflow", |
|
|
"nodes": [ |
|
|
{"name": "start", "handler": "extract_functions"}, |
|
|
{"name": "end", "handler": "calculate_complexity"} |
|
|
], |
|
|
"edges": { |
|
|
"start": "end" |
|
|
}, |
|
|
"entry_point": "start" |
|
|
} |
|
|
|
|
|
response = client.post("/graph/create", json=graph_data) |
|
|
assert response.status_code == 201 |
|
|
|
|
|
data = response.json() |
|
|
assert "graph_id" in data |
|
|
assert data["name"] == "test_workflow" |
|
|
assert data["node_count"] == 2 |
|
|
|
|
|
def test_create_graph_invalid_handler(self): |
|
|
"""Test creating a graph with invalid handler.""" |
|
|
graph_data = { |
|
|
"name": "invalid_workflow", |
|
|
"nodes": [ |
|
|
{"name": "bad", "handler": "nonexistent_handler"} |
|
|
], |
|
|
"edges": {} |
|
|
} |
|
|
|
|
|
response = client.post("/graph/create", json=graph_data) |
|
|
assert response.status_code == 404 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def anyio_backend(): |
|
|
return "asyncio" |
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_run_demo_workflow(): |
|
|
"""Test running the demo workflow.""" |
|
|
transport = ASGITransport(app=app) |
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac: |
|
|
run_data = { |
|
|
"graph_id": "code-review-demo", |
|
|
"initial_state": { |
|
|
"code": "def hello():\n print('world')", |
|
|
"quality_threshold": 5.0 |
|
|
}, |
|
|
"async_execution": False |
|
|
} |
|
|
|
|
|
response = await ac.post("/graph/run", json=run_data) |
|
|
assert response.status_code == 200 |
|
|
|
|
|
data = response.json() |
|
|
assert "run_id" in data |
|
|
assert data["status"] in ["completed", "failed"] |
|
|
assert "execution_log" in data |
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_async_execution(): |
|
|
"""Test async execution mode.""" |
|
|
transport = ASGITransport(app=app) |
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac: |
|
|
run_data = { |
|
|
"graph_id": "code-review-demo", |
|
|
"initial_state": { |
|
|
"code": "def test(): pass", |
|
|
"quality_threshold": 5.0 |
|
|
}, |
|
|
"async_execution": True |
|
|
} |
|
|
|
|
|
response = await ac.post("/graph/run", json=run_data) |
|
|
assert response.status_code == 200 |
|
|
|
|
|
data = response.json() |
|
|
assert "run_id" in data |
|
|
assert data["status"] == "pending" |
|
|
|
|
|
|
|
|
run_id = data["run_id"] |
|
|
state_response = await ac.get(f"/graph/state/{run_id}") |
|
|
assert state_response.status_code == 200 |
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_run_nonexistent_graph(): |
|
|
"""Test running a graph that doesn't exist.""" |
|
|
transport = ASGITransport(app=app) |
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac: |
|
|
run_data = { |
|
|
"graph_id": "nonexistent-graph", |
|
|
"initial_state": {} |
|
|
} |
|
|
|
|
|
response = await ac.post("/graph/run", json=run_data) |
|
|
assert response.status_code == 404 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pytest.main([__file__, "-v"]) |
|
|
|