from typing import Optional, List, Tuple from .broadcast import BroadcastModule, BroadcastBackward class BroadcastOps: """Operations with automatic broadcasting support""" def __init__(self, driver): self.driver = driver self.broadcast_module = BroadcastModule(driver) self.backward_module = BroadcastBackward(driver) def binary_op(self, op_name: str, a_name: str, b_name: str, save_shapes: bool = True) -> Tuple[str, Optional[Tuple[str, str]]]: """ Perform binary operation with automatic broadcasting. Returns: - Result tensor name - Tuple of (broadcasted_a, broadcasted_b) if save_shapes=True """ # Get original shapes for backward pass if save_shapes: a_shape = self.driver.get_tensor(a_name).shape b_shape = self.driver.get_tensor(b_name).shape # Broadcast tensors broadcasted = self.broadcast_module.binary_op_broadcast(a_name, b_name) # Perform operation if op_name == "add": result = self.driver.add(broadcasted[0], broadcasted[1]) elif op_name == "mul": result = self.driver.mul(broadcasted[0], broadcasted[1]) elif op_name == "div": result = self.driver.div(broadcasted[0], broadcasted[1]) else: raise ValueError(f"Unsupported operation: {op_name}") result_name = f"{op_name}_result_{id(result)}" self.driver.create_tensor(result_name, result) if save_shapes: return result_name, broadcasted return result_name, None def backward_binary_op(self, op_name: str, grad_output_name: str, original_shapes: Tuple[Tuple[int, ...], Tuple[int, ...]], broadcasted: Tuple[str, str]) -> Tuple[Optional[str], Optional[str]]: """ Compute gradients for binary operation with broadcasting. Returns gradients for both inputs (may be None if not required). """ grad_a = None grad_b = None if op_name == "add": # For addition, just reduce the gradient back to original shapes grad_a = self.backward_module.reduce_gradient(grad_output_name, original_shapes[0]) grad_b = self.backward_module.reduce_gradient(grad_output_name, original_shapes[1]) elif op_name == "mul": # For multiplication, multiply by the other tensor then reduce grad_a = self.backward_module.reduce_gradient( self.driver.mul(grad_output_name, broadcasted[1]), original_shapes[0] ) grad_b = self.backward_module.reduce_gradient( self.driver.mul(grad_output_name, broadcasted[0]), original_shapes[1] ) elif op_name == "div": # For division, more complex gradients b_squared = self.driver.mul(broadcasted[1], broadcasted[1]) grad_a = self.backward_module.reduce_gradient( self.driver.div(grad_output_name, broadcasted[1]), original_shapes[0] ) grad_b = self.backward_module.reduce_gradient( self.driver.mul( grad_output_name, self.driver.div( self.driver.mul(broadcasted[0], -1.0), b_squared ) ), original_shapes[1] ) return grad_a, grad_b class BroadcastTensor: """Tensor wrapper with broadcasting support""" def __init__(self, name: str, driver, requires_grad: bool = False): self.name = name self.driver = driver self.requires_grad = requires_grad self.grad_name = None if not requires_grad else f"{name}_grad" self.broadcast_ops = BroadcastOps(driver) self._ctx = None @property def shape(self) -> Tuple[int, ...]: return self.driver.get_tensor(self.name).shape def _create_ctx(self, op_name: str, other: 'BroadcastTensor', result_name: str, broadcasted: Tuple[str, str]): """Create context for backward pass""" if self.requires_grad or (other and other.requires_grad): self._ctx = { 'op': op_name, 'shapes': (self.shape, other.shape if other else None), 'broadcasted': broadcasted, 'self_name': self.name, 'other_name': other.name if other else None, 'result_name': result_name } def backward(self, grad_name: Optional[str] = None): """Execute backward pass with broadcasting support""" if not self.requires_grad or not self._ctx: return if grad_name is None: # Create ones gradient grad_name = f"{self.name}_ones_grad" self.driver.create_tensor( grad_name, self.driver.ones_like(self.name) ) op = self._ctx['op'] shapes = self._ctx['shapes'] broadcasted = self._ctx['broadcasted'] grad_self, grad_other = self.broadcast_ops.backward_binary_op( op, grad_name, shapes, broadcasted ) # Accumulate gradients if grad_self is not None and self.requires_grad: if self.grad_name is None: self.grad_name = grad_self else: self.driver.add_(self.grad_name, grad_self) # Propagate to other tensor if needed other = self._ctx['other_name'] if other and grad_other is not None: other.backward(grad_other) def __add__(self, other: 'BroadcastTensor') -> 'BroadcastTensor': result_name, broadcasted = self.broadcast_ops.binary_op( "add", self.name, other.name ) result = BroadcastTensor( result_name, self.driver, requires_grad=self.requires_grad or other.requires_grad ) result._create_ctx("add", other, result_name, broadcasted) return result def __mul__(self, other: 'BroadcastTensor') -> 'BroadcastTensor': result_name, broadcasted = self.broadcast_ops.binary_op( "mul", self.name, other.name ) result = BroadcastTensor( result_name, self.driver, requires_grad=self.requires_grad or other.requires_grad ) result._create_ctx("mul", other, result_name, broadcasted) return result def __truediv__(self, other: 'BroadcastTensor') -> 'BroadcastTensor': result_name, broadcasted = self.broadcast_ops.binary_op( "div", self.name, other.name ) result = BroadcastTensor( result_name, self.driver, requires_grad=self.requires_grad or other.requires_grad ) result._create_ctx("div", other, result_name, broadcasted) return result # Example usage: """ # Initialize driver = YourDriver() # Create tensors with broadcasting support a = BroadcastTensor("tensor_a", driver, requires_grad=True) # shape: (2, 1, 4) b = BroadcastTensor("tensor_b", driver, requires_grad=True) # shape: (3, 1) # Operations with automatic broadcasting c = a + b # shape: (2, 3, 4) d = c * a # broadcasting happens automatically # Backward pass with proper gradient broadcasting d.backward() """