INV / helium /broadcast_ops.py
Fred808's picture
Upload 256 files
7a0c684 verified
from typing import Optional, List, Tuple
from .broadcast import BroadcastModule, BroadcastBackward
class BroadcastOps:
"""Operations with automatic broadcasting support"""
def __init__(self, driver):
self.driver = driver
self.broadcast_module = BroadcastModule(driver)
self.backward_module = BroadcastBackward(driver)
def binary_op(self, op_name: str, a_name: str, b_name: str,
save_shapes: bool = True) -> Tuple[str, Optional[Tuple[str, str]]]:
"""
Perform binary operation with automatic broadcasting.
Returns:
- Result tensor name
- Tuple of (broadcasted_a, broadcasted_b) if save_shapes=True
"""
# Get original shapes for backward pass
if save_shapes:
a_shape = self.driver.get_tensor(a_name).shape
b_shape = self.driver.get_tensor(b_name).shape
# Broadcast tensors
broadcasted = self.broadcast_module.binary_op_broadcast(a_name, b_name)
# Perform operation
if op_name == "add":
result = self.driver.add(broadcasted[0], broadcasted[1])
elif op_name == "mul":
result = self.driver.mul(broadcasted[0], broadcasted[1])
elif op_name == "div":
result = self.driver.div(broadcasted[0], broadcasted[1])
else:
raise ValueError(f"Unsupported operation: {op_name}")
result_name = f"{op_name}_result_{id(result)}"
self.driver.create_tensor(result_name, result)
if save_shapes:
return result_name, broadcasted
return result_name, None
def backward_binary_op(self, op_name: str, grad_output_name: str,
original_shapes: Tuple[Tuple[int, ...], Tuple[int, ...]],
broadcasted: Tuple[str, str]) -> Tuple[Optional[str], Optional[str]]:
"""
Compute gradients for binary operation with broadcasting.
Returns gradients for both inputs (may be None if not required).
"""
grad_a = None
grad_b = None
if op_name == "add":
# For addition, just reduce the gradient back to original shapes
grad_a = self.backward_module.reduce_gradient(grad_output_name, original_shapes[0])
grad_b = self.backward_module.reduce_gradient(grad_output_name, original_shapes[1])
elif op_name == "mul":
# For multiplication, multiply by the other tensor then reduce
grad_a = self.backward_module.reduce_gradient(
self.driver.mul(grad_output_name, broadcasted[1]),
original_shapes[0]
)
grad_b = self.backward_module.reduce_gradient(
self.driver.mul(grad_output_name, broadcasted[0]),
original_shapes[1]
)
elif op_name == "div":
# For division, more complex gradients
b_squared = self.driver.mul(broadcasted[1], broadcasted[1])
grad_a = self.backward_module.reduce_gradient(
self.driver.div(grad_output_name, broadcasted[1]),
original_shapes[0]
)
grad_b = self.backward_module.reduce_gradient(
self.driver.mul(
grad_output_name,
self.driver.div(
self.driver.mul(broadcasted[0], -1.0),
b_squared
)
),
original_shapes[1]
)
return grad_a, grad_b
class BroadcastTensor:
"""Tensor wrapper with broadcasting support"""
def __init__(self, name: str, driver, requires_grad: bool = False):
self.name = name
self.driver = driver
self.requires_grad = requires_grad
self.grad_name = None if not requires_grad else f"{name}_grad"
self.broadcast_ops = BroadcastOps(driver)
self._ctx = None
@property
def shape(self) -> Tuple[int, ...]:
return self.driver.get_tensor(self.name).shape
def _create_ctx(self, op_name: str, other: 'BroadcastTensor',
result_name: str, broadcasted: Tuple[str, str]):
"""Create context for backward pass"""
if self.requires_grad or (other and other.requires_grad):
self._ctx = {
'op': op_name,
'shapes': (self.shape, other.shape if other else None),
'broadcasted': broadcasted,
'self_name': self.name,
'other_name': other.name if other else None,
'result_name': result_name
}
def backward(self, grad_name: Optional[str] = None):
"""Execute backward pass with broadcasting support"""
if not self.requires_grad or not self._ctx:
return
if grad_name is None:
# Create ones gradient
grad_name = f"{self.name}_ones_grad"
self.driver.create_tensor(
grad_name,
self.driver.ones_like(self.name)
)
op = self._ctx['op']
shapes = self._ctx['shapes']
broadcasted = self._ctx['broadcasted']
grad_self, grad_other = self.broadcast_ops.backward_binary_op(
op, grad_name, shapes, broadcasted
)
# Accumulate gradients
if grad_self is not None and self.requires_grad:
if self.grad_name is None:
self.grad_name = grad_self
else:
self.driver.add_(self.grad_name, grad_self)
# Propagate to other tensor if needed
other = self._ctx['other_name']
if other and grad_other is not None:
other.backward(grad_other)
def __add__(self, other: 'BroadcastTensor') -> 'BroadcastTensor':
result_name, broadcasted = self.broadcast_ops.binary_op(
"add", self.name, other.name
)
result = BroadcastTensor(
result_name,
self.driver,
requires_grad=self.requires_grad or other.requires_grad
)
result._create_ctx("add", other, result_name, broadcasted)
return result
def __mul__(self, other: 'BroadcastTensor') -> 'BroadcastTensor':
result_name, broadcasted = self.broadcast_ops.binary_op(
"mul", self.name, other.name
)
result = BroadcastTensor(
result_name,
self.driver,
requires_grad=self.requires_grad or other.requires_grad
)
result._create_ctx("mul", other, result_name, broadcasted)
return result
def __truediv__(self, other: 'BroadcastTensor') -> 'BroadcastTensor':
result_name, broadcasted = self.broadcast_ops.binary_op(
"div", self.name, other.name
)
result = BroadcastTensor(
result_name,
self.driver,
requires_grad=self.requires_grad or other.requires_grad
)
result._create_ctx("div", other, result_name, broadcasted)
return result
# Example usage:
"""
# Initialize
driver = YourDriver()
# Create tensors with broadcasting support
a = BroadcastTensor("tensor_a", driver, requires_grad=True) # shape: (2, 1, 4)
b = BroadcastTensor("tensor_b", driver, requires_grad=True) # shape: (3, 1)
# Operations with automatic broadcasting
c = a + b # shape: (2, 3, 4)
d = c * a # broadcasting happens automatically
# Backward pass with proper gradient broadcasting
d.backward()
"""