Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Testing Service | |
| This service handles all database operations for perturbation testing, | |
| providing a clean interface between the database layer and the pure | |
| testing functions in agentgraph.testing. | |
| """ | |
| import uuid | |
| import logging | |
| from typing import Dict, List, Any, Optional | |
| from sqlalchemy.orm import Session | |
| from datetime import datetime, timezone | |
| import traceback | |
| from backend.database.models import ( | |
| PerturbationTest, KnowledgeGraph, PromptReconstruction | |
| ) | |
| from backend.database.utils import ( | |
| get_knowledge_graph_by_id, get_knowledge_graph, | |
| get_prompt_reconstructions_for_kg | |
| ) | |
| # Import pure testing functions | |
| from agentgraph.testing import ( | |
| run_jailbreak_tests, run_counterfactual_bias_tests, | |
| validate_testing_data, prepare_testing_data, | |
| load_litellm_config, run_knowledge_graph_tests | |
| ) | |
| from backend.database import get_db | |
| from backend.services.task_service import update_task_status | |
| logger = logging.getLogger(__name__) | |
| class TestingService: | |
| """ | |
| Service for handling perturbation testing with database operations. | |
| This service acts as an abstraction layer between the database and the pure | |
| testing functions in agentgraph.testing. It handles: | |
| - Fetching test data from database | |
| - Calling pure testing functions | |
| - Saving test results back to database | |
| """ | |
| def __init__(self, session: Session): | |
| self.session = session | |
| def fetch_testing_data(self, knowledge_graph_identifier: str) -> Dict[str, Any]: | |
| """ | |
| Fetch all data needed for testing from the database. | |
| Args: | |
| knowledge_graph_identifier: Identifier of the knowledge graph to test | |
| Returns: | |
| Dictionary containing all testing data or error information | |
| """ | |
| try: | |
| # Get knowledge graph | |
| kg = get_knowledge_graph(self.session, knowledge_graph_identifier) | |
| if not kg: | |
| return {'error': f'Knowledge graph {knowledge_graph_identifier} not found'} | |
| # Get reconstructed prompts | |
| reconstructed_prompts = get_prompt_reconstructions_for_kg( | |
| self.session, knowledge_graph_identifier | |
| ) | |
| if not reconstructed_prompts: | |
| return { | |
| 'error': f'No prompt reconstructions found for knowledge graph {knowledge_graph_identifier}. ' | |
| 'Please run prompt reconstruction first.' | |
| } | |
| # Prepare testing data using pure function | |
| testing_data = prepare_testing_data( | |
| knowledge_graph=kg.graph_data, | |
| reconstructed_prompts={pr.relation_id: pr.reconstructed_prompt | |
| for pr in reconstructed_prompts} | |
| ) | |
| # Add metadata | |
| testing_data['knowledge_graph_id'] = kg.id | |
| testing_data['knowledge_graph_identifier'] = knowledge_graph_identifier | |
| return testing_data | |
| except Exception as e: | |
| logger.error(f"Error fetching testing data for {knowledge_graph_identifier}: {e}") | |
| return {'error': f'Failed to fetch testing data: {str(e)}'} | |
| def save_test_result( | |
| self, | |
| knowledge_graph_id: int, | |
| relation_id: str, | |
| perturbation_type: str, | |
| test_result: Dict[str, Any], | |
| perturbation_score: float = None, | |
| test_metadata: Dict[str, Any] = None, | |
| perturbation_set_id: str = None | |
| ) -> Optional[PerturbationTest]: | |
| """ | |
| Save test result to database. | |
| Args: | |
| knowledge_graph_id: ID of the knowledge graph | |
| relation_id: ID of the relation tested | |
| perturbation_type: Type of perturbation test | |
| test_result: Test result data | |
| perturbation_score: Score from the test | |
| test_metadata: Additional test metadata | |
| perturbation_set_id: ID of the perturbation set | |
| Returns: | |
| PerturbationTest object if successful, None otherwise | |
| """ | |
| try: | |
| # Create new test result | |
| test = PerturbationTest( | |
| knowledge_graph_id=knowledge_graph_id, | |
| relation_id=relation_id, | |
| perturbation_type=perturbation_type, | |
| test_result=test_result, | |
| perturbation_score=perturbation_score or test_result.get('perturbation_score', 0.0), | |
| test_metadata=test_metadata or {}, | |
| perturbation_set_id=perturbation_set_id or str(uuid.uuid4()), | |
| created_at=datetime.utcnow() | |
| ) | |
| self.session.add(test) | |
| self.session.commit() | |
| logger.info(f"Saved test result for relation {relation_id}, type {perturbation_type}") | |
| return test | |
| except Exception as e: | |
| logger.error(f"Error saving test result: {e}") | |
| self.session.rollback() | |
| return None | |
| def run_perturbation_tests( | |
| self, | |
| knowledge_graph_identifier: str, | |
| perturbation_types: List[str], | |
| max_relations: int = None, | |
| model: str = "gpt-5-mini", | |
| **test_kwargs | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run perturbation tests with database operations. | |
| Args: | |
| knowledge_graph_identifier: Identifier of the knowledge graph | |
| perturbation_types: List of perturbation types to test | |
| max_relations: Maximum number of relations to test | |
| model: Model to use for testing | |
| **test_kwargs: Additional arguments for testing | |
| Returns: | |
| Dictionary containing test results for each perturbation type | |
| """ | |
| # Fetch data from database | |
| testing_data = self.fetch_testing_data(knowledge_graph_identifier) | |
| if "error" in testing_data: | |
| return testing_data | |
| # Load model configurations | |
| try: | |
| model_configs = load_litellm_config() | |
| except Exception as e: | |
| logger.warning(f"Failed to load model configs: {e}") | |
| model_configs = [] | |
| results = {} | |
| for perturbation_type in perturbation_types: | |
| try: | |
| logger.info(f"Running {perturbation_type} tests on knowledge graph {knowledge_graph_identifier}") | |
| # Generate unique set ID for this test run | |
| perturbation_set_id = str(uuid.uuid4()) | |
| # Call appropriate pure testing function | |
| if perturbation_type == "jailbreak": | |
| test_results = run_jailbreak_tests( | |
| testing_data=testing_data, | |
| model=model, | |
| max_relations=max_relations, | |
| model_configs=model_configs, | |
| **test_kwargs | |
| ) | |
| elif perturbation_type == "counterfactual_bias": | |
| test_results = run_counterfactual_bias_tests( | |
| testing_data=testing_data, | |
| model=model, | |
| max_relations=max_relations, | |
| model_configs=model_configs, | |
| **test_kwargs | |
| ) | |
| else: | |
| logger.error(f"Unknown perturbation type: {perturbation_type}") | |
| results[perturbation_type] = {"error": f"Unknown perturbation type: {perturbation_type}"} | |
| continue | |
| # Check for errors in test results | |
| if "error" in test_results: | |
| results[perturbation_type] = test_results | |
| continue | |
| # Save test results to database | |
| saved_results = [] | |
| for relation_result in test_results.get('relations', []): | |
| relation_id = relation_result.get('relation_id') | |
| if relation_id: | |
| saved_test = self.save_test_result( | |
| knowledge_graph_id=testing_data["knowledge_graph_id"], | |
| relation_id=relation_id, | |
| perturbation_type=perturbation_type, | |
| test_result=relation_result, | |
| perturbation_score=relation_result.get('perturbation_score'), | |
| test_metadata={ | |
| 'model': model, | |
| 'test_timestamp': datetime.utcnow().isoformat(), | |
| 'perturbation_metadata': test_results.get('perturbation_metadata', {}) | |
| }, | |
| perturbation_set_id=perturbation_set_id | |
| ) | |
| if saved_test: | |
| saved_results.append({ | |
| 'relation_id': relation_id, | |
| 'test_id': saved_test.id, | |
| 'perturbation_score': saved_test.perturbation_score | |
| }) | |
| # Store results with metadata | |
| results[perturbation_type] = { | |
| 'test_results': test_results, | |
| 'saved_results': saved_results, | |
| 'perturbation_set_id': perturbation_set_id, | |
| 'summary': test_results.get('summary', {}), | |
| 'metadata': test_results.get('perturbation_metadata', {}) | |
| } | |
| logger.info(f"Completed {perturbation_type} tests: {len(saved_results)} results saved") | |
| except Exception as e: | |
| logger.error(f"Error running {perturbation_type} tests: {e}") | |
| results[perturbation_type] = {'error': f'Failed to run {perturbation_type} tests: {str(e)}'} | |
| return results | |
| def get_test_results( | |
| self, | |
| knowledge_graph_id: int, | |
| perturbation_type: Optional[str] = None, | |
| perturbation_set_id: Optional[str] = None | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Get test results from database. | |
| Args: | |
| knowledge_graph_id: ID of the knowledge graph | |
| perturbation_type: Filter by perturbation type (optional) | |
| perturbation_set_id: Filter by perturbation set ID (optional) | |
| Returns: | |
| List of test result dictionaries | |
| """ | |
| try: | |
| query = self.session.query(PerturbationTest).filter_by( | |
| knowledge_graph_id=knowledge_graph_id | |
| ) | |
| if perturbation_type: | |
| query = query.filter_by(perturbation_type=perturbation_type) | |
| if perturbation_set_id: | |
| query = query.filter_by(perturbation_set_id=perturbation_set_id) | |
| tests = query.all() | |
| results = [] | |
| for test in tests: | |
| result = { | |
| 'id': test.id, | |
| 'relation_id': test.relation_id, | |
| 'perturbation_type': test.perturbation_type, | |
| 'perturbation_score': test.perturbation_score, | |
| 'test_result': test.test_result, | |
| 'test_metadata': test.test_metadata, | |
| 'perturbation_set_id': test.perturbation_set_id, | |
| 'created_at': test.created_at.isoformat() if test.created_at else None | |
| } | |
| results.append(result) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error getting test results: {e}") | |
| return [] | |
| def get_test_summary(self, knowledge_graph_id: int) -> Dict[str, Any]: | |
| """ | |
| Get summary of test results for a knowledge graph. | |
| Args: | |
| knowledge_graph_id: ID of the knowledge graph | |
| Returns: | |
| Dictionary containing test summary | |
| """ | |
| try: | |
| tests = self.session.query(PerturbationTest).filter_by( | |
| knowledge_graph_id=knowledge_graph_id | |
| ).all() | |
| if not tests: | |
| return { | |
| 'total_tests': 0, | |
| 'perturbation_types': [], | |
| 'average_scores': {}, | |
| 'latest_test': None | |
| } | |
| # Group by perturbation type | |
| by_type = {} | |
| for test in tests: | |
| ptype = test.perturbation_type | |
| if ptype not in by_type: | |
| by_type[ptype] = [] | |
| by_type[ptype].append(test) | |
| # Calculate averages | |
| average_scores = {} | |
| for ptype, type_tests in by_type.items(): | |
| scores = [t.perturbation_score for t in type_tests if t.perturbation_score is not None] | |
| average_scores[ptype] = sum(scores) / len(scores) if scores else 0.0 | |
| # Find latest test | |
| latest_test = max(tests, key=lambda t: t.created_at or datetime.min) | |
| return { | |
| 'total_tests': len(tests), | |
| 'perturbation_types': list(by_type.keys()), | |
| 'tests_by_type': {ptype: len(type_tests) for ptype, type_tests in by_type.items()}, | |
| 'average_scores': average_scores, | |
| 'latest_test': { | |
| 'id': latest_test.id, | |
| 'perturbation_type': latest_test.perturbation_type, | |
| 'created_at': latest_test.created_at.isoformat() if latest_test.created_at else None | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting test summary: {e}") | |
| return {'error': f'Failed to get test summary: {str(e)}'} | |
| async def perturb_knowledge_graph_task( | |
| kg_id: str, | |
| task_id: str, | |
| config: Dict[str, Any] = None | |
| ) -> bool: | |
| """ | |
| Background task for perturbing a knowledge graph. | |
| This now uses the pure functions from agentgraph.testing. | |
| Returns True if successful, False otherwise. | |
| Args: | |
| kg_id: Knowledge graph ID | |
| task_id: Task ID for status tracking | |
| config: Optional configuration dictionary with: | |
| - model: LLM model to use (default: gpt-4o-mini) | |
| - judge_model: Judge model for evaluation (default: gpt-4o-mini) | |
| - max_relations: Max relations to test (default: None = all) | |
| - jailbreak: Jailbreak test config | |
| - counterfactual_bias: Bias test config | |
| - execution: Execution config (workers, retries, etc.) | |
| """ | |
| logger.info(f"Starting knowledge graph perturbation task {task_id} for KG {kg_id}") | |
| update_task_status(task_id, "RUNNING", "Perturbing knowledge graph") | |
| # Parse configuration | |
| config = config or {} | |
| model = config.get("model", "gpt-4o-mini") | |
| judge_model = config.get("judge_model", "gpt-4o-mini") | |
| max_relations = config.get("max_relations") | |
| # Jailbreak config | |
| jailbreak_config = config.get("jailbreak", {}) | |
| jailbreak_enabled = jailbreak_config.get("enabled", True) | |
| num_techniques = jailbreak_config.get("num_techniques", 10) | |
| # Counterfactual bias config | |
| bias_config = config.get("counterfactual_bias", {}) | |
| bias_enabled = bias_config.get("enabled", True) | |
| comparison_mode = bias_config.get("comparison_mode", "both") | |
| include_baseline = bias_config.get("include_baseline", True) | |
| # Build demographics list from config | |
| demographics_config = bias_config.get("demographics", [ | |
| {"gender": "male", "race": "White"}, | |
| {"gender": "female", "race": "White"}, | |
| {"gender": "male", "race": "Black"}, | |
| {"gender": "female", "race": "Black"}, | |
| ]) | |
| demographics = [(d["gender"], d["race"]) for d in demographics_config] | |
| # Determine which tests to run | |
| perturbation_types = [] | |
| if jailbreak_enabled: | |
| perturbation_types.append("jailbreak") | |
| if bias_enabled: | |
| perturbation_types.append("counterfactual_bias") | |
| if not perturbation_types: | |
| update_task_status(task_id, "FAILED", "No perturbation tests enabled") | |
| return False | |
| try: | |
| session = next(get_db()) | |
| try: | |
| from backend.database.models import PerturbationTest, PromptReconstruction | |
| import uuid | |
| kg = get_knowledge_graph_by_id(session, kg_id) | |
| if not kg: | |
| logger.error(f"Knowledge graph with ID {kg_id} not found") | |
| update_task_status(task_id, "FAILED", f"Knowledge graph with ID {kg_id} not found") | |
| return False | |
| if kg.status not in ["enriched", "perturbed", "analyzed"]: | |
| update_task_status(task_id, "FAILED", "Knowledge graph must be enriched before perturbation") | |
| return False | |
| # 1. Fetch data for testing | |
| update_task_status(task_id, "RUNNING", "Fetching data for testing", 10) | |
| reconstructed_prompts = get_prompt_reconstructions_for_kg(session, kg.id) | |
| if not reconstructed_prompts: | |
| update_task_status(task_id, "FAILED", "No prompt reconstructions found for this knowledge graph.") | |
| return False | |
| # 2. Prepare testing data | |
| update_task_status(task_id, "RUNNING", "Preparing testing data", 25) | |
| testing_data = prepare_testing_data( | |
| knowledge_graph=kg.graph_data, | |
| reconstructed_prompts=reconstructed_prompts | |
| ) | |
| # 3. Define progress callback | |
| def progress_callback(current, total, message): | |
| progress = 25 + int((current / total) * 55) # Scale progress from 25% to 80% | |
| update_task_status(task_id, "RUNNING", message, progress) | |
| # 4. Run tests with configuration | |
| update_task_status(task_id, "RUNNING", f"Running perturbation tests: {', '.join(perturbation_types)}", 50) | |
| test_results = run_knowledge_graph_tests( | |
| testing_data=testing_data, | |
| perturbation_types=perturbation_types, | |
| model=model, | |
| max_relations=max_relations, | |
| progress_callback=progress_callback, | |
| # Jailbreak specific | |
| num_techniques=num_techniques, | |
| judge_model=judge_model, | |
| # Counterfactual bias specific | |
| demographics=demographics, | |
| include_baseline=include_baseline, | |
| comparison_mode=comparison_mode, | |
| ) | |
| update_task_status(task_id, "RUNNING", "Tests completed, saving results", 80) | |
| # 5. Save results | |
| for p_type, p_results in test_results.items(): | |
| if "error" in p_results: | |
| logger.error(f"Error during {p_type} test: {p_results['error']}") | |
| continue | |
| perturbation_set_id = str(uuid.uuid4()) | |
| for relation_result in p_results.get('relations', []): | |
| # Find prompt_reconstruction_id | |
| prompt_reconstruction = session.query(PromptReconstruction).filter_by( | |
| knowledge_graph_id=kg.id, | |
| relation_id=relation_result["relation_id"] | |
| ).first() | |
| if not prompt_reconstruction: | |
| logger.warning(f"Could not find prompt reconstruction for relation {relation_result['relation_id']}. Skipping saving test result.") | |
| continue | |
| test = PerturbationTest( | |
| knowledge_graph_id=kg.id, | |
| prompt_reconstruction_id=prompt_reconstruction.id, | |
| relation_id=relation_result["relation_id"], | |
| perturbation_type=p_type, | |
| perturbation_set_id=perturbation_set_id, | |
| test_result=relation_result, | |
| perturbation_score=relation_result.get("perturbation_score"), | |
| test_metadata={ | |
| "model": model, | |
| "judge_model": judge_model, | |
| 'test_timestamp': datetime.now(timezone.utc).isoformat(), | |
| 'config': config, | |
| } | |
| ) | |
| session.add(test) | |
| # Update status | |
| kg.status = "perturbed" | |
| kg.update_timestamp = datetime.now(timezone.utc) | |
| session.commit() | |
| update_task_status(task_id, "COMPLETED", "Knowledge graph perturbed successfully") | |
| logger.info(f"Knowledge graph {kg_id} perturbed successfully") | |
| return True | |
| finally: | |
| session.close() | |
| except Exception as e: | |
| error_message = f"Error perturbing knowledge graph: {str(e)}" | |
| logger.error(error_message) | |
| logger.error(traceback.format_exc()) | |
| update_task_status(task_id, "FAILED", error_message) | |
| return False | |