INV / helium /examples /test_broadcast.py
Fred808's picture
Upload 256 files
7a0c684 verified
import numpy as np
from helium.broadcast_ops import BroadcastTensor
from helium.broadcast import BroadcastModule, BroadcastBackward
def test_broadcasting():
"""Test broadcasting operations with various shapes"""
# Initialize dummy driver for testing
class DummyDriver:
def __init__(self):
self.tensors = {}
self.counter = 0
def create_tensor(self, name, data):
self.tensors[name] = np.array(data)
return name
def get_tensor(self, name):
return self.tensors[name]
def delete_tensor(self, name):
if name in self.tensors:
del self.tensors[name]
def tensor_exists(self, name):
return name in self.tensors
def add(self, a, b):
return self.get_tensor(a) + self.get_tensor(b)
def mul(self, a, b):
return self.get_tensor(a) * self.get_tensor(b)
def div(self, a, b):
return self.get_tensor(a) / self.get_tensor(b)
def sum(self, x, axis=None, keepdims=False):
return np.sum(self.get_tensor(x), axis=axis, keepdims=keepdims)
def ones_like(self, x):
return np.ones_like(self.get_tensor(x))
def broadcast_to(self, x, shape):
return np.broadcast_to(self.get_tensor(x), shape)
def reshape(self, x, shape):
return self.get_tensor(x).reshape(shape)
def add_(self, a, b):
self.tensors[a] += self.get_tensor(b)
# Create driver instance
driver = DummyDriver()
# Test Case 1: Basic Broadcasting
print("Test Case 1: Basic Broadcasting")
a = np.array([[1, 2, 3],
[4, 5, 6]]) # shape: (2, 3)
b = np.array([10, 20, 30]) # shape: (3,)
tensor_a = BroadcastTensor("a", driver, requires_grad=True)
tensor_b = BroadcastTensor("b", driver, requires_grad=True)
driver.create_tensor("a", a)
driver.create_tensor("b", b)
# Test addition
c = tensor_a + tensor_b
result = driver.get_tensor(c.name)
print(f"Shape a: {a.shape}")
print(f"Shape b: {b.shape}")
print(f"Result shape: {result.shape}")
print("Result:")
print(result)
print()
# Test Case 2: Complex Broadcasting
print("Test Case 2: Complex Broadcasting")
x = np.random.randn(2, 1, 4) # shape: (2, 1, 4)
y = np.random.randn(3, 1) # shape: (3, 1)
tensor_x = BroadcastTensor("x", driver, requires_grad=True)
tensor_y = BroadcastTensor("y", driver, requires_grad=True)
driver.create_tensor("x", x)
driver.create_tensor("y", y)
# Test multiplication
z = tensor_x * tensor_y
result = driver.get_tensor(z.name)
print(f"Shape x: {x.shape}")
print(f"Shape y: {y.shape}")
print(f"Result shape: {result.shape}")
print("Result shape should be (2, 3, 4)")
print()
# Test Case 3: Gradient Broadcasting
print("Test Case 3: Gradient Broadcasting")
m = np.random.randn(2, 1) # shape: (2, 1)
n = np.random.randn(3) # shape: (3,)
tensor_m = BroadcastTensor("m", driver, requires_grad=True)
tensor_n = BroadcastTensor("n", driver, requires_grad=True)
driver.create_tensor("m", m)
driver.create_tensor("n", n)
# Forward pass
p = tensor_m + tensor_n # shape will be (2, 3)
q = p * tensor_m # involves more broadcasting
# Backward pass
q.backward()
# Check gradient shapes
m_grad = driver.get_tensor(tensor_m.grad_name)
n_grad = driver.get_tensor(tensor_n.grad_name)
print(f"Original m shape: {m.shape}")
print(f"m gradient shape: {m_grad.shape}")
print(f"Original n shape: {n.shape}")
print(f"n gradient shape: {n_grad.shape}")
print()
# Test Case 4: Division with Broadcasting
print("Test Case 4: Division with Broadcasting")
u = np.random.randn(4, 1, 3) # shape: (4, 1, 3)
v = np.random.randn(1, 2, 1) # shape: (1, 2, 1)
tensor_u = BroadcastTensor("u", driver, requires_grad=True)
tensor_v = BroadcastTensor("v", driver, requires_grad=True)
driver.create_tensor("u", u)
driver.create_tensor("v", v)
# Test division
w = tensor_u / tensor_v
result = driver.get_tensor(w.name)
print(f"Shape u: {u.shape}")
print(f"Shape v: {v.shape}")
print(f"Result shape: {result.shape}")
print("Result shape should be (4, 2, 3)")
# Test backward pass
w.backward()
u_grad = driver.get_tensor(tensor_u.grad_name)
v_grad = driver.get_tensor(tensor_v.grad_name)
print(f"u gradient shape: {u_grad.shape}")
print(f"v gradient shape: {v_grad.shape}")
if __name__ == "__main__":
test_broadcasting()