""" Causal Analysis Service This service handles all database operations for causal analysis, providing a clean interface between the database layer and the pure analytical functions in agentgraph.causal. """ import logging from typing import Dict, List, Any, Optional from sqlalchemy.orm import Session from datetime import datetime, timezone import uuid import traceback import numpy as np from backend.database.models import CausalAnalysis, KnowledgeGraph, PerturbationTest, PromptReconstruction from backend.database.utils import save_causal_analysis, get_causal_analysis, get_causal_analysis_summary, get_all_causal_analyses, get_knowledge_graph_by_id from agentgraph.causal.causal_interface import analyze_causal_effects from backend.database import get_db from backend.services.task_service import update_task_status logger = logging.getLogger(__name__) def sanitize_for_json(data: Any) -> Any: """ Recursively convert numpy types to standard Python types for JSON serialization. """ if isinstance(data, dict): return {key: sanitize_for_json(value) for key, value in data.items()} elif isinstance(data, list): return [sanitize_for_json(item) for item in data] elif isinstance(data, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64)): return int(data) elif isinstance(data, (np.float_, np.float16, np.float32, np.float64)): return float(data) elif isinstance(data, np.ndarray): return data.tolist() return data class CausalService: """ Service for orchestrating causal analysis with database operations. This service fetches data from the database, calls pure analytical functions from agentgraph.causal, and saves the results back to the database. """ def __init__(self, session: Session): self.session = session def fetch_analysis_data(self, knowledge_graph_id: int, perturbation_set_id: str) -> Dict[str, Any]: """ Fetch all data needed for causal analysis from the database. Args: knowledge_graph_id: ID of the knowledge graph perturbation_set_id: ID of the perturbation set Returns: Dictionary containing all data needed for analysis """ try: # 1. Query PerturbationTest, filtering by both IDs perturbation_tests = self.session.query(PerturbationTest).filter( PerturbationTest.knowledge_graph_id == knowledge_graph_id, PerturbationTest.perturbation_set_id == perturbation_set_id ).all() if not perturbation_tests: logger.warning(f"No perturbation tests found for knowledge_graph_id={knowledge_graph_id}, perturbation_set_id={perturbation_set_id}") # Debug: Check what perturbation tests exist all_tests = self.session.query(PerturbationTest).filter( PerturbationTest.knowledge_graph_id == knowledge_graph_id ).all() logger.warning(f"Available perturbation sets for KG {knowledge_graph_id}: {[t.perturbation_set_id for t in all_tests]}") return {"error": "No perturbation tests found for the specified criteria"} # Get a sample to report what we're analyzing sample_test = perturbation_tests[0] logger.info(f"Analyzing {len(perturbation_tests)} perturbation tests of type '{sample_test.perturbation_type}'") # 2. Get all prompt_reconstruction_ids pr_ids = [test.prompt_reconstruction_id for test in perturbation_tests] # 3. Query PromptReconstruction for these IDs prompt_reconstructions = self.session.query(PromptReconstruction).filter( PromptReconstruction.id.in_(pr_ids) ).all() # 4. Get the knowledge graph data kg = self.session.query(KnowledgeGraph).filter_by(id=knowledge_graph_id).first() if not kg: return {"error": f"Knowledge graph with ID {knowledge_graph_id} not found"} # 5. Create the analysis data structure analysis_data = { "perturbation_tests": [test.to_dict() for test in perturbation_tests], "dependencies_map": {pr.id: pr.dependencies for pr in prompt_reconstructions}, "knowledge_graph": kg.graph_data, "perturbation_type": sample_test.perturbation_type, "perturbation_scores": {test.relation_id: test.perturbation_score for test in perturbation_tests}, "relation_to_pr_map": {test.relation_id: test.prompt_reconstruction_id for test in perturbation_tests} } return analysis_data except Exception as e: logger.error(f"Error while extracting data for analysis: {repr(e)}") return {"error": f"Failed to extract analysis data: {repr(e)}"} def save_analysis_results(self, method: str, results: Dict[str, Any], knowledge_graph_id: int, perturbation_set_id: str) -> None: """ Save analysis results to the database. Args: method: Analysis method name results: Results from the analysis knowledge_graph_id: ID of the knowledge graph perturbation_set_id: ID of the perturbation set """ if "error" in results: logger.warning(f"Not saving results for {method} due to error: {results['error']}") return # Sanitize results to ensure they are JSON serializable sanitized_results = sanitize_for_json(results) # Calculate causal score based on method causal_score = self._calculate_causal_score(method, sanitized_results) # Save to database save_causal_analysis( self.session, knowledge_graph_id=knowledge_graph_id, perturbation_set_id=perturbation_set_id, analysis_method=method, analysis_result=sanitized_results, causal_score=causal_score, analysis_metadata={ "timestamp": datetime.now(timezone.utc).isoformat(), "method_specific_metadata": sanitized_results.get("metadata", {}) } ) def _calculate_causal_score(self, method: str, result: Dict[str, Any]) -> float: """Calculate a single causal score for the method based on results.""" try: if method == "graph": scores = result.get("scores", {}) ace_scores = scores.get("ACE", {}) if ace_scores: return sum(abs(score) for score in ace_scores.values()) / len(ace_scores) return 0.0 elif method == "component": scores = result.get("scores", {}) feature_importance = scores.get("Feature_Importance", {}) if feature_importance: return sum(abs(score) for score in feature_importance.values()) / len(feature_importance) return 0.0 elif method == "dowhy": scores = result.get("scores", {}) effect_estimates = scores.get("Effect_Estimate", {}) if effect_estimates: return sum(abs(score) for score in effect_estimates.values()) / len(effect_estimates) return 0.0 elif method in ["confounder", "mscd"]: scores = result.get("scores", {}) confounders = scores.get("Confounders", {}) return len(confounders) * 0.1 # Simple heuristic elif method == "ate": scores = result.get("scores", {}) effect_strengths = scores.get("Effect_Strengths", {}) if effect_strengths: return sum(abs(score) for score in effect_strengths.values()) / len(effect_strengths) return 0.0 except Exception as e: logger.warning(f"Error calculating causal score for {method}: {e}") return 0.0 def run_causal_analysis(self, knowledge_graph_id: int, perturbation_set_id: str, methods: Optional[List[str]] = None) -> Dict[str, Any]: """ Run causal analysis with database operations. Args: knowledge_graph_id: ID of the knowledge graph perturbation_set_id: ID of the perturbation set methods: List of analysis methods to use Returns: Dictionary containing analysis results for each method """ if methods is None: methods = ['graph', 'component', 'dowhy', 'confounder', 'mscd', 'ate'] # Fetch data from database analysis_data = self.fetch_analysis_data(knowledge_graph_id, perturbation_set_id) if "error" in analysis_data: return analysis_data # Import and call pure analysis function try: results = analyze_causal_effects(analysis_data, methods) # Save each method's results to database for method, result in results.items(): if "error" not in result: self.save_analysis_results(method, result, knowledge_graph_id, perturbation_set_id) return results except Exception as e: logger.error(f"Error during causal analysis: {repr(e)}") return {"error": f"Analysis failed: {repr(e)}"} def get_analysis_results(self, knowledge_graph_id: int, method: Optional[str] = None) -> List[Dict[str, Any]]: """ Get causal analysis results from database. Args: knowledge_graph_id: ID of the knowledge graph method: Optional filter by analysis method Returns: List of causal analysis results """ return get_all_causal_analyses( session=self.session, knowledge_graph_id=knowledge_graph_id, analysis_method=method ) def get_causal_analysis_summary(self, knowledge_graph_id: str) -> Dict[str, Any]: """ Retrieves a summary of the causal analysis for a given knowledge graph. """ return get_causal_analysis_summary(self.session, knowledge_graph_id) async def analyze_causal_relationships_task(kg_id: str, task_id: str) -> bool: """ Task to analyze causal relationships in a knowledge graph using the CausalService. Returns True if successful, False otherwise. """ logger.info(f"Starting causal analysis for knowledge graph {kg_id}") update_task_status(task_id, "RUNNING", "Analyzing causal relationships") try: session = next(get_db()) try: kg = get_knowledge_graph_by_id(session, kg_id) if not kg: logger.error(f"Knowledge graph with ID {kg_id} not found") update_task_status(task_id, "FAILED", f"Knowledge graph with ID {kg_id} not found") return False if kg.status not in ["perturbed", "analyzed"]: update_task_status(task_id, "FAILED", "Knowledge graph must be perturbed before causal analysis") return False # Instantiate the CausalService with the session causal_service = CausalService(session=session) # Get all available perturbation_set_ids for this KG perturbation_sets = session.query(PerturbationTest.perturbation_set_id).filter_by(knowledge_graph_id=kg.id).distinct().all() logger.info(f"Found {len(perturbation_sets)} perturbation sets for KG {kg.id}: {[ps[0] for ps in perturbation_sets]}") if not perturbation_sets: update_task_status(task_id, "FAILED", "No perturbation tests found for this knowledge graph") return False # Run analysis for each perturbation set for (perturbation_set_id,) in perturbation_sets: logger.info(f"Running causal analysis for KG {kg.id} with perturbation set {perturbation_set_id}") analysis_results = causal_service.run_causal_analysis( knowledge_graph_id=kg.id, perturbation_set_id=perturbation_set_id ) logger.info(f"Causal analysis results for set {perturbation_set_id}: {analysis_results}") # Update KG status kg.status = "analyzed" kg.update_timestamp = datetime.now(timezone.utc) session.commit() update_task_status(task_id, "COMPLETED", "Causal analysis completed") logger.info(f"Causal analysis completed for knowledge graph {kg_id}") return True finally: session.close() except Exception as e: logger.error(f"Error in causal analysis: {str(e)}") logger.error(traceback.format_exc()) update_task_status(task_id, "FAILED", f"Error in causal analysis: {str(e)}") return False