|
|
from typing import Optional, List, Tuple
|
|
|
from .broadcast import BroadcastModule, BroadcastBackward
|
|
|
|
|
|
class BroadcastOps:
|
|
|
"""Operations with automatic broadcasting support"""
|
|
|
|
|
|
def __init__(self, driver):
|
|
|
self.driver = driver
|
|
|
self.broadcast_module = BroadcastModule(driver)
|
|
|
self.backward_module = BroadcastBackward(driver)
|
|
|
|
|
|
def binary_op(self, op_name: str, a_name: str, b_name: str,
|
|
|
save_shapes: bool = True) -> Tuple[str, Optional[Tuple[str, str]]]:
|
|
|
"""
|
|
|
Perform binary operation with automatic broadcasting.
|
|
|
Returns:
|
|
|
- Result tensor name
|
|
|
- Tuple of (broadcasted_a, broadcasted_b) if save_shapes=True
|
|
|
"""
|
|
|
|
|
|
if save_shapes:
|
|
|
a_shape = self.driver.get_tensor(a_name).shape
|
|
|
b_shape = self.driver.get_tensor(b_name).shape
|
|
|
|
|
|
|
|
|
broadcasted = self.broadcast_module.binary_op_broadcast(a_name, b_name)
|
|
|
|
|
|
|
|
|
if op_name == "add":
|
|
|
result = self.driver.add(broadcasted[0], broadcasted[1])
|
|
|
elif op_name == "mul":
|
|
|
result = self.driver.mul(broadcasted[0], broadcasted[1])
|
|
|
elif op_name == "div":
|
|
|
result = self.driver.div(broadcasted[0], broadcasted[1])
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported operation: {op_name}")
|
|
|
|
|
|
result_name = f"{op_name}_result_{id(result)}"
|
|
|
self.driver.create_tensor(result_name, result)
|
|
|
|
|
|
if save_shapes:
|
|
|
return result_name, broadcasted
|
|
|
return result_name, None
|
|
|
|
|
|
def backward_binary_op(self, op_name: str, grad_output_name: str,
|
|
|
original_shapes: Tuple[Tuple[int, ...], Tuple[int, ...]],
|
|
|
broadcasted: Tuple[str, str]) -> Tuple[Optional[str], Optional[str]]:
|
|
|
"""
|
|
|
Compute gradients for binary operation with broadcasting.
|
|
|
Returns gradients for both inputs (may be None if not required).
|
|
|
"""
|
|
|
grad_a = None
|
|
|
grad_b = None
|
|
|
|
|
|
if op_name == "add":
|
|
|
|
|
|
grad_a = self.backward_module.reduce_gradient(grad_output_name, original_shapes[0])
|
|
|
grad_b = self.backward_module.reduce_gradient(grad_output_name, original_shapes[1])
|
|
|
|
|
|
elif op_name == "mul":
|
|
|
|
|
|
grad_a = self.backward_module.reduce_gradient(
|
|
|
self.driver.mul(grad_output_name, broadcasted[1]),
|
|
|
original_shapes[0]
|
|
|
)
|
|
|
grad_b = self.backward_module.reduce_gradient(
|
|
|
self.driver.mul(grad_output_name, broadcasted[0]),
|
|
|
original_shapes[1]
|
|
|
)
|
|
|
|
|
|
elif op_name == "div":
|
|
|
|
|
|
b_squared = self.driver.mul(broadcasted[1], broadcasted[1])
|
|
|
grad_a = self.backward_module.reduce_gradient(
|
|
|
self.driver.div(grad_output_name, broadcasted[1]),
|
|
|
original_shapes[0]
|
|
|
)
|
|
|
grad_b = self.backward_module.reduce_gradient(
|
|
|
self.driver.mul(
|
|
|
grad_output_name,
|
|
|
self.driver.div(
|
|
|
self.driver.mul(broadcasted[0], -1.0),
|
|
|
b_squared
|
|
|
)
|
|
|
),
|
|
|
original_shapes[1]
|
|
|
)
|
|
|
|
|
|
return grad_a, grad_b
|
|
|
|
|
|
class BroadcastTensor:
|
|
|
"""Tensor wrapper with broadcasting support"""
|
|
|
|
|
|
def __init__(self, name: str, driver, requires_grad: bool = False):
|
|
|
self.name = name
|
|
|
self.driver = driver
|
|
|
self.requires_grad = requires_grad
|
|
|
self.grad_name = None if not requires_grad else f"{name}_grad"
|
|
|
self.broadcast_ops = BroadcastOps(driver)
|
|
|
self._ctx = None
|
|
|
|
|
|
@property
|
|
|
def shape(self) -> Tuple[int, ...]:
|
|
|
return self.driver.get_tensor(self.name).shape
|
|
|
|
|
|
def _create_ctx(self, op_name: str, other: 'BroadcastTensor',
|
|
|
result_name: str, broadcasted: Tuple[str, str]):
|
|
|
"""Create context for backward pass"""
|
|
|
if self.requires_grad or (other and other.requires_grad):
|
|
|
self._ctx = {
|
|
|
'op': op_name,
|
|
|
'shapes': (self.shape, other.shape if other else None),
|
|
|
'broadcasted': broadcasted,
|
|
|
'self_name': self.name,
|
|
|
'other_name': other.name if other else None,
|
|
|
'result_name': result_name
|
|
|
}
|
|
|
|
|
|
def backward(self, grad_name: Optional[str] = None):
|
|
|
"""Execute backward pass with broadcasting support"""
|
|
|
if not self.requires_grad or not self._ctx:
|
|
|
return
|
|
|
|
|
|
if grad_name is None:
|
|
|
|
|
|
grad_name = f"{self.name}_ones_grad"
|
|
|
self.driver.create_tensor(
|
|
|
grad_name,
|
|
|
self.driver.ones_like(self.name)
|
|
|
)
|
|
|
|
|
|
op = self._ctx['op']
|
|
|
shapes = self._ctx['shapes']
|
|
|
broadcasted = self._ctx['broadcasted']
|
|
|
|
|
|
grad_self, grad_other = self.broadcast_ops.backward_binary_op(
|
|
|
op, grad_name, shapes, broadcasted
|
|
|
)
|
|
|
|
|
|
|
|
|
if grad_self is not None and self.requires_grad:
|
|
|
if self.grad_name is None:
|
|
|
self.grad_name = grad_self
|
|
|
else:
|
|
|
self.driver.add_(self.grad_name, grad_self)
|
|
|
|
|
|
|
|
|
other = self._ctx['other_name']
|
|
|
if other and grad_other is not None:
|
|
|
other.backward(grad_other)
|
|
|
|
|
|
def __add__(self, other: 'BroadcastTensor') -> 'BroadcastTensor':
|
|
|
result_name, broadcasted = self.broadcast_ops.binary_op(
|
|
|
"add", self.name, other.name
|
|
|
)
|
|
|
result = BroadcastTensor(
|
|
|
result_name,
|
|
|
self.driver,
|
|
|
requires_grad=self.requires_grad or other.requires_grad
|
|
|
)
|
|
|
result._create_ctx("add", other, result_name, broadcasted)
|
|
|
return result
|
|
|
|
|
|
def __mul__(self, other: 'BroadcastTensor') -> 'BroadcastTensor':
|
|
|
result_name, broadcasted = self.broadcast_ops.binary_op(
|
|
|
"mul", self.name, other.name
|
|
|
)
|
|
|
result = BroadcastTensor(
|
|
|
result_name,
|
|
|
self.driver,
|
|
|
requires_grad=self.requires_grad or other.requires_grad
|
|
|
)
|
|
|
result._create_ctx("mul", other, result_name, broadcasted)
|
|
|
return result
|
|
|
|
|
|
def __truediv__(self, other: 'BroadcastTensor') -> 'BroadcastTensor':
|
|
|
result_name, broadcasted = self.broadcast_ops.binary_op(
|
|
|
"div", self.name, other.name
|
|
|
)
|
|
|
result = BroadcastTensor(
|
|
|
result_name,
|
|
|
self.driver,
|
|
|
requires_grad=self.requires_grad or other.requires_grad
|
|
|
)
|
|
|
result._create_ctx("div", other, result_name, broadcasted)
|
|
|
return result
|
|
|
|
|
|
|
|
|
"""
|
|
|
# Initialize
|
|
|
driver = YourDriver()
|
|
|
|
|
|
# Create tensors with broadcasting support
|
|
|
a = BroadcastTensor("tensor_a", driver, requires_grad=True) # shape: (2, 1, 4)
|
|
|
b = BroadcastTensor("tensor_b", driver, requires_grad=True) # shape: (3, 1)
|
|
|
|
|
|
# Operations with automatic broadcasting
|
|
|
c = a + b # shape: (2, 3, 4)
|
|
|
d = c * a # broadcasting happens automatically
|
|
|
|
|
|
# Backward pass with proper gradient broadcasting
|
|
|
d.backward()
|
|
|
"""
|
|
|
|