AgentGraph / backend /services /causal_service.py
wu981526092's picture
🚀 Deploy AgentGraph: Complete agent monitoring and knowledge graph system
c2ea5ed
"""
Causal Analysis Service
This service handles all database operations for causal analysis,
providing a clean interface between the database layer and the pure
analytical functions in agentgraph.causal.
"""
import logging
from typing import Dict, List, Any, Optional
from sqlalchemy.orm import Session
from datetime import datetime, timezone
import uuid
import traceback
import numpy as np
from backend.database.models import CausalAnalysis, KnowledgeGraph, PerturbationTest, PromptReconstruction
from backend.database.utils import save_causal_analysis, get_causal_analysis, get_causal_analysis_summary, get_all_causal_analyses, get_knowledge_graph_by_id
from agentgraph.causal.causal_interface import analyze_causal_effects
from backend.database import get_db
from backend.services.task_service import update_task_status
logger = logging.getLogger(__name__)
def sanitize_for_json(data: Any) -> Any:
"""
Recursively convert numpy types to standard Python types for JSON serialization.
"""
if isinstance(data, dict):
return {key: sanitize_for_json(value) for key, value in data.items()}
elif isinstance(data, list):
return [sanitize_for_json(item) for item in data]
elif isinstance(data, (np.int_, np.intc, np.intp, np.int8,
np.int16, np.int32, np.int64, np.uint8,
np.uint16, np.uint32, np.uint64)):
return int(data)
elif isinstance(data, (np.float_, np.float16, np.float32, np.float64)):
return float(data)
elif isinstance(data, np.ndarray):
return data.tolist()
return data
class CausalService:
"""
Service for orchestrating causal analysis with database operations.
This service fetches data from the database, calls pure analytical functions
from agentgraph.causal, and saves the results back to the database.
"""
def __init__(self, session: Session):
self.session = session
def fetch_analysis_data(self, knowledge_graph_id: int, perturbation_set_id: str) -> Dict[str, Any]:
"""
Fetch all data needed for causal analysis from the database.
Args:
knowledge_graph_id: ID of the knowledge graph
perturbation_set_id: ID of the perturbation set
Returns:
Dictionary containing all data needed for analysis
"""
try:
# 1. Query PerturbationTest, filtering by both IDs
perturbation_tests = self.session.query(PerturbationTest).filter(
PerturbationTest.knowledge_graph_id == knowledge_graph_id,
PerturbationTest.perturbation_set_id == perturbation_set_id
).all()
if not perturbation_tests:
logger.warning(f"No perturbation tests found for knowledge_graph_id={knowledge_graph_id}, perturbation_set_id={perturbation_set_id}")
# Debug: Check what perturbation tests exist
all_tests = self.session.query(PerturbationTest).filter(
PerturbationTest.knowledge_graph_id == knowledge_graph_id
).all()
logger.warning(f"Available perturbation sets for KG {knowledge_graph_id}: {[t.perturbation_set_id for t in all_tests]}")
return {"error": "No perturbation tests found for the specified criteria"}
# Get a sample to report what we're analyzing
sample_test = perturbation_tests[0]
logger.info(f"Analyzing {len(perturbation_tests)} perturbation tests of type '{sample_test.perturbation_type}'")
# 2. Get all prompt_reconstruction_ids
pr_ids = [test.prompt_reconstruction_id for test in perturbation_tests]
# 3. Query PromptReconstruction for these IDs
prompt_reconstructions = self.session.query(PromptReconstruction).filter(
PromptReconstruction.id.in_(pr_ids)
).all()
# 4. Get the knowledge graph data
kg = self.session.query(KnowledgeGraph).filter_by(id=knowledge_graph_id).first()
if not kg:
return {"error": f"Knowledge graph with ID {knowledge_graph_id} not found"}
# 5. Create the analysis data structure
analysis_data = {
"perturbation_tests": [test.to_dict() for test in perturbation_tests],
"dependencies_map": {pr.id: pr.dependencies for pr in prompt_reconstructions},
"knowledge_graph": kg.graph_data,
"perturbation_type": sample_test.perturbation_type,
"perturbation_scores": {test.relation_id: test.perturbation_score for test in perturbation_tests},
"relation_to_pr_map": {test.relation_id: test.prompt_reconstruction_id for test in perturbation_tests}
}
return analysis_data
except Exception as e:
logger.error(f"Error while extracting data for analysis: {repr(e)}")
return {"error": f"Failed to extract analysis data: {repr(e)}"}
def save_analysis_results(self, method: str, results: Dict[str, Any], knowledge_graph_id: int, perturbation_set_id: str) -> None:
"""
Save analysis results to the database.
Args:
method: Analysis method name
results: Results from the analysis
knowledge_graph_id: ID of the knowledge graph
perturbation_set_id: ID of the perturbation set
"""
if "error" in results:
logger.warning(f"Not saving results for {method} due to error: {results['error']}")
return
# Sanitize results to ensure they are JSON serializable
sanitized_results = sanitize_for_json(results)
# Calculate causal score based on method
causal_score = self._calculate_causal_score(method, sanitized_results)
# Save to database
save_causal_analysis(
self.session,
knowledge_graph_id=knowledge_graph_id,
perturbation_set_id=perturbation_set_id,
analysis_method=method,
analysis_result=sanitized_results,
causal_score=causal_score,
analysis_metadata={
"timestamp": datetime.now(timezone.utc).isoformat(),
"method_specific_metadata": sanitized_results.get("metadata", {})
}
)
def _calculate_causal_score(self, method: str, result: Dict[str, Any]) -> float:
"""Calculate a single causal score for the method based on results."""
try:
if method == "graph":
scores = result.get("scores", {})
ace_scores = scores.get("ACE", {})
if ace_scores:
return sum(abs(score) for score in ace_scores.values()) / len(ace_scores)
return 0.0
elif method == "component":
scores = result.get("scores", {})
feature_importance = scores.get("Feature_Importance", {})
if feature_importance:
return sum(abs(score) for score in feature_importance.values()) / len(feature_importance)
return 0.0
elif method == "dowhy":
scores = result.get("scores", {})
effect_estimates = scores.get("Effect_Estimate", {})
if effect_estimates:
return sum(abs(score) for score in effect_estimates.values()) / len(effect_estimates)
return 0.0
elif method in ["confounder", "mscd"]:
scores = result.get("scores", {})
confounders = scores.get("Confounders", {})
return len(confounders) * 0.1 # Simple heuristic
elif method == "ate":
scores = result.get("scores", {})
effect_strengths = scores.get("Effect_Strengths", {})
if effect_strengths:
return sum(abs(score) for score in effect_strengths.values()) / len(effect_strengths)
return 0.0
except Exception as e:
logger.warning(f"Error calculating causal score for {method}: {e}")
return 0.0
def run_causal_analysis(self, knowledge_graph_id: int, perturbation_set_id: str, methods: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Run causal analysis with database operations.
Args:
knowledge_graph_id: ID of the knowledge graph
perturbation_set_id: ID of the perturbation set
methods: List of analysis methods to use
Returns:
Dictionary containing analysis results for each method
"""
if methods is None:
methods = ['graph', 'component', 'dowhy', 'confounder', 'mscd', 'ate']
# Fetch data from database
analysis_data = self.fetch_analysis_data(knowledge_graph_id, perturbation_set_id)
if "error" in analysis_data:
return analysis_data
# Import and call pure analysis function
try:
results = analyze_causal_effects(analysis_data, methods)
# Save each method's results to database
for method, result in results.items():
if "error" not in result:
self.save_analysis_results(method, result, knowledge_graph_id, perturbation_set_id)
return results
except Exception as e:
logger.error(f"Error during causal analysis: {repr(e)}")
return {"error": f"Analysis failed: {repr(e)}"}
def get_analysis_results(self, knowledge_graph_id: int, method: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Get causal analysis results from database.
Args:
knowledge_graph_id: ID of the knowledge graph
method: Optional filter by analysis method
Returns:
List of causal analysis results
"""
return get_all_causal_analyses(
session=self.session,
knowledge_graph_id=knowledge_graph_id,
analysis_method=method
)
def get_causal_analysis_summary(self, knowledge_graph_id: str) -> Dict[str, Any]:
"""
Retrieves a summary of the causal analysis for a given knowledge graph.
"""
return get_causal_analysis_summary(self.session, knowledge_graph_id)
async def analyze_causal_relationships_task(kg_id: str, task_id: str) -> bool:
"""
Task to analyze causal relationships in a knowledge graph using the CausalService.
Returns True if successful, False otherwise.
"""
logger.info(f"Starting causal analysis for knowledge graph {kg_id}")
update_task_status(task_id, "RUNNING", "Analyzing causal relationships")
try:
session = next(get_db())
try:
kg = get_knowledge_graph_by_id(session, kg_id)
if not kg:
logger.error(f"Knowledge graph with ID {kg_id} not found")
update_task_status(task_id, "FAILED", f"Knowledge graph with ID {kg_id} not found")
return False
if kg.status not in ["perturbed", "analyzed"]:
update_task_status(task_id, "FAILED", "Knowledge graph must be perturbed before causal analysis")
return False
# Instantiate the CausalService with the session
causal_service = CausalService(session=session)
# Get all available perturbation_set_ids for this KG
perturbation_sets = session.query(PerturbationTest.perturbation_set_id).filter_by(knowledge_graph_id=kg.id).distinct().all()
logger.info(f"Found {len(perturbation_sets)} perturbation sets for KG {kg.id}: {[ps[0] for ps in perturbation_sets]}")
if not perturbation_sets:
update_task_status(task_id, "FAILED", "No perturbation tests found for this knowledge graph")
return False
# Run analysis for each perturbation set
for (perturbation_set_id,) in perturbation_sets:
logger.info(f"Running causal analysis for KG {kg.id} with perturbation set {perturbation_set_id}")
analysis_results = causal_service.run_causal_analysis(
knowledge_graph_id=kg.id,
perturbation_set_id=perturbation_set_id
)
logger.info(f"Causal analysis results for set {perturbation_set_id}: {analysis_results}")
# Update KG status
kg.status = "analyzed"
kg.update_timestamp = datetime.now(timezone.utc)
session.commit()
update_task_status(task_id, "COMPLETED", "Causal analysis completed")
logger.info(f"Causal analysis completed for knowledge graph {kg_id}")
return True
finally:
session.close()
except Exception as e:
logger.error(f"Error in causal analysis: {str(e)}")
logger.error(traceback.format_exc())
update_task_status(task_id, "FAILED", f"Error in causal analysis: {str(e)}")
return False