""" Integration tests for perturbation testing API endpoints. Tests the backend/routers/testing.py endpoints. """ import pytest import json import os import tempfile import shutil from fastapi.testclient import TestClient from unittest.mock import patch, MagicMock # Import the app from backend.app import app @pytest.fixture def client(): """Create test client.""" return TestClient(app) @pytest.fixture def temp_prompts_dir(tmp_path): """Create a temporary directory for custom prompts.""" prompts_dir = tmp_path / "custom_jailbreak_prompts" prompts_dir.mkdir() # Patch the CUSTOM_PROMPTS_DIR with patch("backend.routers.testing.CUSTOM_PROMPTS_DIR", str(prompts_dir)): yield prompts_dir class TestListJailbreakPromptSources: """Tests for GET /api/testing/jailbreak-prompts/list""" def test_list_returns_sources(self, client): """Test that list endpoint returns sources.""" response = client.get("/api/testing/jailbreak-prompts/list") assert response.status_code == 200 data = response.json() assert "sources" in data assert isinstance(data["sources"], list) def test_list_includes_builtin_source(self, client): """Test that built-in source is included.""" response = client.get("/api/testing/jailbreak-prompts/list") data = response.json() sources = data["sources"] # Find the standard/builtin source builtin = [s for s in sources if s.get("name") == "standard"] assert len(builtin) >= 1 def test_source_has_required_fields(self, client): """Test that sources have required fields.""" response = client.get("/api/testing/jailbreak-prompts/list") data = response.json() for source in data["sources"]: assert "name" in source assert "description" in source assert "count" in source assert "source_type" in source class TestUploadJailbreakPrompts: """Tests for POST /api/testing/jailbreak-prompts/upload""" def test_upload_json_prompts(self, client, temp_prompts_dir): """Test uploading JSON prompts.""" prompts = [ {"name": "Test1", "prompt": "Test prompt 1"}, {"name": "Test2", "prompt": "Test prompt 2"} ] response = client.post( "/api/testing/jailbreak-prompts/upload?name=test_set", files={"file": ("prompts.json", json.dumps(prompts), "application/json")} ) assert response.status_code == 200 data = response.json() assert data["status"] == "success" assert data["name"] == "test_set" assert data["prompt_count"] == 2 def test_upload_csv_prompts(self, client, temp_prompts_dir): """Test uploading CSV prompts.""" csv_content = "name,prompt,description\nTest1,Test prompt 1,Desc1\nTest2,Test prompt 2,Desc2" response = client.post( "/api/testing/jailbreak-prompts/upload?name=csv_set", files={"file": ("prompts.csv", csv_content, "text/csv")} ) assert response.status_code == 200 data = response.json() assert data["status"] == "success" assert data["prompt_count"] == 2 def test_upload_invalid_json(self, client, temp_prompts_dir): """Test uploading invalid JSON.""" response = client.post( "/api/testing/jailbreak-prompts/upload?name=invalid", files={"file": ("prompts.json", "invalid json", "application/json")} ) assert response.status_code == 500 def test_upload_json_missing_prompt_field(self, client, temp_prompts_dir): """Test uploading JSON without prompt field.""" prompts = [{"name": "Test", "content": "No prompt field"}] response = client.post( "/api/testing/jailbreak-prompts/upload?name=missing", files={"file": ("prompts.json", json.dumps(prompts), "application/json")} ) assert response.status_code == 400 assert "prompt" in response.json()["detail"].lower() def test_upload_csv_missing_prompt_column(self, client, temp_prompts_dir): """Test uploading CSV without prompt column.""" csv_content = "name,description\nTest1,Desc1" response = client.post( "/api/testing/jailbreak-prompts/upload?name=missing", files={"file": ("prompts.csv", csv_content, "text/csv")} ) assert response.status_code == 400 assert "prompt" in response.json()["detail"].lower() def test_upload_unsupported_format(self, client, temp_prompts_dir): """Test uploading unsupported file format.""" response = client.post( "/api/testing/jailbreak-prompts/upload?name=unsupported", files={"file": ("prompts.txt", "text content", "text/plain")} ) assert response.status_code == 400 class TestDeleteJailbreakPrompts: """Tests for DELETE /api/testing/jailbreak-prompts/{name}""" def test_delete_custom_prompts(self, client, temp_prompts_dir): """Test deleting custom prompts.""" # First upload some prompts prompts = [{"prompt": "Test"}] client.post( "/api/testing/jailbreak-prompts/upload?name=to_delete", files={"file": ("prompts.json", json.dumps(prompts), "application/json")} ) # Now delete them response = client.delete("/api/testing/jailbreak-prompts/to_delete") assert response.status_code == 200 data = response.json() assert data["status"] == "success" def test_delete_nonexistent_prompts(self, client, temp_prompts_dir): """Test deleting prompts that don't exist.""" response = client.delete("/api/testing/jailbreak-prompts/nonexistent") assert response.status_code == 404 def test_cannot_delete_builtin(self, client): """Test that built-in dataset cannot be deleted.""" response = client.delete("/api/testing/jailbreak-prompts/standard") assert response.status_code == 400 assert "built-in" in response.json()["detail"].lower() class TestListPresets: """Tests for GET /api/testing/presets""" def test_list_presets(self, client): """Test listing available presets.""" response = client.get("/api/testing/presets") assert response.status_code == 200 data = response.json() assert "presets" in data assert len(data["presets"]) >= 3 # quick, standard, comprehensive def test_preset_has_required_fields(self, client): """Test that presets have required fields.""" response = client.get("/api/testing/presets") data = response.json() for preset in data["presets"]: assert "name" in preset assert "description" in preset assert "jailbreak_techniques" in preset assert "demographics_count" in preset assert "comparison_mode" in preset def test_get_specific_preset(self, client): """Test getting a specific preset.""" for preset_name in ["quick", "standard", "comprehensive"]: response = client.get(f"/api/testing/presets/{preset_name}") assert response.status_code == 200 data = response.json() assert data["name"] == preset_name def test_get_invalid_preset(self, client): """Test getting an invalid preset.""" response = client.get("/api/testing/presets/invalid") assert response.status_code == 404 class TestListDemographics: """Tests for GET /api/testing/demographics""" def test_list_demographics(self, client): """Test listing available demographics.""" response = client.get("/api/testing/demographics") assert response.status_code == 200 data = response.json() assert "demographics" in data def test_demographics_has_gender_options(self, client): """Test that gender options are provided.""" response = client.get("/api/testing/demographics") data = response.json() assert "gender" in data["demographics"] genders = data["demographics"]["gender"] assert "male" in genders assert "female" in genders def test_demographics_has_race_options(self, client): """Test that race options are provided.""" response = client.get("/api/testing/demographics") data = response.json() assert "race" in data["demographics"] races = data["demographics"]["race"] assert "White" in races assert "Black" in races def test_demographics_has_presets(self, client): """Test that demographic presets are provided.""" response = client.get("/api/testing/demographics") data = response.json() assert "presets" in data["demographics"] presets = data["demographics"]["presets"] assert "minimal" in presets assert "standard" in presets assert "comprehensive" in presets def test_demographics_has_comparison_modes(self, client): """Test that comparison modes are provided.""" response = client.get("/api/testing/demographics") data = response.json() assert "comparison_modes" in data modes = data["comparison_modes"] assert "vs_baseline" in modes assert "all_pairs" in modes assert "both" in modes def test_demographics_has_extended_dimensions(self, client): """Test that extended dimensions are provided.""" response = client.get("/api/testing/demographics") data = response.json() assert "extended_dimensions" in data dims = data["extended_dimensions"] assert "age" in dims assert "disability" in dims assert "socioeconomic" in dims class TestPerturbEndpoint: """Tests for POST /api/knowledge-graphs/{kg_id}/perturb""" def test_perturb_with_invalid_kg_id(self, client): """Test perturb endpoint with invalid kg_id.""" response = client.post( "/api/knowledge-graphs/invalid_id/perturb", json={} ) # Should return error (404 or 400) assert response.status_code in [400, 404, 500] def test_perturb_endpoint_exists(self, client): """Test that perturb endpoint exists.""" # Just verify the endpoint exists (will fail with 404/400 for missing KG) response = client.post( "/api/knowledge-graphs/test_kg/perturb", json={"model": "gpt-4o-mini"} ) # Should not be 404 Method Not Allowed - endpoint exists assert response.status_code != 405 class TestAPIWorkflow: """Integration tests for complete API workflows.""" def test_upload_list_delete_workflow(self, client, temp_prompts_dir): """Test complete workflow: upload -> list -> delete.""" # 1. Upload prompts prompts = [{"prompt": "Test prompt 1"}, {"prompt": "Test prompt 2"}] upload_response = client.post( "/api/testing/jailbreak-prompts/upload?name=workflow_test", files={"file": ("prompts.json", json.dumps(prompts), "application/json")} ) assert upload_response.status_code == 200 # 2. List sources and verify our upload is there list_response = client.get("/api/testing/jailbreak-prompts/list") assert list_response.status_code == 200 sources = list_response.json()["sources"] names = [s["name"] for s in sources] assert "workflow_test" in names # 3. Delete the uploaded prompts delete_response = client.delete("/api/testing/jailbreak-prompts/workflow_test") assert delete_response.status_code == 200 # 4. Verify deletion list_response = client.get("/api/testing/jailbreak-prompts/list") sources = list_response.json()["sources"] names = [s["name"] for s in sources] assert "workflow_test" not in names def test_get_preset_and_demographics_for_config(self, client): """Test getting preset and demographics to build config.""" # Get a preset preset_response = client.get("/api/testing/presets/standard") assert preset_response.status_code == 200 preset = preset_response.json() # Get demographics demo_response = client.get("/api/testing/demographics") assert demo_response.status_code == 200 demographics = demo_response.json() # Verify we can build a config from these config = { "model": "gpt-4o-mini", "judge_model": "gpt-4o-mini", "max_relations": preset["max_relations"], "jailbreak": { "enabled": True, "num_techniques": preset["jailbreak_techniques"] }, "counterfactual_bias": { "enabled": True, "demographics": demographics["demographics"]["presets"]["standard"], "comparison_mode": preset["comparison_mode"] } } # Verify config structure is valid assert config["jailbreak"]["num_techniques"] == 10 assert len(config["counterfactual_bias"]["demographics"]) == 4