INV / helium /examples /test_jit.py
Fred808's picture
Upload 256 files
7a0c684 verified
import numpy as np
from helium.jit import JITCompiler
from helium.broadcast import BroadcastModule
def test_jit_compilation():
"""Test JIT compilation with various operations"""
# Initialize dummy driver with JIT operation support
class DummyDriver:
def __init__(self):
self.tensors = {}
self.counter = 0
def create_tensor(self, name, data):
self.tensors[name] = np.array(data)
return name
def get_tensor(self, name):
return self.tensors[name]
def matmul(self, a, b, output=None):
result = np.matmul(self.get_tensor(a), self.get_tensor(b))
if output:
self.tensors[output] = result
return output
name = f"matmul_result_{self.counter}"
self.counter += 1
self.tensors[name] = result
return name
def add(self, a, b, output=None):
result = self.get_tensor(a) + self.get_tensor(b)
if output:
self.tensors[output] = result
return output
name = f"add_result_{self.counter}"
self.counter += 1
self.tensors[name] = result
return name
def relu(self, x, output=None):
result = np.maximum(0, self.get_tensor(x))
if output:
self.tensors[output] = result
return output
name = f"relu_result_{self.counter}"
self.counter += 1
self.tensors[name] = result
return name
def fused_matmul_add(self, a, b, bias, output=None):
"""Optimized matmul + add operation"""
result = np.matmul(self.get_tensor(a), self.get_tensor(b))
result += self.get_tensor(bias)
if output:
self.tensors[output] = result
return output
name = f"fused_matmul_add_result_{self.counter}"
self.counter += 1
self.tensors[name] = result
return name
def fused_add_relu(self, a, b, output=None):
"""Optimized add + relu operation"""
result = self.get_tensor(a) + self.get_tensor(b)
result = np.maximum(0, result)
if output:
self.tensors[output] = result
return output
name = f"fused_add_relu_result_{self.counter}"
self.counter += 1
self.tensors[name] = result
return name
# Create driver instance
driver = DummyDriver()
# Define a function to be JIT compiled
def linear_relu(x_name: str, weight_name: str, bias_name: str) -> str:
"""Function implementing linear layer with ReLU"""
# Standard implementation
matmul_result = driver.matmul(x_name, weight_name)
bias_result = driver.add(matmul_result, bias_name)
return driver.relu(bias_result)
# Initialize JIT compiler
compiler = JITCompiler(driver)
# Create example inputs
batch_size = 32
in_features = 64
out_features = 128
x = np.random.randn(batch_size, in_features)
weight = np.random.randn(in_features, out_features)
bias = np.random.randn(out_features)
# Store tensors in driver
x_name = driver.create_tensor("x", x)
weight_name = driver.create_tensor("weight", weight)
bias_name = driver.create_tensor("bias", bias)
# Compile the function
example_inputs = {
"x_name": x_name,
"weight_name": weight_name,
"bias_name": bias_name
}
compiled_fn = compiler.compile(linear_relu, example_inputs)
print("Running standard vs JIT compiled versions...")
# Run standard version
import time
start_time = time.time()
standard_result = linear_relu(x_name, weight_name, bias_name)
standard_time = time.time() - start_time
# Run JIT compiled version
start_time = time.time()
jit_result = compiled_fn(
x_name=x_name,
weight_name=weight_name,
bias_name=bias_name
)["output"]
jit_time = time.time() - start_time
# Compare results
standard_output = driver.get_tensor(standard_result)
jit_output = driver.get_tensor(jit_result)
print("\nResults:")
print(f"Standard execution time: {standard_time:.6f} seconds")
print(f"JIT execution time: {jit_time:.6f} seconds")
print(f"Speedup: {standard_time/jit_time:.2f}x")
print(f"Max difference in outputs: {np.max(np.abs(standard_output - jit_output))}")
# Show optimizations
print("\nOptimizations applied:")
print("1. Operation fusion:")
print(" - Matmul + Add -> fused_matmul_add")
print(" - Add + ReLU -> fused_add_relu")
print("2. Memory reuse:")
print(" - Intermediate tensors reuse memory slots")
print("3. Operation reordering:")
print(" - Independent operations can run in parallel")
# Test with different input sizes
print("\nTesting with different input sizes...")
sizes = [(16, 32, 64), (64, 128, 256), (128, 256, 512)]
for batch, in_dim, out_dim in sizes:
print(f"\nInput size: batch={batch}, in_features={in_dim}, out_features={out_dim}")
# Create new inputs
x = np.random.randn(batch, in_dim)
weight = np.random.randn(in_dim, out_dim)
bias = np.random.randn(out_dim)
x_name = driver.create_tensor(f"x_{batch}", x)
weight_name = driver.create_tensor(f"weight_{batch}", weight)
bias_name = driver.create_tensor(f"bias_{batch}", bias)
# Time standard execution
start_time = time.time()
standard_result = linear_relu(x_name, weight_name, bias_name)
standard_time = time.time() - start_time
# Time JIT execution
example_inputs = {
"x_name": x_name,
"weight_name": weight_name,
"bias_name": bias_name
}
compiled_fn = compiler.compile(linear_relu, example_inputs)
start_time = time.time()
jit_result = compiled_fn(
x_name=x_name,
weight_name=weight_name,
bias_name=bias_name
)["output"]
jit_time = time.time() - start_time
print(f"Standard time: {standard_time:.6f} seconds")
print(f"JIT time: {jit_time:.6f} seconds")
print(f"Speedup: {standard_time/jit_time:.2f}x")
if __name__ == "__main__":
test_jit_compilation()