File size: 13,559 Bytes
c2ea5ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""
Causal Analysis Service

This service handles all database operations for causal analysis,
providing a clean interface between the database layer and the pure
analytical functions in agentgraph.causal.
"""

import logging
from typing import Dict, List, Any, Optional
from sqlalchemy.orm import Session
from datetime import datetime, timezone
import uuid
import traceback
import numpy as np

from backend.database.models import CausalAnalysis, KnowledgeGraph, PerturbationTest, PromptReconstruction
from backend.database.utils import save_causal_analysis, get_causal_analysis, get_causal_analysis_summary, get_all_causal_analyses, get_knowledge_graph_by_id
from agentgraph.causal.causal_interface import analyze_causal_effects
from backend.database import get_db
from backend.services.task_service import update_task_status

logger = logging.getLogger(__name__)


def sanitize_for_json(data: Any) -> Any:
    """
    Recursively convert numpy types to standard Python types for JSON serialization.
    """
    if isinstance(data, dict):
        return {key: sanitize_for_json(value) for key, value in data.items()}
    elif isinstance(data, list):
        return [sanitize_for_json(item) for item in data]
    elif isinstance(data, (np.int_, np.intc, np.intp, np.int8,
                           np.int16, np.int32, np.int64, np.uint8,
                           np.uint16, np.uint32, np.uint64)):
        return int(data)
    elif isinstance(data, (np.float_, np.float16, np.float32, np.float64)):
        return float(data)
    elif isinstance(data, np.ndarray):
        return data.tolist()
    return data


class CausalService:
    """
    Service for orchestrating causal analysis with database operations.
    
    This service fetches data from the database, calls pure analytical functions
    from agentgraph.causal, and saves the results back to the database.
    """
    
    def __init__(self, session: Session):
        self.session = session
    
    def fetch_analysis_data(self, knowledge_graph_id: int, perturbation_set_id: str) -> Dict[str, Any]:
        """
        Fetch all data needed for causal analysis from the database.
        
        Args:
            knowledge_graph_id: ID of the knowledge graph
            perturbation_set_id: ID of the perturbation set
            
        Returns:
            Dictionary containing all data needed for analysis
        """
        try:
            # 1. Query PerturbationTest, filtering by both IDs
            perturbation_tests = self.session.query(PerturbationTest).filter(
                PerturbationTest.knowledge_graph_id == knowledge_graph_id,
                PerturbationTest.perturbation_set_id == perturbation_set_id
            ).all()
            
            if not perturbation_tests:
                logger.warning(f"No perturbation tests found for knowledge_graph_id={knowledge_graph_id}, perturbation_set_id={perturbation_set_id}")
                # Debug: Check what perturbation tests exist
                all_tests = self.session.query(PerturbationTest).filter(
                    PerturbationTest.knowledge_graph_id == knowledge_graph_id
                ).all()
                logger.warning(f"Available perturbation sets for KG {knowledge_graph_id}: {[t.perturbation_set_id for t in all_tests]}")
                return {"error": "No perturbation tests found for the specified criteria"}
            
            # Get a sample to report what we're analyzing
            sample_test = perturbation_tests[0]
            logger.info(f"Analyzing {len(perturbation_tests)} perturbation tests of type '{sample_test.perturbation_type}'")
            
            # 2. Get all prompt_reconstruction_ids
            pr_ids = [test.prompt_reconstruction_id for test in perturbation_tests]
            
            # 3. Query PromptReconstruction for these IDs
            prompt_reconstructions = self.session.query(PromptReconstruction).filter(
                PromptReconstruction.id.in_(pr_ids)
            ).all()
            
            # 4. Get the knowledge graph data
            kg = self.session.query(KnowledgeGraph).filter_by(id=knowledge_graph_id).first()
            if not kg:
                return {"error": f"Knowledge graph with ID {knowledge_graph_id} not found"}
            
            # 5. Create the analysis data structure
            analysis_data = {
                "perturbation_tests": [test.to_dict() for test in perturbation_tests],
                "dependencies_map": {pr.id: pr.dependencies for pr in prompt_reconstructions},
                "knowledge_graph": kg.graph_data,
                "perturbation_type": sample_test.perturbation_type,
                "perturbation_scores": {test.relation_id: test.perturbation_score for test in perturbation_tests},
                "relation_to_pr_map": {test.relation_id: test.prompt_reconstruction_id for test in perturbation_tests}
            }
            
            return analysis_data
            
        except Exception as e:
            logger.error(f"Error while extracting data for analysis: {repr(e)}")
            return {"error": f"Failed to extract analysis data: {repr(e)}"}
    
    def save_analysis_results(self, method: str, results: Dict[str, Any], knowledge_graph_id: int, perturbation_set_id: str) -> None:
        """
        Save analysis results to the database.
        
        Args:
            method: Analysis method name
            results: Results from the analysis
            knowledge_graph_id: ID of the knowledge graph
            perturbation_set_id: ID of the perturbation set
        """
        if "error" in results:
            logger.warning(f"Not saving results for {method} due to error: {results['error']}")
            return
            
        # Sanitize results to ensure they are JSON serializable
        sanitized_results = sanitize_for_json(results)
        
        # Calculate causal score based on method
        causal_score = self._calculate_causal_score(method, sanitized_results)
        
        # Save to database
        save_causal_analysis(
            self.session,
            knowledge_graph_id=knowledge_graph_id,
            perturbation_set_id=perturbation_set_id,
            analysis_method=method,
            analysis_result=sanitized_results,
            causal_score=causal_score,
            analysis_metadata={
                "timestamp": datetime.now(timezone.utc).isoformat(),
                "method_specific_metadata": sanitized_results.get("metadata", {})
            }
        )
        
    def _calculate_causal_score(self, method: str, result: Dict[str, Any]) -> float:
        """Calculate a single causal score for the method based on results."""
        try:
            if method == "graph":
                scores = result.get("scores", {})
                ace_scores = scores.get("ACE", {})
                if ace_scores:
                    return sum(abs(score) for score in ace_scores.values()) / len(ace_scores)
                return 0.0
                
            elif method == "component":
                scores = result.get("scores", {})
                feature_importance = scores.get("Feature_Importance", {})
                if feature_importance:
                    return sum(abs(score) for score in feature_importance.values()) / len(feature_importance)
                return 0.0
                
            elif method == "dowhy":
                scores = result.get("scores", {})
                effect_estimates = scores.get("Effect_Estimate", {})
                if effect_estimates:
                    return sum(abs(score) for score in effect_estimates.values()) / len(effect_estimates)
                return 0.0
                
            elif method in ["confounder", "mscd"]:
                scores = result.get("scores", {})
                confounders = scores.get("Confounders", {})
                return len(confounders) * 0.1  # Simple heuristic
                
            elif method == "ate":
                scores = result.get("scores", {})
                effect_strengths = scores.get("Effect_Strengths", {})
                if effect_strengths:
                    return sum(abs(score) for score in effect_strengths.values()) / len(effect_strengths)
                return 0.0
                
        except Exception as e:
            logger.warning(f"Error calculating causal score for {method}: {e}")
            
        return 0.0
    
    def run_causal_analysis(self, knowledge_graph_id: int, perturbation_set_id: str, methods: Optional[List[str]] = None) -> Dict[str, Any]:
        """
        Run causal analysis with database operations.
        
        Args:
            knowledge_graph_id: ID of the knowledge graph
            perturbation_set_id: ID of the perturbation set  
            methods: List of analysis methods to use
            
        Returns:
            Dictionary containing analysis results for each method
        """
        if methods is None:
            methods = ['graph', 'component', 'dowhy', 'confounder', 'mscd', 'ate']
        
        # Fetch data from database
        analysis_data = self.fetch_analysis_data(knowledge_graph_id, perturbation_set_id)
        if "error" in analysis_data:
            return analysis_data
        
        # Import and call pure analysis function
        try:
            results = analyze_causal_effects(analysis_data, methods)
            
            # Save each method's results to database
            for method, result in results.items():
                if "error" not in result:
                    self.save_analysis_results(method, result, knowledge_graph_id, perturbation_set_id)
            
            return results
            
        except Exception as e:
            logger.error(f"Error during causal analysis: {repr(e)}")
            return {"error": f"Analysis failed: {repr(e)}"}
    
    def get_analysis_results(self, knowledge_graph_id: int, method: Optional[str] = None) -> List[Dict[str, Any]]:
        """
        Get causal analysis results from database.
        
        Args:
            knowledge_graph_id: ID of the knowledge graph
            method: Optional filter by analysis method
            
        Returns:
            List of causal analysis results
        """
        return get_all_causal_analyses(
            session=self.session,
            knowledge_graph_id=knowledge_graph_id,
            analysis_method=method
        )
    
    def get_causal_analysis_summary(self, knowledge_graph_id: str) -> Dict[str, Any]:
        """
        Retrieves a summary of the causal analysis for a given knowledge graph.
        """
        return get_causal_analysis_summary(self.session, knowledge_graph_id)

async def analyze_causal_relationships_task(kg_id: str, task_id: str) -> bool:
    """
    Task to analyze causal relationships in a knowledge graph using the CausalService.
    Returns True if successful, False otherwise.
    """
    logger.info(f"Starting causal analysis for knowledge graph {kg_id}")
    update_task_status(task_id, "RUNNING", "Analyzing causal relationships")
    try:
        session = next(get_db())
        try:
            kg = get_knowledge_graph_by_id(session, kg_id)
            if not kg:
                logger.error(f"Knowledge graph with ID {kg_id} not found")
                update_task_status(task_id, "FAILED", f"Knowledge graph with ID {kg_id} not found")
                return False
            if kg.status not in ["perturbed", "analyzed"]:
                update_task_status(task_id, "FAILED", "Knowledge graph must be perturbed before causal analysis")
                return False
                
            # Instantiate the CausalService with the session
            causal_service = CausalService(session=session)
            
            # Get all available perturbation_set_ids for this KG
            perturbation_sets = session.query(PerturbationTest.perturbation_set_id).filter_by(knowledge_graph_id=kg.id).distinct().all()
            logger.info(f"Found {len(perturbation_sets)} perturbation sets for KG {kg.id}: {[ps[0] for ps in perturbation_sets]}")
            
            if not perturbation_sets:
                update_task_status(task_id, "FAILED", "No perturbation tests found for this knowledge graph")
                return False
                
            # Run analysis for each perturbation set
            for (perturbation_set_id,) in perturbation_sets:
                logger.info(f"Running causal analysis for KG {kg.id} with perturbation set {perturbation_set_id}")
                analysis_results = causal_service.run_causal_analysis(
                    knowledge_graph_id=kg.id,
                    perturbation_set_id=perturbation_set_id
                )
                logger.info(f"Causal analysis results for set {perturbation_set_id}: {analysis_results}")
                
            # Update KG status
            kg.status = "analyzed"
            kg.update_timestamp = datetime.now(timezone.utc)
            session.commit()
            
            update_task_status(task_id, "COMPLETED", "Causal analysis completed")
            logger.info(f"Causal analysis completed for knowledge graph {kg_id}")
            return True
        finally:
            session.close()
    except Exception as e:
        logger.error(f"Error in causal analysis: {str(e)}")
        logger.error(traceback.format_exc())
        update_task_status(task_id, "FAILED", f"Error in causal analysis: {str(e)}")
        return False