AgentGraph / tests /integration /test_perturbation_api.py
wu981526092's picture
Add comprehensive perturbation testing system with E2E tests
795b72e
"""
Integration tests for perturbation testing API endpoints.
Tests the backend/routers/testing.py endpoints.
"""
import pytest
import json
import os
import tempfile
import shutil
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock
# Import the app
from backend.app import app
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture
def temp_prompts_dir(tmp_path):
"""Create a temporary directory for custom prompts."""
prompts_dir = tmp_path / "custom_jailbreak_prompts"
prompts_dir.mkdir()
# Patch the CUSTOM_PROMPTS_DIR
with patch("backend.routers.testing.CUSTOM_PROMPTS_DIR", str(prompts_dir)):
yield prompts_dir
class TestListJailbreakPromptSources:
"""Tests for GET /api/testing/jailbreak-prompts/list"""
def test_list_returns_sources(self, client):
"""Test that list endpoint returns sources."""
response = client.get("/api/testing/jailbreak-prompts/list")
assert response.status_code == 200
data = response.json()
assert "sources" in data
assert isinstance(data["sources"], list)
def test_list_includes_builtin_source(self, client):
"""Test that built-in source is included."""
response = client.get("/api/testing/jailbreak-prompts/list")
data = response.json()
sources = data["sources"]
# Find the standard/builtin source
builtin = [s for s in sources if s.get("name") == "standard"]
assert len(builtin) >= 1
def test_source_has_required_fields(self, client):
"""Test that sources have required fields."""
response = client.get("/api/testing/jailbreak-prompts/list")
data = response.json()
for source in data["sources"]:
assert "name" in source
assert "description" in source
assert "count" in source
assert "source_type" in source
class TestUploadJailbreakPrompts:
"""Tests for POST /api/testing/jailbreak-prompts/upload"""
def test_upload_json_prompts(self, client, temp_prompts_dir):
"""Test uploading JSON prompts."""
prompts = [
{"name": "Test1", "prompt": "Test prompt 1"},
{"name": "Test2", "prompt": "Test prompt 2"}
]
response = client.post(
"/api/testing/jailbreak-prompts/upload?name=test_set",
files={"file": ("prompts.json", json.dumps(prompts), "application/json")}
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert data["name"] == "test_set"
assert data["prompt_count"] == 2
def test_upload_csv_prompts(self, client, temp_prompts_dir):
"""Test uploading CSV prompts."""
csv_content = "name,prompt,description\nTest1,Test prompt 1,Desc1\nTest2,Test prompt 2,Desc2"
response = client.post(
"/api/testing/jailbreak-prompts/upload?name=csv_set",
files={"file": ("prompts.csv", csv_content, "text/csv")}
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert data["prompt_count"] == 2
def test_upload_invalid_json(self, client, temp_prompts_dir):
"""Test uploading invalid JSON."""
response = client.post(
"/api/testing/jailbreak-prompts/upload?name=invalid",
files={"file": ("prompts.json", "invalid json", "application/json")}
)
assert response.status_code == 500
def test_upload_json_missing_prompt_field(self, client, temp_prompts_dir):
"""Test uploading JSON without prompt field."""
prompts = [{"name": "Test", "content": "No prompt field"}]
response = client.post(
"/api/testing/jailbreak-prompts/upload?name=missing",
files={"file": ("prompts.json", json.dumps(prompts), "application/json")}
)
assert response.status_code == 400
assert "prompt" in response.json()["detail"].lower()
def test_upload_csv_missing_prompt_column(self, client, temp_prompts_dir):
"""Test uploading CSV without prompt column."""
csv_content = "name,description\nTest1,Desc1"
response = client.post(
"/api/testing/jailbreak-prompts/upload?name=missing",
files={"file": ("prompts.csv", csv_content, "text/csv")}
)
assert response.status_code == 400
assert "prompt" in response.json()["detail"].lower()
def test_upload_unsupported_format(self, client, temp_prompts_dir):
"""Test uploading unsupported file format."""
response = client.post(
"/api/testing/jailbreak-prompts/upload?name=unsupported",
files={"file": ("prompts.txt", "text content", "text/plain")}
)
assert response.status_code == 400
class TestDeleteJailbreakPrompts:
"""Tests for DELETE /api/testing/jailbreak-prompts/{name}"""
def test_delete_custom_prompts(self, client, temp_prompts_dir):
"""Test deleting custom prompts."""
# First upload some prompts
prompts = [{"prompt": "Test"}]
client.post(
"/api/testing/jailbreak-prompts/upload?name=to_delete",
files={"file": ("prompts.json", json.dumps(prompts), "application/json")}
)
# Now delete them
response = client.delete("/api/testing/jailbreak-prompts/to_delete")
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
def test_delete_nonexistent_prompts(self, client, temp_prompts_dir):
"""Test deleting prompts that don't exist."""
response = client.delete("/api/testing/jailbreak-prompts/nonexistent")
assert response.status_code == 404
def test_cannot_delete_builtin(self, client):
"""Test that built-in dataset cannot be deleted."""
response = client.delete("/api/testing/jailbreak-prompts/standard")
assert response.status_code == 400
assert "built-in" in response.json()["detail"].lower()
class TestListPresets:
"""Tests for GET /api/testing/presets"""
def test_list_presets(self, client):
"""Test listing available presets."""
response = client.get("/api/testing/presets")
assert response.status_code == 200
data = response.json()
assert "presets" in data
assert len(data["presets"]) >= 3 # quick, standard, comprehensive
def test_preset_has_required_fields(self, client):
"""Test that presets have required fields."""
response = client.get("/api/testing/presets")
data = response.json()
for preset in data["presets"]:
assert "name" in preset
assert "description" in preset
assert "jailbreak_techniques" in preset
assert "demographics_count" in preset
assert "comparison_mode" in preset
def test_get_specific_preset(self, client):
"""Test getting a specific preset."""
for preset_name in ["quick", "standard", "comprehensive"]:
response = client.get(f"/api/testing/presets/{preset_name}")
assert response.status_code == 200
data = response.json()
assert data["name"] == preset_name
def test_get_invalid_preset(self, client):
"""Test getting an invalid preset."""
response = client.get("/api/testing/presets/invalid")
assert response.status_code == 404
class TestListDemographics:
"""Tests for GET /api/testing/demographics"""
def test_list_demographics(self, client):
"""Test listing available demographics."""
response = client.get("/api/testing/demographics")
assert response.status_code == 200
data = response.json()
assert "demographics" in data
def test_demographics_has_gender_options(self, client):
"""Test that gender options are provided."""
response = client.get("/api/testing/demographics")
data = response.json()
assert "gender" in data["demographics"]
genders = data["demographics"]["gender"]
assert "male" in genders
assert "female" in genders
def test_demographics_has_race_options(self, client):
"""Test that race options are provided."""
response = client.get("/api/testing/demographics")
data = response.json()
assert "race" in data["demographics"]
races = data["demographics"]["race"]
assert "White" in races
assert "Black" in races
def test_demographics_has_presets(self, client):
"""Test that demographic presets are provided."""
response = client.get("/api/testing/demographics")
data = response.json()
assert "presets" in data["demographics"]
presets = data["demographics"]["presets"]
assert "minimal" in presets
assert "standard" in presets
assert "comprehensive" in presets
def test_demographics_has_comparison_modes(self, client):
"""Test that comparison modes are provided."""
response = client.get("/api/testing/demographics")
data = response.json()
assert "comparison_modes" in data
modes = data["comparison_modes"]
assert "vs_baseline" in modes
assert "all_pairs" in modes
assert "both" in modes
def test_demographics_has_extended_dimensions(self, client):
"""Test that extended dimensions are provided."""
response = client.get("/api/testing/demographics")
data = response.json()
assert "extended_dimensions" in data
dims = data["extended_dimensions"]
assert "age" in dims
assert "disability" in dims
assert "socioeconomic" in dims
class TestPerturbEndpoint:
"""Tests for POST /api/knowledge-graphs/{kg_id}/perturb"""
def test_perturb_with_invalid_kg_id(self, client):
"""Test perturb endpoint with invalid kg_id."""
response = client.post(
"/api/knowledge-graphs/invalid_id/perturb",
json={}
)
# Should return error (404 or 400)
assert response.status_code in [400, 404, 500]
def test_perturb_endpoint_exists(self, client):
"""Test that perturb endpoint exists."""
# Just verify the endpoint exists (will fail with 404/400 for missing KG)
response = client.post(
"/api/knowledge-graphs/test_kg/perturb",
json={"model": "gpt-4o-mini"}
)
# Should not be 404 Method Not Allowed - endpoint exists
assert response.status_code != 405
class TestAPIWorkflow:
"""Integration tests for complete API workflows."""
def test_upload_list_delete_workflow(self, client, temp_prompts_dir):
"""Test complete workflow: upload -> list -> delete."""
# 1. Upload prompts
prompts = [{"prompt": "Test prompt 1"}, {"prompt": "Test prompt 2"}]
upload_response = client.post(
"/api/testing/jailbreak-prompts/upload?name=workflow_test",
files={"file": ("prompts.json", json.dumps(prompts), "application/json")}
)
assert upload_response.status_code == 200
# 2. List sources and verify our upload is there
list_response = client.get("/api/testing/jailbreak-prompts/list")
assert list_response.status_code == 200
sources = list_response.json()["sources"]
names = [s["name"] for s in sources]
assert "workflow_test" in names
# 3. Delete the uploaded prompts
delete_response = client.delete("/api/testing/jailbreak-prompts/workflow_test")
assert delete_response.status_code == 200
# 4. Verify deletion
list_response = client.get("/api/testing/jailbreak-prompts/list")
sources = list_response.json()["sources"]
names = [s["name"] for s in sources]
assert "workflow_test" not in names
def test_get_preset_and_demographics_for_config(self, client):
"""Test getting preset and demographics to build config."""
# Get a preset
preset_response = client.get("/api/testing/presets/standard")
assert preset_response.status_code == 200
preset = preset_response.json()
# Get demographics
demo_response = client.get("/api/testing/demographics")
assert demo_response.status_code == 200
demographics = demo_response.json()
# Verify we can build a config from these
config = {
"model": "gpt-4o-mini",
"judge_model": "gpt-4o-mini",
"max_relations": preset["max_relations"],
"jailbreak": {
"enabled": True,
"num_techniques": preset["jailbreak_techniques"]
},
"counterfactual_bias": {
"enabled": True,
"demographics": demographics["demographics"]["presets"]["standard"],
"comparison_mode": preset["comparison_mode"]
}
}
# Verify config structure is valid
assert config["jailbreak"]["num_techniques"] == 10
assert len(config["counterfactual_bias"]["demographics"]) == 4