|
|
import numpy as np
|
|
|
from helium.jit import JITCompiler
|
|
|
from helium.broadcast import BroadcastModule
|
|
|
|
|
|
def test_jit_compilation():
|
|
|
"""Test JIT compilation with various operations"""
|
|
|
|
|
|
|
|
|
class DummyDriver:
|
|
|
def __init__(self):
|
|
|
self.tensors = {}
|
|
|
self.counter = 0
|
|
|
|
|
|
def create_tensor(self, name, data):
|
|
|
self.tensors[name] = np.array(data)
|
|
|
return name
|
|
|
|
|
|
def get_tensor(self, name):
|
|
|
return self.tensors[name]
|
|
|
|
|
|
def matmul(self, a, b, output=None):
|
|
|
result = np.matmul(self.get_tensor(a), self.get_tensor(b))
|
|
|
if output:
|
|
|
self.tensors[output] = result
|
|
|
return output
|
|
|
name = f"matmul_result_{self.counter}"
|
|
|
self.counter += 1
|
|
|
self.tensors[name] = result
|
|
|
return name
|
|
|
|
|
|
def add(self, a, b, output=None):
|
|
|
result = self.get_tensor(a) + self.get_tensor(b)
|
|
|
if output:
|
|
|
self.tensors[output] = result
|
|
|
return output
|
|
|
name = f"add_result_{self.counter}"
|
|
|
self.counter += 1
|
|
|
self.tensors[name] = result
|
|
|
return name
|
|
|
|
|
|
def relu(self, x, output=None):
|
|
|
result = np.maximum(0, self.get_tensor(x))
|
|
|
if output:
|
|
|
self.tensors[output] = result
|
|
|
return output
|
|
|
name = f"relu_result_{self.counter}"
|
|
|
self.counter += 1
|
|
|
self.tensors[name] = result
|
|
|
return name
|
|
|
|
|
|
def fused_matmul_add(self, a, b, bias, output=None):
|
|
|
"""Optimized matmul + add operation"""
|
|
|
result = np.matmul(self.get_tensor(a), self.get_tensor(b))
|
|
|
result += self.get_tensor(bias)
|
|
|
if output:
|
|
|
self.tensors[output] = result
|
|
|
return output
|
|
|
name = f"fused_matmul_add_result_{self.counter}"
|
|
|
self.counter += 1
|
|
|
self.tensors[name] = result
|
|
|
return name
|
|
|
|
|
|
def fused_add_relu(self, a, b, output=None):
|
|
|
"""Optimized add + relu operation"""
|
|
|
result = self.get_tensor(a) + self.get_tensor(b)
|
|
|
result = np.maximum(0, result)
|
|
|
if output:
|
|
|
self.tensors[output] = result
|
|
|
return output
|
|
|
name = f"fused_add_relu_result_{self.counter}"
|
|
|
self.counter += 1
|
|
|
self.tensors[name] = result
|
|
|
return name
|
|
|
|
|
|
|
|
|
driver = DummyDriver()
|
|
|
|
|
|
|
|
|
def linear_relu(x_name: str, weight_name: str, bias_name: str) -> str:
|
|
|
"""Function implementing linear layer with ReLU"""
|
|
|
|
|
|
matmul_result = driver.matmul(x_name, weight_name)
|
|
|
bias_result = driver.add(matmul_result, bias_name)
|
|
|
return driver.relu(bias_result)
|
|
|
|
|
|
|
|
|
compiler = JITCompiler(driver)
|
|
|
|
|
|
|
|
|
batch_size = 32
|
|
|
in_features = 64
|
|
|
out_features = 128
|
|
|
|
|
|
x = np.random.randn(batch_size, in_features)
|
|
|
weight = np.random.randn(in_features, out_features)
|
|
|
bias = np.random.randn(out_features)
|
|
|
|
|
|
|
|
|
x_name = driver.create_tensor("x", x)
|
|
|
weight_name = driver.create_tensor("weight", weight)
|
|
|
bias_name = driver.create_tensor("bias", bias)
|
|
|
|
|
|
|
|
|
example_inputs = {
|
|
|
"x_name": x_name,
|
|
|
"weight_name": weight_name,
|
|
|
"bias_name": bias_name
|
|
|
}
|
|
|
|
|
|
compiled_fn = compiler.compile(linear_relu, example_inputs)
|
|
|
|
|
|
print("Running standard vs JIT compiled versions...")
|
|
|
|
|
|
|
|
|
import time
|
|
|
start_time = time.time()
|
|
|
standard_result = linear_relu(x_name, weight_name, bias_name)
|
|
|
standard_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
jit_result = compiled_fn(
|
|
|
x_name=x_name,
|
|
|
weight_name=weight_name,
|
|
|
bias_name=bias_name
|
|
|
)["output"]
|
|
|
jit_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
standard_output = driver.get_tensor(standard_result)
|
|
|
jit_output = driver.get_tensor(jit_result)
|
|
|
|
|
|
print("\nResults:")
|
|
|
print(f"Standard execution time: {standard_time:.6f} seconds")
|
|
|
print(f"JIT execution time: {jit_time:.6f} seconds")
|
|
|
print(f"Speedup: {standard_time/jit_time:.2f}x")
|
|
|
print(f"Max difference in outputs: {np.max(np.abs(standard_output - jit_output))}")
|
|
|
|
|
|
|
|
|
print("\nOptimizations applied:")
|
|
|
print("1. Operation fusion:")
|
|
|
print(" - Matmul + Add -> fused_matmul_add")
|
|
|
print(" - Add + ReLU -> fused_add_relu")
|
|
|
print("2. Memory reuse:")
|
|
|
print(" - Intermediate tensors reuse memory slots")
|
|
|
print("3. Operation reordering:")
|
|
|
print(" - Independent operations can run in parallel")
|
|
|
|
|
|
|
|
|
print("\nTesting with different input sizes...")
|
|
|
|
|
|
sizes = [(16, 32, 64), (64, 128, 256), (128, 256, 512)]
|
|
|
|
|
|
for batch, in_dim, out_dim in sizes:
|
|
|
print(f"\nInput size: batch={batch}, in_features={in_dim}, out_features={out_dim}")
|
|
|
|
|
|
|
|
|
x = np.random.randn(batch, in_dim)
|
|
|
weight = np.random.randn(in_dim, out_dim)
|
|
|
bias = np.random.randn(out_dim)
|
|
|
|
|
|
x_name = driver.create_tensor(f"x_{batch}", x)
|
|
|
weight_name = driver.create_tensor(f"weight_{batch}", weight)
|
|
|
bias_name = driver.create_tensor(f"bias_{batch}", bias)
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
standard_result = linear_relu(x_name, weight_name, bias_name)
|
|
|
standard_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
example_inputs = {
|
|
|
"x_name": x_name,
|
|
|
"weight_name": weight_name,
|
|
|
"bias_name": bias_name
|
|
|
}
|
|
|
compiled_fn = compiler.compile(linear_relu, example_inputs)
|
|
|
|
|
|
start_time = time.time()
|
|
|
jit_result = compiled_fn(
|
|
|
x_name=x_name,
|
|
|
weight_name=weight_name,
|
|
|
bias_name=bias_name
|
|
|
)["output"]
|
|
|
jit_time = time.time() - start_time
|
|
|
|
|
|
print(f"Standard time: {standard_time:.6f} seconds")
|
|
|
print(f"JIT time: {jit_time:.6f} seconds")
|
|
|
print(f"Speedup: {standard_time/jit_time:.2f}x")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
test_jit_compilation()
|
|
|
|