Spaces:
Running
Running
File size: 13,330 Bytes
795b72e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 |
"""
Integration tests for perturbation testing API endpoints.
Tests the backend/routers/testing.py endpoints.
"""
import pytest
import json
import os
import tempfile
import shutil
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock
# Import the app
from backend.app import app
@pytest.fixture
def client():
"""Create test client."""
return TestClient(app)
@pytest.fixture
def temp_prompts_dir(tmp_path):
"""Create a temporary directory for custom prompts."""
prompts_dir = tmp_path / "custom_jailbreak_prompts"
prompts_dir.mkdir()
# Patch the CUSTOM_PROMPTS_DIR
with patch("backend.routers.testing.CUSTOM_PROMPTS_DIR", str(prompts_dir)):
yield prompts_dir
class TestListJailbreakPromptSources:
"""Tests for GET /api/testing/jailbreak-prompts/list"""
def test_list_returns_sources(self, client):
"""Test that list endpoint returns sources."""
response = client.get("/api/testing/jailbreak-prompts/list")
assert response.status_code == 200
data = response.json()
assert "sources" in data
assert isinstance(data["sources"], list)
def test_list_includes_builtin_source(self, client):
"""Test that built-in source is included."""
response = client.get("/api/testing/jailbreak-prompts/list")
data = response.json()
sources = data["sources"]
# Find the standard/builtin source
builtin = [s for s in sources if s.get("name") == "standard"]
assert len(builtin) >= 1
def test_source_has_required_fields(self, client):
"""Test that sources have required fields."""
response = client.get("/api/testing/jailbreak-prompts/list")
data = response.json()
for source in data["sources"]:
assert "name" in source
assert "description" in source
assert "count" in source
assert "source_type" in source
class TestUploadJailbreakPrompts:
"""Tests for POST /api/testing/jailbreak-prompts/upload"""
def test_upload_json_prompts(self, client, temp_prompts_dir):
"""Test uploading JSON prompts."""
prompts = [
{"name": "Test1", "prompt": "Test prompt 1"},
{"name": "Test2", "prompt": "Test prompt 2"}
]
response = client.post(
"/api/testing/jailbreak-prompts/upload?name=test_set",
files={"file": ("prompts.json", json.dumps(prompts), "application/json")}
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert data["name"] == "test_set"
assert data["prompt_count"] == 2
def test_upload_csv_prompts(self, client, temp_prompts_dir):
"""Test uploading CSV prompts."""
csv_content = "name,prompt,description\nTest1,Test prompt 1,Desc1\nTest2,Test prompt 2,Desc2"
response = client.post(
"/api/testing/jailbreak-prompts/upload?name=csv_set",
files={"file": ("prompts.csv", csv_content, "text/csv")}
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert data["prompt_count"] == 2
def test_upload_invalid_json(self, client, temp_prompts_dir):
"""Test uploading invalid JSON."""
response = client.post(
"/api/testing/jailbreak-prompts/upload?name=invalid",
files={"file": ("prompts.json", "invalid json", "application/json")}
)
assert response.status_code == 500
def test_upload_json_missing_prompt_field(self, client, temp_prompts_dir):
"""Test uploading JSON without prompt field."""
prompts = [{"name": "Test", "content": "No prompt field"}]
response = client.post(
"/api/testing/jailbreak-prompts/upload?name=missing",
files={"file": ("prompts.json", json.dumps(prompts), "application/json")}
)
assert response.status_code == 400
assert "prompt" in response.json()["detail"].lower()
def test_upload_csv_missing_prompt_column(self, client, temp_prompts_dir):
"""Test uploading CSV without prompt column."""
csv_content = "name,description\nTest1,Desc1"
response = client.post(
"/api/testing/jailbreak-prompts/upload?name=missing",
files={"file": ("prompts.csv", csv_content, "text/csv")}
)
assert response.status_code == 400
assert "prompt" in response.json()["detail"].lower()
def test_upload_unsupported_format(self, client, temp_prompts_dir):
"""Test uploading unsupported file format."""
response = client.post(
"/api/testing/jailbreak-prompts/upload?name=unsupported",
files={"file": ("prompts.txt", "text content", "text/plain")}
)
assert response.status_code == 400
class TestDeleteJailbreakPrompts:
"""Tests for DELETE /api/testing/jailbreak-prompts/{name}"""
def test_delete_custom_prompts(self, client, temp_prompts_dir):
"""Test deleting custom prompts."""
# First upload some prompts
prompts = [{"prompt": "Test"}]
client.post(
"/api/testing/jailbreak-prompts/upload?name=to_delete",
files={"file": ("prompts.json", json.dumps(prompts), "application/json")}
)
# Now delete them
response = client.delete("/api/testing/jailbreak-prompts/to_delete")
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
def test_delete_nonexistent_prompts(self, client, temp_prompts_dir):
"""Test deleting prompts that don't exist."""
response = client.delete("/api/testing/jailbreak-prompts/nonexistent")
assert response.status_code == 404
def test_cannot_delete_builtin(self, client):
"""Test that built-in dataset cannot be deleted."""
response = client.delete("/api/testing/jailbreak-prompts/standard")
assert response.status_code == 400
assert "built-in" in response.json()["detail"].lower()
class TestListPresets:
"""Tests for GET /api/testing/presets"""
def test_list_presets(self, client):
"""Test listing available presets."""
response = client.get("/api/testing/presets")
assert response.status_code == 200
data = response.json()
assert "presets" in data
assert len(data["presets"]) >= 3 # quick, standard, comprehensive
def test_preset_has_required_fields(self, client):
"""Test that presets have required fields."""
response = client.get("/api/testing/presets")
data = response.json()
for preset in data["presets"]:
assert "name" in preset
assert "description" in preset
assert "jailbreak_techniques" in preset
assert "demographics_count" in preset
assert "comparison_mode" in preset
def test_get_specific_preset(self, client):
"""Test getting a specific preset."""
for preset_name in ["quick", "standard", "comprehensive"]:
response = client.get(f"/api/testing/presets/{preset_name}")
assert response.status_code == 200
data = response.json()
assert data["name"] == preset_name
def test_get_invalid_preset(self, client):
"""Test getting an invalid preset."""
response = client.get("/api/testing/presets/invalid")
assert response.status_code == 404
class TestListDemographics:
"""Tests for GET /api/testing/demographics"""
def test_list_demographics(self, client):
"""Test listing available demographics."""
response = client.get("/api/testing/demographics")
assert response.status_code == 200
data = response.json()
assert "demographics" in data
def test_demographics_has_gender_options(self, client):
"""Test that gender options are provided."""
response = client.get("/api/testing/demographics")
data = response.json()
assert "gender" in data["demographics"]
genders = data["demographics"]["gender"]
assert "male" in genders
assert "female" in genders
def test_demographics_has_race_options(self, client):
"""Test that race options are provided."""
response = client.get("/api/testing/demographics")
data = response.json()
assert "race" in data["demographics"]
races = data["demographics"]["race"]
assert "White" in races
assert "Black" in races
def test_demographics_has_presets(self, client):
"""Test that demographic presets are provided."""
response = client.get("/api/testing/demographics")
data = response.json()
assert "presets" in data["demographics"]
presets = data["demographics"]["presets"]
assert "minimal" in presets
assert "standard" in presets
assert "comprehensive" in presets
def test_demographics_has_comparison_modes(self, client):
"""Test that comparison modes are provided."""
response = client.get("/api/testing/demographics")
data = response.json()
assert "comparison_modes" in data
modes = data["comparison_modes"]
assert "vs_baseline" in modes
assert "all_pairs" in modes
assert "both" in modes
def test_demographics_has_extended_dimensions(self, client):
"""Test that extended dimensions are provided."""
response = client.get("/api/testing/demographics")
data = response.json()
assert "extended_dimensions" in data
dims = data["extended_dimensions"]
assert "age" in dims
assert "disability" in dims
assert "socioeconomic" in dims
class TestPerturbEndpoint:
"""Tests for POST /api/knowledge-graphs/{kg_id}/perturb"""
def test_perturb_with_invalid_kg_id(self, client):
"""Test perturb endpoint with invalid kg_id."""
response = client.post(
"/api/knowledge-graphs/invalid_id/perturb",
json={}
)
# Should return error (404 or 400)
assert response.status_code in [400, 404, 500]
def test_perturb_endpoint_exists(self, client):
"""Test that perturb endpoint exists."""
# Just verify the endpoint exists (will fail with 404/400 for missing KG)
response = client.post(
"/api/knowledge-graphs/test_kg/perturb",
json={"model": "gpt-4o-mini"}
)
# Should not be 404 Method Not Allowed - endpoint exists
assert response.status_code != 405
class TestAPIWorkflow:
"""Integration tests for complete API workflows."""
def test_upload_list_delete_workflow(self, client, temp_prompts_dir):
"""Test complete workflow: upload -> list -> delete."""
# 1. Upload prompts
prompts = [{"prompt": "Test prompt 1"}, {"prompt": "Test prompt 2"}]
upload_response = client.post(
"/api/testing/jailbreak-prompts/upload?name=workflow_test",
files={"file": ("prompts.json", json.dumps(prompts), "application/json")}
)
assert upload_response.status_code == 200
# 2. List sources and verify our upload is there
list_response = client.get("/api/testing/jailbreak-prompts/list")
assert list_response.status_code == 200
sources = list_response.json()["sources"]
names = [s["name"] for s in sources]
assert "workflow_test" in names
# 3. Delete the uploaded prompts
delete_response = client.delete("/api/testing/jailbreak-prompts/workflow_test")
assert delete_response.status_code == 200
# 4. Verify deletion
list_response = client.get("/api/testing/jailbreak-prompts/list")
sources = list_response.json()["sources"]
names = [s["name"] for s in sources]
assert "workflow_test" not in names
def test_get_preset_and_demographics_for_config(self, client):
"""Test getting preset and demographics to build config."""
# Get a preset
preset_response = client.get("/api/testing/presets/standard")
assert preset_response.status_code == 200
preset = preset_response.json()
# Get demographics
demo_response = client.get("/api/testing/demographics")
assert demo_response.status_code == 200
demographics = demo_response.json()
# Verify we can build a config from these
config = {
"model": "gpt-4o-mini",
"judge_model": "gpt-4o-mini",
"max_relations": preset["max_relations"],
"jailbreak": {
"enabled": True,
"num_techniques": preset["jailbreak_techniques"]
},
"counterfactual_bias": {
"enabled": True,
"demographics": demographics["demographics"]["presets"]["standard"],
"comparison_mode": preset["comparison_mode"]
}
}
# Verify config structure is valid
assert config["jailbreak"]["num_techniques"] == 10
assert len(config["counterfactual_bias"]["demographics"]) == 4
|