|
|
import numpy as np
|
|
|
import ctypes
|
|
|
from typing import Dict, List, Set, Tuple, Optional, Any, Union, TYPE_CHECKING
|
|
|
from collections import defaultdict
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from .autograd import Tensor
|
|
|
|
|
|
class ComputationalNode:
|
|
|
def __init__(self, op_name: str, inputs: List['Tensor'], output: 'Tensor', driver):
|
|
|
self.op_name = op_name
|
|
|
self.node_id = f"node_{id(self)}"
|
|
|
self.driver = driver
|
|
|
|
|
|
|
|
|
input_names = [inp.name for inp in inputs]
|
|
|
self.driver.create_tensor(f"{self.node_id}_inputs", np.array(input_names))
|
|
|
self.driver.create_tensor(f"{self.node_id}_output", np.array(output.name))
|
|
|
|
|
|
self.gradient_fn = None
|
|
|
self.saved_tensor_names = {}
|
|
|
|
|
|
def save_for_backward(self, **tensors):
|
|
|
"""Save tensor data needed for backward pass in driver storage"""
|
|
|
for key, tensor in tensors.items():
|
|
|
tensor_name = f"{self.node_id}_saved_{key}"
|
|
|
if isinstance(tensor, np.ndarray):
|
|
|
self.driver.create_tensor(tensor_name, tensor)
|
|
|
else:
|
|
|
|
|
|
self.driver.create_tensor(tensor_name, np.array(tensor))
|
|
|
self.saved_tensor_names[key] = tensor_name
|
|
|
|
|
|
def get_saved_tensor(self, key):
|
|
|
"""Retrieve saved tensor from driver storage"""
|
|
|
return self.driver.get_tensor(self.saved_tensor_names[key])
|
|
|
|
|
|
def get_inputs(self):
|
|
|
"""Get input tensor names from driver storage"""
|
|
|
return self.driver.get_tensor(f"{self.node_id}_inputs")
|
|
|
|
|
|
def get_output(self):
|
|
|
"""Get output tensor name from driver storage"""
|
|
|
return self.driver.get_tensor(f"{self.node_id}_output").item()
|
|
|
|
|
|
class ComputeGraph:
|
|
|
def __init__(self, driver=None):
|
|
|
self.driver = driver
|
|
|
self.graph_id = f"graph_{id(self)}"
|
|
|
|
|
|
self.driver.create_tensor(f"{self.graph_id}_nodes", np.array([]))
|
|
|
|
|
|
self.driver.create_tensor(f"{self.graph_id}_grad_fns", np.array([]))
|
|
|
self.requires_grad = set()
|
|
|
self.is_training = True
|
|
|
|
|
|
def register_operation(self, op_name: str, forward_fn: Any, backward_fn: Any):
|
|
|
"""Register a new operation with its forward and backward functions"""
|
|
|
fn_name = f"{self.graph_id}_fn_{op_name}"
|
|
|
self.driver.create_tensor(fn_name, np.array([id(forward_fn), id(backward_fn)]))
|
|
|
|
|
|
ops = list(self.driver.get_tensor(f"{self.graph_id}_grad_fns"))
|
|
|
ops.append(fn_name)
|
|
|
self.driver.create_tensor(f"{self.graph_id}_grad_fns", np.array(ops))
|
|
|
|
|
|
def track_operation(self, op_name: str, inputs: List['Tensor'], output: 'Tensor') -> None:
|
|
|
"""Record an operation in the computational graph"""
|
|
|
if not self.is_training:
|
|
|
return
|
|
|
|
|
|
if any(inp.requires_grad for inp in inputs):
|
|
|
node = ComputationalNode(op_name, inputs, output, self.driver)
|
|
|
fn_name = f"{self.graph_id}_fn_{op_name}"
|
|
|
fn_ids = self.driver.get_tensor(fn_name)
|
|
|
node.gradient_fn = ctypes.cast(int(fn_ids[1]), ctypes.py_object).value
|
|
|
|
|
|
|
|
|
nodes = list(self.driver.get_tensor(f"{self.graph_id}_nodes"))
|
|
|
nodes.append(node.node_id)
|
|
|
self.driver.create_tensor(f"{self.graph_id}_nodes", np.array(nodes))
|
|
|
|
|
|
def backward(self, loss_tensor: 'Tensor', retain_graph: bool = False):
|
|
|
"""Execute backward pass through the computational graph"""
|
|
|
nodes = self.driver.get_tensor(f"{self.graph_id}_nodes")
|
|
|
if len(nodes) == 0:
|
|
|
return
|
|
|
|
|
|
|
|
|
grad_id = f"{self.graph_id}_grads"
|
|
|
self.driver.create_tensor(f"{grad_id}_{loss_tensor.name}", np.ones_like(loss_tensor.data()))
|
|
|
|
|
|
|
|
|
visited_id = f"{self.graph_id}_visited"
|
|
|
self.driver.create_tensor(visited_id, np.array([]))
|
|
|
topo_id = f"{self.graph_id}_topo"
|
|
|
self.driver.create_tensor(topo_id, np.array([]))
|
|
|
|
|
|
def build_topo(node):
|
|
|
visited = set(self.driver.get_tensor(visited_id))
|
|
|
if node.node_id in visited:
|
|
|
return
|
|
|
|
|
|
visited = list(visited)
|
|
|
visited.append(node.node_id)
|
|
|
self.driver.create_tensor(visited_id, np.array(visited))
|
|
|
|
|
|
for input_name in node.get_inputs():
|
|
|
input_tensor = self.get_tensor_by_name(input_name)
|
|
|
if input_tensor.requires_grad:
|
|
|
for n_id in nodes:
|
|
|
n = self.get_node_by_id(n_id)
|
|
|
if n.get_output() == input_name:
|
|
|
build_topo(n)
|
|
|
|
|
|
topo_order = list(self.driver.get_tensor(topo_id))
|
|
|
topo_order.append(node.node_id)
|
|
|
self.driver.create_tensor(topo_id, np.array(topo_order))
|
|
|
|
|
|
|
|
|
for node_id in reversed(nodes):
|
|
|
node = self.get_node_by_id(node_id)
|
|
|
if node.get_output() == loss_tensor.name:
|
|
|
build_topo(node)
|
|
|
|
|
|
|
|
|
topo_order = self.driver.get_tensor(topo_id)
|
|
|
for node_id in reversed(topo_order):
|
|
|
node = self.get_node_by_id(node_id)
|
|
|
grad_output = self.driver.get_tensor(f"{grad_id}_{node.get_output()}")
|
|
|
grad_inputs = node.gradient_fn(grad_output, **{k: node.get_saved_tensor(k)
|
|
|
for k in node.saved_tensor_names})
|
|
|
|
|
|
if not isinstance(grad_inputs, tuple):
|
|
|
grad_inputs = (grad_inputs,)
|
|
|
|
|
|
for grad_input, input_name in zip(grad_inputs, node.get_inputs()):
|
|
|
input_tensor = self.get_tensor_by_name(input_name)
|
|
|
if input_tensor.requires_grad:
|
|
|
grad_key = f"{grad_id}_{input_name}"
|
|
|
if self.driver.tensor_exists(grad_key):
|
|
|
existing_grad = self.driver.get_tensor(grad_key)
|
|
|
self.driver.create_tensor(grad_key, existing_grad + grad_input)
|
|
|
else:
|
|
|
self.driver.create_tensor(grad_key, grad_input)
|
|
|
|
|
|
|
|
|
for node_id in nodes:
|
|
|
node = self.get_node_by_id(node_id)
|
|
|
for input_name in node.get_inputs():
|
|
|
input_tensor = self.get_tensor_by_name(input_name)
|
|
|
if input_tensor.requires_grad:
|
|
|
grad_key = f"{grad_id}_{input_name}"
|
|
|
if self.driver.tensor_exists(grad_key):
|
|
|
input_tensor.set_grad(self.driver.get_tensor(grad_key))
|
|
|
|
|
|
if not retain_graph:
|
|
|
|
|
|
self.driver.create_tensor(f"{self.graph_id}_nodes", np.array([]))
|
|
|
|
|
|
def clear(self):
|
|
|
"""Clear the computational graph"""
|
|
|
self.nodes.clear()
|
|
|
|
|
|
def no_grad(self):
|
|
|
"""Context manager to disable gradient computation"""
|
|
|
return NoGrad(self)
|
|
|
|
|
|
class NoGrad:
|
|
|
def __init__(self, graph: ComputeGraph):
|
|
|
self.graph = graph
|
|
|
self.prev_state = None
|
|
|
|
|
|
def __enter__(self):
|
|
|
self.prev_state = self.graph.is_training
|
|
|
self.graph.is_training = False
|
|
|
|
|
|
def __exit__(self, *args):
|
|
|
self.graph.is_training = self.prev_state
|
|
|
|
|
|
|
|
|
GLOBAL_GRAPH = ComputeGraph()
|
|
|
|
|
|
def get_compute_graph() -> ComputeGraph:
|
|
|
return GLOBAL_GRAPH
|
|
|
|