Spaces:
Running
Running
| """ | |
| Integration tests for perturbation testing API endpoints. | |
| Tests the backend/routers/testing.py endpoints. | |
| """ | |
| import pytest | |
| import json | |
| import os | |
| import tempfile | |
| import shutil | |
| from fastapi.testclient import TestClient | |
| from unittest.mock import patch, MagicMock | |
| # Import the app | |
| from backend.app import app | |
| def client(): | |
| """Create test client.""" | |
| return TestClient(app) | |
| def temp_prompts_dir(tmp_path): | |
| """Create a temporary directory for custom prompts.""" | |
| prompts_dir = tmp_path / "custom_jailbreak_prompts" | |
| prompts_dir.mkdir() | |
| # Patch the CUSTOM_PROMPTS_DIR | |
| with patch("backend.routers.testing.CUSTOM_PROMPTS_DIR", str(prompts_dir)): | |
| yield prompts_dir | |
| class TestListJailbreakPromptSources: | |
| """Tests for GET /api/testing/jailbreak-prompts/list""" | |
| def test_list_returns_sources(self, client): | |
| """Test that list endpoint returns sources.""" | |
| response = client.get("/api/testing/jailbreak-prompts/list") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "sources" in data | |
| assert isinstance(data["sources"], list) | |
| def test_list_includes_builtin_source(self, client): | |
| """Test that built-in source is included.""" | |
| response = client.get("/api/testing/jailbreak-prompts/list") | |
| data = response.json() | |
| sources = data["sources"] | |
| # Find the standard/builtin source | |
| builtin = [s for s in sources if s.get("name") == "standard"] | |
| assert len(builtin) >= 1 | |
| def test_source_has_required_fields(self, client): | |
| """Test that sources have required fields.""" | |
| response = client.get("/api/testing/jailbreak-prompts/list") | |
| data = response.json() | |
| for source in data["sources"]: | |
| assert "name" in source | |
| assert "description" in source | |
| assert "count" in source | |
| assert "source_type" in source | |
| class TestUploadJailbreakPrompts: | |
| """Tests for POST /api/testing/jailbreak-prompts/upload""" | |
| def test_upload_json_prompts(self, client, temp_prompts_dir): | |
| """Test uploading JSON prompts.""" | |
| prompts = [ | |
| {"name": "Test1", "prompt": "Test prompt 1"}, | |
| {"name": "Test2", "prompt": "Test prompt 2"} | |
| ] | |
| response = client.post( | |
| "/api/testing/jailbreak-prompts/upload?name=test_set", | |
| files={"file": ("prompts.json", json.dumps(prompts), "application/json")} | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["status"] == "success" | |
| assert data["name"] == "test_set" | |
| assert data["prompt_count"] == 2 | |
| def test_upload_csv_prompts(self, client, temp_prompts_dir): | |
| """Test uploading CSV prompts.""" | |
| csv_content = "name,prompt,description\nTest1,Test prompt 1,Desc1\nTest2,Test prompt 2,Desc2" | |
| response = client.post( | |
| "/api/testing/jailbreak-prompts/upload?name=csv_set", | |
| files={"file": ("prompts.csv", csv_content, "text/csv")} | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["status"] == "success" | |
| assert data["prompt_count"] == 2 | |
| def test_upload_invalid_json(self, client, temp_prompts_dir): | |
| """Test uploading invalid JSON.""" | |
| response = client.post( | |
| "/api/testing/jailbreak-prompts/upload?name=invalid", | |
| files={"file": ("prompts.json", "invalid json", "application/json")} | |
| ) | |
| assert response.status_code == 500 | |
| def test_upload_json_missing_prompt_field(self, client, temp_prompts_dir): | |
| """Test uploading JSON without prompt field.""" | |
| prompts = [{"name": "Test", "content": "No prompt field"}] | |
| response = client.post( | |
| "/api/testing/jailbreak-prompts/upload?name=missing", | |
| files={"file": ("prompts.json", json.dumps(prompts), "application/json")} | |
| ) | |
| assert response.status_code == 400 | |
| assert "prompt" in response.json()["detail"].lower() | |
| def test_upload_csv_missing_prompt_column(self, client, temp_prompts_dir): | |
| """Test uploading CSV without prompt column.""" | |
| csv_content = "name,description\nTest1,Desc1" | |
| response = client.post( | |
| "/api/testing/jailbreak-prompts/upload?name=missing", | |
| files={"file": ("prompts.csv", csv_content, "text/csv")} | |
| ) | |
| assert response.status_code == 400 | |
| assert "prompt" in response.json()["detail"].lower() | |
| def test_upload_unsupported_format(self, client, temp_prompts_dir): | |
| """Test uploading unsupported file format.""" | |
| response = client.post( | |
| "/api/testing/jailbreak-prompts/upload?name=unsupported", | |
| files={"file": ("prompts.txt", "text content", "text/plain")} | |
| ) | |
| assert response.status_code == 400 | |
| class TestDeleteJailbreakPrompts: | |
| """Tests for DELETE /api/testing/jailbreak-prompts/{name}""" | |
| def test_delete_custom_prompts(self, client, temp_prompts_dir): | |
| """Test deleting custom prompts.""" | |
| # First upload some prompts | |
| prompts = [{"prompt": "Test"}] | |
| client.post( | |
| "/api/testing/jailbreak-prompts/upload?name=to_delete", | |
| files={"file": ("prompts.json", json.dumps(prompts), "application/json")} | |
| ) | |
| # Now delete them | |
| response = client.delete("/api/testing/jailbreak-prompts/to_delete") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["status"] == "success" | |
| def test_delete_nonexistent_prompts(self, client, temp_prompts_dir): | |
| """Test deleting prompts that don't exist.""" | |
| response = client.delete("/api/testing/jailbreak-prompts/nonexistent") | |
| assert response.status_code == 404 | |
| def test_cannot_delete_builtin(self, client): | |
| """Test that built-in dataset cannot be deleted.""" | |
| response = client.delete("/api/testing/jailbreak-prompts/standard") | |
| assert response.status_code == 400 | |
| assert "built-in" in response.json()["detail"].lower() | |
| class TestListPresets: | |
| """Tests for GET /api/testing/presets""" | |
| def test_list_presets(self, client): | |
| """Test listing available presets.""" | |
| response = client.get("/api/testing/presets") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "presets" in data | |
| assert len(data["presets"]) >= 3 # quick, standard, comprehensive | |
| def test_preset_has_required_fields(self, client): | |
| """Test that presets have required fields.""" | |
| response = client.get("/api/testing/presets") | |
| data = response.json() | |
| for preset in data["presets"]: | |
| assert "name" in preset | |
| assert "description" in preset | |
| assert "jailbreak_techniques" in preset | |
| assert "demographics_count" in preset | |
| assert "comparison_mode" in preset | |
| def test_get_specific_preset(self, client): | |
| """Test getting a specific preset.""" | |
| for preset_name in ["quick", "standard", "comprehensive"]: | |
| response = client.get(f"/api/testing/presets/{preset_name}") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["name"] == preset_name | |
| def test_get_invalid_preset(self, client): | |
| """Test getting an invalid preset.""" | |
| response = client.get("/api/testing/presets/invalid") | |
| assert response.status_code == 404 | |
| class TestListDemographics: | |
| """Tests for GET /api/testing/demographics""" | |
| def test_list_demographics(self, client): | |
| """Test listing available demographics.""" | |
| response = client.get("/api/testing/demographics") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "demographics" in data | |
| def test_demographics_has_gender_options(self, client): | |
| """Test that gender options are provided.""" | |
| response = client.get("/api/testing/demographics") | |
| data = response.json() | |
| assert "gender" in data["demographics"] | |
| genders = data["demographics"]["gender"] | |
| assert "male" in genders | |
| assert "female" in genders | |
| def test_demographics_has_race_options(self, client): | |
| """Test that race options are provided.""" | |
| response = client.get("/api/testing/demographics") | |
| data = response.json() | |
| assert "race" in data["demographics"] | |
| races = data["demographics"]["race"] | |
| assert "White" in races | |
| assert "Black" in races | |
| def test_demographics_has_presets(self, client): | |
| """Test that demographic presets are provided.""" | |
| response = client.get("/api/testing/demographics") | |
| data = response.json() | |
| assert "presets" in data["demographics"] | |
| presets = data["demographics"]["presets"] | |
| assert "minimal" in presets | |
| assert "standard" in presets | |
| assert "comprehensive" in presets | |
| def test_demographics_has_comparison_modes(self, client): | |
| """Test that comparison modes are provided.""" | |
| response = client.get("/api/testing/demographics") | |
| data = response.json() | |
| assert "comparison_modes" in data | |
| modes = data["comparison_modes"] | |
| assert "vs_baseline" in modes | |
| assert "all_pairs" in modes | |
| assert "both" in modes | |
| def test_demographics_has_extended_dimensions(self, client): | |
| """Test that extended dimensions are provided.""" | |
| response = client.get("/api/testing/demographics") | |
| data = response.json() | |
| assert "extended_dimensions" in data | |
| dims = data["extended_dimensions"] | |
| assert "age" in dims | |
| assert "disability" in dims | |
| assert "socioeconomic" in dims | |
| class TestPerturbEndpoint: | |
| """Tests for POST /api/knowledge-graphs/{kg_id}/perturb""" | |
| def test_perturb_with_invalid_kg_id(self, client): | |
| """Test perturb endpoint with invalid kg_id.""" | |
| response = client.post( | |
| "/api/knowledge-graphs/invalid_id/perturb", | |
| json={} | |
| ) | |
| # Should return error (404 or 400) | |
| assert response.status_code in [400, 404, 500] | |
| def test_perturb_endpoint_exists(self, client): | |
| """Test that perturb endpoint exists.""" | |
| # Just verify the endpoint exists (will fail with 404/400 for missing KG) | |
| response = client.post( | |
| "/api/knowledge-graphs/test_kg/perturb", | |
| json={"model": "gpt-4o-mini"} | |
| ) | |
| # Should not be 404 Method Not Allowed - endpoint exists | |
| assert response.status_code != 405 | |
| class TestAPIWorkflow: | |
| """Integration tests for complete API workflows.""" | |
| def test_upload_list_delete_workflow(self, client, temp_prompts_dir): | |
| """Test complete workflow: upload -> list -> delete.""" | |
| # 1. Upload prompts | |
| prompts = [{"prompt": "Test prompt 1"}, {"prompt": "Test prompt 2"}] | |
| upload_response = client.post( | |
| "/api/testing/jailbreak-prompts/upload?name=workflow_test", | |
| files={"file": ("prompts.json", json.dumps(prompts), "application/json")} | |
| ) | |
| assert upload_response.status_code == 200 | |
| # 2. List sources and verify our upload is there | |
| list_response = client.get("/api/testing/jailbreak-prompts/list") | |
| assert list_response.status_code == 200 | |
| sources = list_response.json()["sources"] | |
| names = [s["name"] for s in sources] | |
| assert "workflow_test" in names | |
| # 3. Delete the uploaded prompts | |
| delete_response = client.delete("/api/testing/jailbreak-prompts/workflow_test") | |
| assert delete_response.status_code == 200 | |
| # 4. Verify deletion | |
| list_response = client.get("/api/testing/jailbreak-prompts/list") | |
| sources = list_response.json()["sources"] | |
| names = [s["name"] for s in sources] | |
| assert "workflow_test" not in names | |
| def test_get_preset_and_demographics_for_config(self, client): | |
| """Test getting preset and demographics to build config.""" | |
| # Get a preset | |
| preset_response = client.get("/api/testing/presets/standard") | |
| assert preset_response.status_code == 200 | |
| preset = preset_response.json() | |
| # Get demographics | |
| demo_response = client.get("/api/testing/demographics") | |
| assert demo_response.status_code == 200 | |
| demographics = demo_response.json() | |
| # Verify we can build a config from these | |
| config = { | |
| "model": "gpt-4o-mini", | |
| "judge_model": "gpt-4o-mini", | |
| "max_relations": preset["max_relations"], | |
| "jailbreak": { | |
| "enabled": True, | |
| "num_techniques": preset["jailbreak_techniques"] | |
| }, | |
| "counterfactual_bias": { | |
| "enabled": True, | |
| "demographics": demographics["demographics"]["presets"]["standard"], | |
| "comparison_mode": preset["comparison_mode"] | |
| } | |
| } | |
| # Verify config structure is valid | |
| assert config["jailbreak"]["num_techniques"] == 10 | |
| assert len(config["counterfactual_bias"]["demographics"]) == 4 | |