Spaces:
Sleeping
Sleeping
| """ | |
| Causal Analysis Service | |
| This service handles all database operations for causal analysis, | |
| providing a clean interface between the database layer and the pure | |
| analytical functions in agentgraph.causal. | |
| """ | |
| import logging | |
| from typing import Dict, List, Any, Optional | |
| from sqlalchemy.orm import Session | |
| from datetime import datetime, timezone | |
| import uuid | |
| import traceback | |
| import numpy as np | |
| from backend.database.models import CausalAnalysis, KnowledgeGraph, PerturbationTest, PromptReconstruction | |
| from backend.database.utils import save_causal_analysis, get_causal_analysis, get_causal_analysis_summary, get_all_causal_analyses, get_knowledge_graph_by_id | |
| from agentgraph.causal.causal_interface import analyze_causal_effects | |
| from backend.database import get_db | |
| from backend.services.task_service import update_task_status | |
| logger = logging.getLogger(__name__) | |
| def sanitize_for_json(data: Any) -> Any: | |
| """ | |
| Recursively convert numpy types to standard Python types for JSON serialization. | |
| """ | |
| if isinstance(data, dict): | |
| return {key: sanitize_for_json(value) for key, value in data.items()} | |
| elif isinstance(data, list): | |
| return [sanitize_for_json(item) for item in data] | |
| elif isinstance(data, (np.int_, np.intc, np.intp, np.int8, | |
| np.int16, np.int32, np.int64, np.uint8, | |
| np.uint16, np.uint32, np.uint64)): | |
| return int(data) | |
| elif isinstance(data, (np.float_, np.float16, np.float32, np.float64)): | |
| return float(data) | |
| elif isinstance(data, np.ndarray): | |
| return data.tolist() | |
| return data | |
| class CausalService: | |
| """ | |
| Service for orchestrating causal analysis with database operations. | |
| This service fetches data from the database, calls pure analytical functions | |
| from agentgraph.causal, and saves the results back to the database. | |
| """ | |
| def __init__(self, session: Session): | |
| self.session = session | |
| def fetch_analysis_data(self, knowledge_graph_id: int, perturbation_set_id: str) -> Dict[str, Any]: | |
| """ | |
| Fetch all data needed for causal analysis from the database. | |
| Args: | |
| knowledge_graph_id: ID of the knowledge graph | |
| perturbation_set_id: ID of the perturbation set | |
| Returns: | |
| Dictionary containing all data needed for analysis | |
| """ | |
| try: | |
| # 1. Query PerturbationTest, filtering by both IDs | |
| perturbation_tests = self.session.query(PerturbationTest).filter( | |
| PerturbationTest.knowledge_graph_id == knowledge_graph_id, | |
| PerturbationTest.perturbation_set_id == perturbation_set_id | |
| ).all() | |
| if not perturbation_tests: | |
| logger.warning(f"No perturbation tests found for knowledge_graph_id={knowledge_graph_id}, perturbation_set_id={perturbation_set_id}") | |
| # Debug: Check what perturbation tests exist | |
| all_tests = self.session.query(PerturbationTest).filter( | |
| PerturbationTest.knowledge_graph_id == knowledge_graph_id | |
| ).all() | |
| logger.warning(f"Available perturbation sets for KG {knowledge_graph_id}: {[t.perturbation_set_id for t in all_tests]}") | |
| return {"error": "No perturbation tests found for the specified criteria"} | |
| # Get a sample to report what we're analyzing | |
| sample_test = perturbation_tests[0] | |
| logger.info(f"Analyzing {len(perturbation_tests)} perturbation tests of type '{sample_test.perturbation_type}'") | |
| # 2. Get all prompt_reconstruction_ids | |
| pr_ids = [test.prompt_reconstruction_id for test in perturbation_tests] | |
| # 3. Query PromptReconstruction for these IDs | |
| prompt_reconstructions = self.session.query(PromptReconstruction).filter( | |
| PromptReconstruction.id.in_(pr_ids) | |
| ).all() | |
| # 4. Get the knowledge graph data | |
| kg = self.session.query(KnowledgeGraph).filter_by(id=knowledge_graph_id).first() | |
| if not kg: | |
| return {"error": f"Knowledge graph with ID {knowledge_graph_id} not found"} | |
| # 5. Create the analysis data structure | |
| analysis_data = { | |
| "perturbation_tests": [test.to_dict() for test in perturbation_tests], | |
| "dependencies_map": {pr.id: pr.dependencies for pr in prompt_reconstructions}, | |
| "knowledge_graph": kg.graph_data, | |
| "perturbation_type": sample_test.perturbation_type, | |
| "perturbation_scores": {test.relation_id: test.perturbation_score for test in perturbation_tests}, | |
| "relation_to_pr_map": {test.relation_id: test.prompt_reconstruction_id for test in perturbation_tests} | |
| } | |
| return analysis_data | |
| except Exception as e: | |
| logger.error(f"Error while extracting data for analysis: {repr(e)}") | |
| return {"error": f"Failed to extract analysis data: {repr(e)}"} | |
| def save_analysis_results(self, method: str, results: Dict[str, Any], knowledge_graph_id: int, perturbation_set_id: str) -> None: | |
| """ | |
| Save analysis results to the database. | |
| Args: | |
| method: Analysis method name | |
| results: Results from the analysis | |
| knowledge_graph_id: ID of the knowledge graph | |
| perturbation_set_id: ID of the perturbation set | |
| """ | |
| if "error" in results: | |
| logger.warning(f"Not saving results for {method} due to error: {results['error']}") | |
| return | |
| # Sanitize results to ensure they are JSON serializable | |
| sanitized_results = sanitize_for_json(results) | |
| # Calculate causal score based on method | |
| causal_score = self._calculate_causal_score(method, sanitized_results) | |
| # Save to database | |
| save_causal_analysis( | |
| self.session, | |
| knowledge_graph_id=knowledge_graph_id, | |
| perturbation_set_id=perturbation_set_id, | |
| analysis_method=method, | |
| analysis_result=sanitized_results, | |
| causal_score=causal_score, | |
| analysis_metadata={ | |
| "timestamp": datetime.now(timezone.utc).isoformat(), | |
| "method_specific_metadata": sanitized_results.get("metadata", {}) | |
| } | |
| ) | |
| def _calculate_causal_score(self, method: str, result: Dict[str, Any]) -> float: | |
| """Calculate a single causal score for the method based on results.""" | |
| try: | |
| if method == "graph": | |
| scores = result.get("scores", {}) | |
| ace_scores = scores.get("ACE", {}) | |
| if ace_scores: | |
| return sum(abs(score) for score in ace_scores.values()) / len(ace_scores) | |
| return 0.0 | |
| elif method == "component": | |
| scores = result.get("scores", {}) | |
| feature_importance = scores.get("Feature_Importance", {}) | |
| if feature_importance: | |
| return sum(abs(score) for score in feature_importance.values()) / len(feature_importance) | |
| return 0.0 | |
| elif method == "dowhy": | |
| scores = result.get("scores", {}) | |
| effect_estimates = scores.get("Effect_Estimate", {}) | |
| if effect_estimates: | |
| return sum(abs(score) for score in effect_estimates.values()) / len(effect_estimates) | |
| return 0.0 | |
| elif method in ["confounder", "mscd"]: | |
| scores = result.get("scores", {}) | |
| confounders = scores.get("Confounders", {}) | |
| return len(confounders) * 0.1 # Simple heuristic | |
| elif method == "ate": | |
| scores = result.get("scores", {}) | |
| effect_strengths = scores.get("Effect_Strengths", {}) | |
| if effect_strengths: | |
| return sum(abs(score) for score in effect_strengths.values()) / len(effect_strengths) | |
| return 0.0 | |
| except Exception as e: | |
| logger.warning(f"Error calculating causal score for {method}: {e}") | |
| return 0.0 | |
| def run_causal_analysis(self, knowledge_graph_id: int, perturbation_set_id: str, methods: Optional[List[str]] = None) -> Dict[str, Any]: | |
| """ | |
| Run causal analysis with database operations. | |
| Args: | |
| knowledge_graph_id: ID of the knowledge graph | |
| perturbation_set_id: ID of the perturbation set | |
| methods: List of analysis methods to use | |
| Returns: | |
| Dictionary containing analysis results for each method | |
| """ | |
| if methods is None: | |
| methods = ['graph', 'component', 'dowhy', 'confounder', 'mscd', 'ate'] | |
| # Fetch data from database | |
| analysis_data = self.fetch_analysis_data(knowledge_graph_id, perturbation_set_id) | |
| if "error" in analysis_data: | |
| return analysis_data | |
| # Import and call pure analysis function | |
| try: | |
| results = analyze_causal_effects(analysis_data, methods) | |
| # Save each method's results to database | |
| for method, result in results.items(): | |
| if "error" not in result: | |
| self.save_analysis_results(method, result, knowledge_graph_id, perturbation_set_id) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error during causal analysis: {repr(e)}") | |
| return {"error": f"Analysis failed: {repr(e)}"} | |
| def get_analysis_results(self, knowledge_graph_id: int, method: Optional[str] = None) -> List[Dict[str, Any]]: | |
| """ | |
| Get causal analysis results from database. | |
| Args: | |
| knowledge_graph_id: ID of the knowledge graph | |
| method: Optional filter by analysis method | |
| Returns: | |
| List of causal analysis results | |
| """ | |
| return get_all_causal_analyses( | |
| session=self.session, | |
| knowledge_graph_id=knowledge_graph_id, | |
| analysis_method=method | |
| ) | |
| def get_causal_analysis_summary(self, knowledge_graph_id: str) -> Dict[str, Any]: | |
| """ | |
| Retrieves a summary of the causal analysis for a given knowledge graph. | |
| """ | |
| return get_causal_analysis_summary(self.session, knowledge_graph_id) | |
| async def analyze_causal_relationships_task(kg_id: str, task_id: str) -> bool: | |
| """ | |
| Task to analyze causal relationships in a knowledge graph using the CausalService. | |
| Returns True if successful, False otherwise. | |
| """ | |
| logger.info(f"Starting causal analysis for knowledge graph {kg_id}") | |
| update_task_status(task_id, "RUNNING", "Analyzing causal relationships") | |
| try: | |
| session = next(get_db()) | |
| try: | |
| kg = get_knowledge_graph_by_id(session, kg_id) | |
| if not kg: | |
| logger.error(f"Knowledge graph with ID {kg_id} not found") | |
| update_task_status(task_id, "FAILED", f"Knowledge graph with ID {kg_id} not found") | |
| return False | |
| if kg.status not in ["perturbed", "analyzed"]: | |
| update_task_status(task_id, "FAILED", "Knowledge graph must be perturbed before causal analysis") | |
| return False | |
| # Instantiate the CausalService with the session | |
| causal_service = CausalService(session=session) | |
| # Get all available perturbation_set_ids for this KG | |
| perturbation_sets = session.query(PerturbationTest.perturbation_set_id).filter_by(knowledge_graph_id=kg.id).distinct().all() | |
| logger.info(f"Found {len(perturbation_sets)} perturbation sets for KG {kg.id}: {[ps[0] for ps in perturbation_sets]}") | |
| if not perturbation_sets: | |
| update_task_status(task_id, "FAILED", "No perturbation tests found for this knowledge graph") | |
| return False | |
| # Run analysis for each perturbation set | |
| for (perturbation_set_id,) in perturbation_sets: | |
| logger.info(f"Running causal analysis for KG {kg.id} with perturbation set {perturbation_set_id}") | |
| analysis_results = causal_service.run_causal_analysis( | |
| knowledge_graph_id=kg.id, | |
| perturbation_set_id=perturbation_set_id | |
| ) | |
| logger.info(f"Causal analysis results for set {perturbation_set_id}: {analysis_results}") | |
| # Update KG status | |
| kg.status = "analyzed" | |
| kg.update_timestamp = datetime.now(timezone.utc) | |
| session.commit() | |
| update_task_status(task_id, "COMPLETED", "Causal analysis completed") | |
| logger.info(f"Causal analysis completed for knowledge graph {kg_id}") | |
| return True | |
| finally: | |
| session.close() | |
| except Exception as e: | |
| logger.error(f"Error in causal analysis: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| update_task_status(task_id, "FAILED", f"Error in causal analysis: {str(e)}") | |
| return False |