import numpy as np import ctypes from typing import Dict, List, Set, Tuple, Optional, Any, Union, TYPE_CHECKING from collections import defaultdict if TYPE_CHECKING: from .autograd import Tensor class ComputationalNode: def __init__(self, op_name: str, inputs: List['Tensor'], output: 'Tensor', driver): self.op_name = op_name self.node_id = f"node_{id(self)}" self.driver = driver # Store input and output tensor names in driver input_names = [inp.name for inp in inputs] self.driver.create_tensor(f"{self.node_id}_inputs", np.array(input_names)) self.driver.create_tensor(f"{self.node_id}_output", np.array(output.name)) self.gradient_fn = None self.saved_tensor_names = {} # Store tensor names instead of data def save_for_backward(self, **tensors): """Save tensor data needed for backward pass in driver storage""" for key, tensor in tensors.items(): tensor_name = f"{self.node_id}_saved_{key}" if isinstance(tensor, np.ndarray): self.driver.create_tensor(tensor_name, tensor) else: # For non-tensor data like shapes, axes etc. self.driver.create_tensor(tensor_name, np.array(tensor)) self.saved_tensor_names[key] = tensor_name def get_saved_tensor(self, key): """Retrieve saved tensor from driver storage""" return self.driver.get_tensor(self.saved_tensor_names[key]) def get_inputs(self): """Get input tensor names from driver storage""" return self.driver.get_tensor(f"{self.node_id}_inputs") def get_output(self): """Get output tensor name from driver storage""" return self.driver.get_tensor(f"{self.node_id}_output").item() class ComputeGraph: def __init__(self, driver=None): self.driver = driver self.graph_id = f"graph_{id(self)}" # Store node IDs in driver self.driver.create_tensor(f"{self.graph_id}_nodes", np.array([])) # Store op mappings in driver self.driver.create_tensor(f"{self.graph_id}_grad_fns", np.array([])) self.requires_grad = set() # Small enough to keep in Python self.is_training = True def register_operation(self, op_name: str, forward_fn: Any, backward_fn: Any): """Register a new operation with its forward and backward functions""" fn_name = f"{self.graph_id}_fn_{op_name}" self.driver.create_tensor(fn_name, np.array([id(forward_fn), id(backward_fn)])) # Update op list ops = list(self.driver.get_tensor(f"{self.graph_id}_grad_fns")) ops.append(fn_name) self.driver.create_tensor(f"{self.graph_id}_grad_fns", np.array(ops)) def track_operation(self, op_name: str, inputs: List['Tensor'], output: 'Tensor') -> None: """Record an operation in the computational graph""" if not self.is_training: return if any(inp.requires_grad for inp in inputs): node = ComputationalNode(op_name, inputs, output, self.driver) fn_name = f"{self.graph_id}_fn_{op_name}" fn_ids = self.driver.get_tensor(fn_name) node.gradient_fn = ctypes.cast(int(fn_ids[1]), ctypes.py_object).value # Add node to graph nodes = list(self.driver.get_tensor(f"{self.graph_id}_nodes")) nodes.append(node.node_id) self.driver.create_tensor(f"{self.graph_id}_nodes", np.array(nodes)) def backward(self, loss_tensor: 'Tensor', retain_graph: bool = False): """Execute backward pass through the computational graph""" nodes = self.driver.get_tensor(f"{self.graph_id}_nodes") if len(nodes) == 0: return # Initialize gradients in driver storage grad_id = f"{self.graph_id}_grads" self.driver.create_tensor(f"{grad_id}_{loss_tensor.name}", np.ones_like(loss_tensor.data())) # Topological sort using driver storage visited_id = f"{self.graph_id}_visited" self.driver.create_tensor(visited_id, np.array([])) topo_id = f"{self.graph_id}_topo" self.driver.create_tensor(topo_id, np.array([])) def build_topo(node): visited = set(self.driver.get_tensor(visited_id)) if node.node_id in visited: return visited = list(visited) visited.append(node.node_id) self.driver.create_tensor(visited_id, np.array(visited)) for input_name in node.get_inputs(): input_tensor = self.get_tensor_by_name(input_name) if input_tensor.requires_grad: for n_id in nodes: n = self.get_node_by_id(n_id) if n.get_output() == input_name: build_topo(n) topo_order = list(self.driver.get_tensor(topo_id)) topo_order.append(node.node_id) self.driver.create_tensor(topo_id, np.array(topo_order)) # Build topological ordering for node_id in reversed(nodes): node = self.get_node_by_id(node_id) if node.get_output() == loss_tensor.name: build_topo(node) # Execute backward passes in topological order topo_order = self.driver.get_tensor(topo_id) for node_id in reversed(topo_order): node = self.get_node_by_id(node_id) grad_output = self.driver.get_tensor(f"{grad_id}_{node.get_output()}") grad_inputs = node.gradient_fn(grad_output, **{k: node.get_saved_tensor(k) for k in node.saved_tensor_names}) if not isinstance(grad_inputs, tuple): grad_inputs = (grad_inputs,) for grad_input, input_name in zip(grad_inputs, node.get_inputs()): input_tensor = self.get_tensor_by_name(input_name) if input_tensor.requires_grad: grad_key = f"{grad_id}_{input_name}" if self.driver.tensor_exists(grad_key): existing_grad = self.driver.get_tensor(grad_key) self.driver.create_tensor(grad_key, existing_grad + grad_input) else: self.driver.create_tensor(grad_key, grad_input) # Update gradients in tensors for node_id in nodes: node = self.get_node_by_id(node_id) for input_name in node.get_inputs(): input_tensor = self.get_tensor_by_name(input_name) if input_tensor.requires_grad: grad_key = f"{grad_id}_{input_name}" if self.driver.tensor_exists(grad_key): input_tensor.set_grad(self.driver.get_tensor(grad_key)) if not retain_graph: # Clear graph self.driver.create_tensor(f"{self.graph_id}_nodes", np.array([])) def clear(self): """Clear the computational graph""" self.nodes.clear() def no_grad(self): """Context manager to disable gradient computation""" return NoGrad(self) class NoGrad: def __init__(self, graph: ComputeGraph): self.graph = graph self.prev_state = None def __enter__(self): self.prev_state = self.graph.is_training self.graph.is_training = False def __exit__(self, *args): self.graph.is_training = self.prev_state # Global compute graph instance GLOBAL_GRAPH = ComputeGraph() def get_compute_graph() -> ComputeGraph: return GLOBAL_GRAPH