import numpy as np from helium.jit import JITCompiler from helium.broadcast import BroadcastModule def test_jit_compilation(): """Test JIT compilation with various operations""" # Initialize dummy driver with JIT operation support class DummyDriver: def __init__(self): self.tensors = {} self.counter = 0 def create_tensor(self, name, data): self.tensors[name] = np.array(data) return name def get_tensor(self, name): return self.tensors[name] def matmul(self, a, b, output=None): result = np.matmul(self.get_tensor(a), self.get_tensor(b)) if output: self.tensors[output] = result return output name = f"matmul_result_{self.counter}" self.counter += 1 self.tensors[name] = result return name def add(self, a, b, output=None): result = self.get_tensor(a) + self.get_tensor(b) if output: self.tensors[output] = result return output name = f"add_result_{self.counter}" self.counter += 1 self.tensors[name] = result return name def relu(self, x, output=None): result = np.maximum(0, self.get_tensor(x)) if output: self.tensors[output] = result return output name = f"relu_result_{self.counter}" self.counter += 1 self.tensors[name] = result return name def fused_matmul_add(self, a, b, bias, output=None): """Optimized matmul + add operation""" result = np.matmul(self.get_tensor(a), self.get_tensor(b)) result += self.get_tensor(bias) if output: self.tensors[output] = result return output name = f"fused_matmul_add_result_{self.counter}" self.counter += 1 self.tensors[name] = result return name def fused_add_relu(self, a, b, output=None): """Optimized add + relu operation""" result = self.get_tensor(a) + self.get_tensor(b) result = np.maximum(0, result) if output: self.tensors[output] = result return output name = f"fused_add_relu_result_{self.counter}" self.counter += 1 self.tensors[name] = result return name # Create driver instance driver = DummyDriver() # Define a function to be JIT compiled def linear_relu(x_name: str, weight_name: str, bias_name: str) -> str: """Function implementing linear layer with ReLU""" # Standard implementation matmul_result = driver.matmul(x_name, weight_name) bias_result = driver.add(matmul_result, bias_name) return driver.relu(bias_result) # Initialize JIT compiler compiler = JITCompiler(driver) # Create example inputs batch_size = 32 in_features = 64 out_features = 128 x = np.random.randn(batch_size, in_features) weight = np.random.randn(in_features, out_features) bias = np.random.randn(out_features) # Store tensors in driver x_name = driver.create_tensor("x", x) weight_name = driver.create_tensor("weight", weight) bias_name = driver.create_tensor("bias", bias) # Compile the function example_inputs = { "x_name": x_name, "weight_name": weight_name, "bias_name": bias_name } compiled_fn = compiler.compile(linear_relu, example_inputs) print("Running standard vs JIT compiled versions...") # Run standard version import time start_time = time.time() standard_result = linear_relu(x_name, weight_name, bias_name) standard_time = time.time() - start_time # Run JIT compiled version start_time = time.time() jit_result = compiled_fn( x_name=x_name, weight_name=weight_name, bias_name=bias_name )["output"] jit_time = time.time() - start_time # Compare results standard_output = driver.get_tensor(standard_result) jit_output = driver.get_tensor(jit_result) print("\nResults:") print(f"Standard execution time: {standard_time:.6f} seconds") print(f"JIT execution time: {jit_time:.6f} seconds") print(f"Speedup: {standard_time/jit_time:.2f}x") print(f"Max difference in outputs: {np.max(np.abs(standard_output - jit_output))}") # Show optimizations print("\nOptimizations applied:") print("1. Operation fusion:") print(" - Matmul + Add -> fused_matmul_add") print(" - Add + ReLU -> fused_add_relu") print("2. Memory reuse:") print(" - Intermediate tensors reuse memory slots") print("3. Operation reordering:") print(" - Independent operations can run in parallel") # Test with different input sizes print("\nTesting with different input sizes...") sizes = [(16, 32, 64), (64, 128, 256), (128, 256, 512)] for batch, in_dim, out_dim in sizes: print(f"\nInput size: batch={batch}, in_features={in_dim}, out_features={out_dim}") # Create new inputs x = np.random.randn(batch, in_dim) weight = np.random.randn(in_dim, out_dim) bias = np.random.randn(out_dim) x_name = driver.create_tensor(f"x_{batch}", x) weight_name = driver.create_tensor(f"weight_{batch}", weight) bias_name = driver.create_tensor(f"bias_{batch}", bias) # Time standard execution start_time = time.time() standard_result = linear_relu(x_name, weight_name, bias_name) standard_time = time.time() - start_time # Time JIT execution example_inputs = { "x_name": x_name, "weight_name": weight_name, "bias_name": bias_name } compiled_fn = compiler.compile(linear_relu, example_inputs) start_time = time.time() jit_result = compiled_fn( x_name=x_name, weight_name=weight_name, bias_name=bias_name )["output"] jit_time = time.time() - start_time print(f"Standard time: {standard_time:.6f} seconds") print(f"JIT time: {jit_time:.6f} seconds") print(f"Speedup: {standard_time/jit_time:.2f}x") if __name__ == "__main__": test_jit_compilation()