INV / helium /compute_graph.py
Fred808's picture
Upload 256 files
7a0c684 verified
import numpy as np
import ctypes
from typing import Dict, List, Set, Tuple, Optional, Any, Union, TYPE_CHECKING
from collections import defaultdict
if TYPE_CHECKING:
from .autograd import Tensor
class ComputationalNode:
def __init__(self, op_name: str, inputs: List['Tensor'], output: 'Tensor', driver):
self.op_name = op_name
self.node_id = f"node_{id(self)}"
self.driver = driver
# Store input and output tensor names in driver
input_names = [inp.name for inp in inputs]
self.driver.create_tensor(f"{self.node_id}_inputs", np.array(input_names))
self.driver.create_tensor(f"{self.node_id}_output", np.array(output.name))
self.gradient_fn = None
self.saved_tensor_names = {} # Store tensor names instead of data
def save_for_backward(self, **tensors):
"""Save tensor data needed for backward pass in driver storage"""
for key, tensor in tensors.items():
tensor_name = f"{self.node_id}_saved_{key}"
if isinstance(tensor, np.ndarray):
self.driver.create_tensor(tensor_name, tensor)
else:
# For non-tensor data like shapes, axes etc.
self.driver.create_tensor(tensor_name, np.array(tensor))
self.saved_tensor_names[key] = tensor_name
def get_saved_tensor(self, key):
"""Retrieve saved tensor from driver storage"""
return self.driver.get_tensor(self.saved_tensor_names[key])
def get_inputs(self):
"""Get input tensor names from driver storage"""
return self.driver.get_tensor(f"{self.node_id}_inputs")
def get_output(self):
"""Get output tensor name from driver storage"""
return self.driver.get_tensor(f"{self.node_id}_output").item()
class ComputeGraph:
def __init__(self, driver=None):
self.driver = driver
self.graph_id = f"graph_{id(self)}"
# Store node IDs in driver
self.driver.create_tensor(f"{self.graph_id}_nodes", np.array([]))
# Store op mappings in driver
self.driver.create_tensor(f"{self.graph_id}_grad_fns", np.array([]))
self.requires_grad = set() # Small enough to keep in Python
self.is_training = True
def register_operation(self, op_name: str, forward_fn: Any, backward_fn: Any):
"""Register a new operation with its forward and backward functions"""
fn_name = f"{self.graph_id}_fn_{op_name}"
self.driver.create_tensor(fn_name, np.array([id(forward_fn), id(backward_fn)]))
# Update op list
ops = list(self.driver.get_tensor(f"{self.graph_id}_grad_fns"))
ops.append(fn_name)
self.driver.create_tensor(f"{self.graph_id}_grad_fns", np.array(ops))
def track_operation(self, op_name: str, inputs: List['Tensor'], output: 'Tensor') -> None:
"""Record an operation in the computational graph"""
if not self.is_training:
return
if any(inp.requires_grad for inp in inputs):
node = ComputationalNode(op_name, inputs, output, self.driver)
fn_name = f"{self.graph_id}_fn_{op_name}"
fn_ids = self.driver.get_tensor(fn_name)
node.gradient_fn = ctypes.cast(int(fn_ids[1]), ctypes.py_object).value
# Add node to graph
nodes = list(self.driver.get_tensor(f"{self.graph_id}_nodes"))
nodes.append(node.node_id)
self.driver.create_tensor(f"{self.graph_id}_nodes", np.array(nodes))
def backward(self, loss_tensor: 'Tensor', retain_graph: bool = False):
"""Execute backward pass through the computational graph"""
nodes = self.driver.get_tensor(f"{self.graph_id}_nodes")
if len(nodes) == 0:
return
# Initialize gradients in driver storage
grad_id = f"{self.graph_id}_grads"
self.driver.create_tensor(f"{grad_id}_{loss_tensor.name}", np.ones_like(loss_tensor.data()))
# Topological sort using driver storage
visited_id = f"{self.graph_id}_visited"
self.driver.create_tensor(visited_id, np.array([]))
topo_id = f"{self.graph_id}_topo"
self.driver.create_tensor(topo_id, np.array([]))
def build_topo(node):
visited = set(self.driver.get_tensor(visited_id))
if node.node_id in visited:
return
visited = list(visited)
visited.append(node.node_id)
self.driver.create_tensor(visited_id, np.array(visited))
for input_name in node.get_inputs():
input_tensor = self.get_tensor_by_name(input_name)
if input_tensor.requires_grad:
for n_id in nodes:
n = self.get_node_by_id(n_id)
if n.get_output() == input_name:
build_topo(n)
topo_order = list(self.driver.get_tensor(topo_id))
topo_order.append(node.node_id)
self.driver.create_tensor(topo_id, np.array(topo_order))
# Build topological ordering
for node_id in reversed(nodes):
node = self.get_node_by_id(node_id)
if node.get_output() == loss_tensor.name:
build_topo(node)
# Execute backward passes in topological order
topo_order = self.driver.get_tensor(topo_id)
for node_id in reversed(topo_order):
node = self.get_node_by_id(node_id)
grad_output = self.driver.get_tensor(f"{grad_id}_{node.get_output()}")
grad_inputs = node.gradient_fn(grad_output, **{k: node.get_saved_tensor(k)
for k in node.saved_tensor_names})
if not isinstance(grad_inputs, tuple):
grad_inputs = (grad_inputs,)
for grad_input, input_name in zip(grad_inputs, node.get_inputs()):
input_tensor = self.get_tensor_by_name(input_name)
if input_tensor.requires_grad:
grad_key = f"{grad_id}_{input_name}"
if self.driver.tensor_exists(grad_key):
existing_grad = self.driver.get_tensor(grad_key)
self.driver.create_tensor(grad_key, existing_grad + grad_input)
else:
self.driver.create_tensor(grad_key, grad_input)
# Update gradients in tensors
for node_id in nodes:
node = self.get_node_by_id(node_id)
for input_name in node.get_inputs():
input_tensor = self.get_tensor_by_name(input_name)
if input_tensor.requires_grad:
grad_key = f"{grad_id}_{input_name}"
if self.driver.tensor_exists(grad_key):
input_tensor.set_grad(self.driver.get_tensor(grad_key))
if not retain_graph:
# Clear graph
self.driver.create_tensor(f"{self.graph_id}_nodes", np.array([]))
def clear(self):
"""Clear the computational graph"""
self.nodes.clear()
def no_grad(self):
"""Context manager to disable gradient computation"""
return NoGrad(self)
class NoGrad:
def __init__(self, graph: ComputeGraph):
self.graph = graph
self.prev_state = None
def __enter__(self):
self.prev_state = self.graph.is_training
self.graph.is_training = False
def __exit__(self, *args):
self.graph.is_training = self.prev_state
# Global compute graph instance
GLOBAL_GRAPH = ComputeGraph()
def get_compute_graph() -> ComputeGraph:
return GLOBAL_GRAPH