AgentGraph / backend /services /test_service.py
wu981526092's picture
add
7bc750c
"""
Service for test-related operations
"""
import os
import json
import time
import logging
import traceback
import threading
from typing import Dict, List, Any, Optional
from pathlib import Path
from backend.server_config import TEST_RESULTS_FILE, PROJECT_ROOT
logger = logging.getLogger("agent_monitoring_server.services.test")
class TestService:
"""Service for test-related operations"""
@staticmethod
def test_relation(tester, relation_id: str, model: str, system_prompt: Optional[str] = None) -> Dict[str, Any]:
"""Test a specific relation"""
try:
result = tester.test_relation(relation_id, model, system_prompt)
return result
except Exception as e:
logger.error(f"Error testing relation: {str(e)}")
logger.error(traceback.format_exc())
raise
@staticmethod
def test_relation_with_jailbreak(
tester,
relation_id: str,
model: str,
jailbreak_index: int,
system_prompt: Optional[str] = None
) -> Dict[str, Any]:
"""Test a relation with a jailbreak attack"""
try:
result = tester.test_relation_with_jailbreak(relation_id, model, jailbreak_index, system_prompt)
return result
except Exception as e:
logger.error(f"Error testing relation with jailbreak: {str(e)}")
logger.error(traceback.format_exc())
raise
@staticmethod
def test_relations_by_type(
tester,
relation_type: str,
model: str,
system_prompt: Optional[str] = None
) -> Dict[str, Any]:
"""Test all relations of a specific type"""
try:
# Get relations of the specified type
relations = tester.list_relations(relation_type)
if not relations:
return {
"message": "No relations found of the specified type",
"results": {}
}
# Get relation IDs
relation_ids = [r["id"] for r in relations]
# Test the relations and save results
results = tester.test_multiple_relations(relation_ids, model, system_prompt)
tester.save_results(TEST_RESULTS_FILE, results)
return {
"message": f"Tested {len(relation_ids)} relations of type {relation_type}",
"relation_count": len(relation_ids),
"results": results
}
except Exception as e:
logger.error(f"Error testing relations by type: {str(e)}")
logger.error(traceback.format_exc())
raise
@staticmethod
def run_perturbation_test(
tester,
knowledge_graph: str,
model: str = "gpt-5-mini",
perturbation_type: str = "jailbreak",
relation_type: str = "",
max_jailbreaks: int = 5,
relation_limit: int = 5
) -> Dict[str, Any]:
"""Run perturbation tests on a knowledge graph (using database)"""
try:
# Get a database session
from backend.database.utils import get_db, get_knowledge_graph
session = next(get_db())
try:
# Get the knowledge graph from database
kg = get_knowledge_graph(session, knowledge_graph)
if not kg:
raise FileNotFoundError(f"Knowledge graph '{knowledge_graph}' not found in database")
# Get the knowledge graph content
knowledge_graph_content = kg.graph_data
logger.info(f"Retrieved knowledge graph {knowledge_graph} from database for testing")
# Generate timestamp for this test run
timestamp = int(time.time())
output_path = f"perturbation_results_{timestamp}.json"
# Limit jailbreak techniques if specified
if perturbation_type == 'jailbreak' and max_jailbreaks > 0:
if len(tester.jailbreak_techniques) > max_jailbreaks:
logger.info(f"Limiting jailbreak techniques to {max_jailbreaks}")
tester.jailbreak_techniques = tester.jailbreak_techniques[:max_jailbreaks]
# Start the test in a separate thread
def run_tests():
try:
# Run the tests
tester.run_tests(
knowledge_graph=knowledge_graph_content,
output_file=output_path,
model=model,
perturbation_type=perturbation_type,
relation_type=relation_type,
relation_limit=relation_limit,
max_jailbreaks=max_jailbreaks
)
# Add metadata to the results file
try:
with open(output_path, 'r') as f:
results = json.load(f)
# Add test metadata
results["test_metadata"] = {
"timestamp": timestamp,
"knowledge_graph_id": kg.id,
"knowledge_graph_file": knowledge_graph,
"model": model,
"perturbation_type": perturbation_type,
"relation_type": relation_type,
"max_jailbreaks": max_jailbreaks,
"relation_limit": relation_limit
}
# Write back the updated results
with open(output_path, 'w') as f:
json.dump(results, f, indent=2)
logger.info(f"Added metadata to results file {output_path}")
except Exception as e:
logger.error(f"Error adding metadata to results: {str(e)}")
logger.info(f"Perturbation tests completed. Results saved to {output_path}")
except Exception as e:
logger.error(f"Error running perturbation tests: {str(e)}")
logger.error(traceback.format_exc())
# Create an error results file
error_results = {
"error": True,
"error_message": str(e),
"timestamp": timestamp,
"test_metadata": {
"timestamp": timestamp,
"knowledge_graph_id": kg.id,
"knowledge_graph_file": knowledge_graph,
"model": model,
"perturbation_type": perturbation_type,
"relation_type": relation_type,
"max_jailbreaks": max_jailbreaks,
"relation_limit": relation_limit,
"status": "failed"
}
}
with open(output_path, 'w') as f:
json.dump(error_results, f, indent=2)
# Start the thread
threading.Thread(target=run_tests).start()
# Return response with test information
return {
"status": "in_progress",
"message": f"Perturbation testing started. Results will be saved to {output_path}",
"output_file": output_path,
"timestamp": timestamp,
"knowledge_graph_id": kg.id
}
finally:
session.close()
except Exception as e:
logger.error(f"Error starting perturbation test: {str(e)}")
logger.error(traceback.format_exc())
raise
@staticmethod
def get_test_results() -> Dict[str, Any]:
"""Get test results from the most recent test"""
try:
if not os.path.exists(TEST_RESULTS_FILE):
# Create a simple "no results" response
return {
"status": "no_results",
"message": "No test results are currently available. Run a perturbation test first."
}
try:
with open(TEST_RESULTS_FILE, 'r') as f:
results = json.load(f)
# Add download URL for the results file if it doesn't have an error
if not results.get("error", False):
# If there's an output file stored in the results, add a download link
if "output_file" in results:
output_file = results["output_file"]
results["download_url"] = f"/download/{output_file}"
logger.info(f"Added download URL for {output_file}")
# If there's no output file but we can guess it from timestamp
elif "timestamp" in results:
timestamp = int(results.get("timestamp", time.time()))
guessed_file = f"perturbation_results_{timestamp}.json"
if os.path.exists(guessed_file):
results["output_file"] = guessed_file
results["download_url"] = f"/download/{guessed_file}"
logger.info(f"Added download URL for guessed file {guessed_file}")
return results
except json.JSONDecodeError as e:
# The test results file exists but is not valid JSON
logger.error(f"Invalid JSON in test results file: {str(e)}")
return {
"error": True,
"error_message": f"Test results file contains invalid JSON: {str(e)}",
"timestamp": time.time()
}
except Exception as e:
logger.error(f"Error retrieving test results: {str(e)}")
raise
@staticmethod
def get_test_history() -> Dict[str, List[Dict[str, Any]]]:
"""Get history of all tests"""
try:
# Get all perturbation result files
test_dir = 'datasets/test_results'
Path(test_dir).mkdir(parents=True, exist_ok=True)
result_files = [f for f in os.listdir(test_dir)
if f.startswith('perturbation_results_') and f.endswith('.json')]
# Collect metadata from each file
tests = []
for file in result_files:
try:
with open(os.path.join(test_dir, file), 'r') as f:
data = json.load(f)
# Add output file name to the data
data['output_file'] = file
tests.append(data)
except Exception as e:
logger.error(f"Error reading test result file {file}: {str(e)}")
return {"tests": tests}
except Exception as e:
logger.error(f"Error getting test history: {str(e)}")
return {"tests": [], "error": str(e)}
@staticmethod
def get_specific_test_result(test_id: Optional[str] = None, test_file: Optional[str] = None) -> Dict[str, Any]:
"""Get a specific test result by ID or filename"""
try:
# If file is specified, load directly
if test_file:
file_path = str(PROJECT_ROOT / 'datasets' / 'test_results' / test_file)
if not os.path.exists(file_path):
raise FileNotFoundError(f"Test result file {test_file} not found")
with open(file_path, 'r') as f:
result = json.load(f)
# Add download URL
result['download_url'] = f"/download/{test_file}"
return result
# Otherwise find by ID
elif test_id:
test_dir = 'datasets/test_results'
result_files = [f for f in os.listdir(test_dir)
if f.startswith('perturbation_results_') and f.endswith('.json')]
for file in result_files:
try:
with open(os.path.join(test_dir, file), 'r') as f:
data = json.load(f)
# Check if this is the test we're looking for
if str(data.get('timestamp', '')) == test_id or data.get('id', '') == test_id:
result = data
result['download_url'] = f"/download/{file}"
result['output_file'] = file
return result
except Exception as e:
logger.error(f"Error reading test result file {file}: {str(e)}")
raise FileNotFoundError(f"Test result with ID {test_id} not found")
else:
raise ValueError("Either test_id or test_file must be provided")
except Exception as e:
logger.error(f"Error getting specific test result: {str(e)}")
raise
@staticmethod
def get_progress_status(output_file: str) -> Dict[str, Any]:
"""Get progress status for a test"""
try:
# Construct the progress file path
progress_file = f"progress_{output_file}"
test_dir = 'datasets/test_results'
# Make sure the test_results directory exists
Path(test_dir).mkdir(parents=True, exist_ok=True)
progress_path = os.path.join(test_dir, progress_file)
if not os.path.exists(progress_path):
# Create an empty progress file as a placeholder
try:
with open(progress_path, 'w') as f:
json.dump({
"status": "initializing",
"overall_progress_percentage": 0,
"current_jailbreak": "Preparing...",
"last_tested_relation": "Preparing...",
"created_at": time.time()
}, f)
logger.info(f"Created placeholder progress file: {progress_path}")
except Exception as e:
logger.error(f"Failed to create progress file: {str(e)}")
raise
# Read the progress file
with open(progress_path, 'r') as f:
progress_data = json.load(f)
# Add metadata about the file itself
progress_data["progress_file"] = progress_file
progress_data["last_updated"] = os.path.getmtime(progress_path)
return progress_data
except Exception as e:
logger.error(f"Error getting progress status: {str(e)}")
logger.error(traceback.format_exc())
raise
@staticmethod
def check_progress_file(output_file: str) -> Dict[str, Any]:
"""Check if a progress file exists without returning 404"""
try:
# Construct the progress file path
progress_file = f"progress_{output_file}"
test_dir = 'datasets/test_results'
# Make sure the test_results directory exists
Path(test_dir).mkdir(parents=True, exist_ok=True)
progress_path = os.path.join(test_dir, progress_file)
exists = os.path.exists(progress_path)
return {
"exists": exists,
"file": progress_file,
"path": progress_path,
"last_modified": os.path.getmtime(progress_path) if exists else None,
"status": "ready_to_create" if not exists else "exists"
}
except Exception as e:
logger.error(f"Error checking progress file: {str(e)}")
logger.error(traceback.format_exc())
return {
"exists": False,
"error": str(e)
}