digital-twin-backend / graph_engine.py
AlexKurian's picture
Backend (LFS enabled)
9e7bb72
"""
Graph mutation and simulation engine.
Applies policies to the causal graph and runs multi-pass propagation.
"""
import json
import copy
from typing import Dict, List, Any, Tuple
from datetime import datetime
class GraphState:
"""Represents the state of the causal graph."""
def __init__(self, nodes: List[Dict], edges: List[Dict]):
"""
Initialize graph state.
Args:
nodes: List of node dicts with {id, label, type, enabled, value, ...}
edges: List of edge dicts with {id, source, target, weight, ...}
"""
self.nodes = copy.deepcopy(nodes)
self.edges = copy.deepcopy(edges)
self.history = [] # For undo/redo
self.baseline_snapshot = None
@staticmethod
def from_file(filepath: str) -> 'GraphState':
"""Load graph state from JSON file."""
with open(filepath, 'r') as f:
data = json.load(f)
instance = GraphState(data.get('nodes', []), data.get('edges', []))
instance.baseline_snapshot = copy.deepcopy(data)
return instance
def to_dict(self) -> Dict[str, Any]:
"""Export graph state to dict."""
return {
'nodes': self.nodes,
'edges': self.edges,
'timestamp': datetime.now().isoformat()
}
def get_node(self, node_id: str) -> Dict:
"""Get a node by ID."""
for node in self.nodes:
if node['id'] == node_id:
return node
return None
def get_edge(self, source: str, target: str) -> Dict:
"""Get an edge by source and target."""
for edge in self.edges:
if edge['source'] == source and edge['target'] == target:
return edge
return None
def apply_mutation(self, mutation: Dict) -> Dict:
"""
Apply a single mutation, return change record.
Args:
mutation: {type, node_id, source, target, new_weight, reason, ...}
Returns:
Change record for audit trail
"""
change = {
'type': mutation['type'],
'before': None,
'after': None,
'timestamp': datetime.now().isoformat()
}
if mutation['type'] == 'disable_node':
node_id = mutation['node_id']
node = next((n for n in self.nodes if n['id'] == node_id), None)
if not node:
raise ValueError(f"Node not found: {node_id}")
change['before'] = {'id': node_id, 'enabled': node.get('enabled', True)}
node['enabled'] = False
change['after'] = {'id': node_id, 'enabled': False}
elif mutation['type'] == 'reduce_edge_weight' or mutation['type'] == 'increase_edge_weight':
source = mutation['source']
target = mutation['target']
edge = next((e for e in self.edges if e['source'] == source and e['target'] == target), None)
if not edge:
raise ValueError(f"Edge not found: {source} -> {target}")
old_weight = edge.get('data', {}).get('weight', 0.5) if isinstance(edge.get('data'), dict) else 0.5
change['before'] = {'source': source, 'target': target, 'weight': old_weight}
# Update edge weight
if 'data' not in edge:
edge['data'] = {}
edge['data']['weight'] = mutation['new_weight']
change['after'] = {'source': source, 'target': target, 'weight': mutation['new_weight']}
self.history.append(change)
return change
def apply_policy(self, policy: Dict) -> Dict:
"""
Apply all mutations in a policy.
Args:
policy: Policy dict with 'mutations' list
Returns:
Result dict with applied mutations and any errors
"""
results = {
'policy_id': policy.get('policy_id'),
'mutations_applied': [],
'timestamp': datetime.now().isoformat(),
'errors': []
}
for mutation in policy.get('mutations', []):
try:
result = self.apply_mutation(mutation)
results['mutations_applied'].append(result)
except Exception as e:
results['errors'].append({
'mutation': mutation.get('type'),
'error': str(e)
})
return results
def run_simulation(self) -> Dict[str, Any]:
"""
Multi-pass value propagation through causal graph.
Simulates how changes cascade through the system.
Returns:
Dict with node_values and output metrics (co2, aqi)
"""
# Initialize node values
node_values = {}
for node in self.nodes:
node_id = node['id']
node_type = node.get('data', {}).get('type') if isinstance(node.get('data'), dict) else node.get('type', 'intermediate')
# Sector nodes start with baseline values
if node_type == 'sector':
node_values[node_id] = node.get('data', {}).get('value', 100) if isinstance(node.get('data'), dict) else 100
else:
node_values[node_id] = 0
# Multi-pass propagation (captures cascading effects)
for iteration in range(6): # config.SIMULATION_PASSES
for edge in self.edges:
source_id = edge['source']
target_id = edge['target']
# Skip if nodes disabled or not found
source_node = next((n for n in self.nodes if n['id'] == source_id), None)
target_node = next((n for n in self.nodes if n['id'] == target_id), None)
if not source_node or not target_node:
continue
source_enabled = source_node.get('data', {}).get('enabled', True) if isinstance(source_node.get('data'), dict) else True
target_enabled = target_node.get('data', {}).get('enabled', True) if isinstance(target_node.get('data'), dict) else True
if not (source_enabled and target_enabled):
continue
# Propagate value
source_val = node_values.get(source_id, 0)
weight = edge.get('data', {}).get('weight', 0.5) if isinstance(edge.get('data'), dict) else 0.5
if source_val > 0:
node_values[target_id] = node_values.get(target_id, 0) + source_val * weight
# Extract outputs
return {
'node_values': node_values,
'outputs': {
'co2': node_values.get('co2', 0),
'aqi': node_values.get('aqi', 0)
}
}
def reset(self):
"""Reset graph to baseline state."""
if self.baseline_snapshot:
self.nodes = copy.deepcopy(self.baseline_snapshot.get('nodes', []))
self.edges = copy.deepcopy(self.baseline_snapshot.get('edges', []))
self.history = []
def undo(self, steps: int = 1) -> bool:
"""Revert last N mutations."""
for _ in range(steps):
if not self.history:
return False
change = self.history.pop()
# Reverse the change
if change['type'] == 'disable_node':
node_id = change['before']['id']
node = next((n for n in self.nodes if n['id'] == node_id), None)
if node:
node['enabled'] = change['before']['enabled']
elif 'weight' in str(change['type']):
source = change['before']['source']
target = change['before']['target']
edge = next((e for e in self.edges if e['source'] == source and e['target'] == target), None)
if edge:
if 'data' not in edge:
edge['data'] = {}
edge['data']['weight'] = change['before']['weight']
return True
class ImpactAnalyzer:
"""Analyzes impact of policies by comparing baseline vs post-policy states."""
def __init__(self, baseline_state: GraphState, post_policy_state: GraphState):
"""
Initialize analyzer.
Args:
baseline_state: GraphState before policy
post_policy_state: GraphState after policy
"""
self.baseline = baseline_state
self.post_policy = post_policy_state
def calculate_impact(self) -> Dict[str, Any]:
"""
Calculate impact metrics.
Returns:
Dict with CO₂ and AQI changes, cascade analysis, etc.
"""
baseline_sim = self.baseline.run_simulation()
post_sim = self.post_policy.run_simulation()
baseline_co2 = baseline_sim['outputs']['co2']
post_co2 = post_sim['outputs']['co2']
baseline_aqi = baseline_sim['outputs']['aqi']
post_aqi = post_sim['outputs']['aqi']
# Calculate impacts
impact = {
'co2': {
'baseline': baseline_co2,
'post_policy': post_co2,
'change_absolute': post_co2 - baseline_co2,
'change_pct': (
((post_co2 - baseline_co2) / baseline_co2 * 100)
if baseline_co2 > 0 else 0
)
},
'aqi': {
'baseline': baseline_aqi,
'post_policy': post_aqi,
'change_absolute': post_aqi - baseline_aqi,
'change_pct': (
((post_aqi - baseline_aqi) / baseline_aqi * 100)
if baseline_aqi > 0 else 0
)
}
}
# Cascade analysis
cascade = self.analyze_cascade(baseline_sim, post_sim)
impact['cascade_analysis'] = cascade
return impact
def analyze_cascade(self, baseline_sim: Dict, post_sim: Dict) -> Dict[str, Any]:
"""
Identify which nodes changed most (1st, 2nd, 3rd order effects).
"""
baseline_vals = baseline_sim['node_values']
post_vals = post_sim['node_values']
node_changes = {}
for node_id in baseline_vals:
baseline_val = baseline_vals[node_id]
post_val = post_vals.get(node_id, 0)
if baseline_val > 0.001: # Avoid division by very small numbers
pct_change = ((post_val - baseline_val) / baseline_val) * 100
node_changes[node_id] = {
'baseline': baseline_val,
'post_policy': post_val,
'change_pct': pct_change
}
# Sort by magnitude of change
sorted_changes = sorted(
node_changes.items(),
key=lambda x: abs(x[1]['change_pct']),
reverse=True
)
return {
'most_affected_nodes': [(node_id, data['change_pct']) for node_id, data in sorted_changes[:10]],
'all_node_changes': node_changes,
'summary': {
'nodes_with_reduction': len([d for d in node_changes.values() if d['change_pct'] < 0]),
'nodes_with_increase': len([d for d in node_changes.values() if d['change_pct'] > 0]),
'avg_change_pct': sum(d['change_pct'] for d in node_changes.values()) / len(node_changes) if node_changes else 0
}
}