File size: 6,847 Bytes
7a0c684 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | import numpy as np
from helium.jit import JITCompiler
from helium.broadcast import BroadcastModule
def test_jit_compilation():
"""Test JIT compilation with various operations"""
# Initialize dummy driver with JIT operation support
class DummyDriver:
def __init__(self):
self.tensors = {}
self.counter = 0
def create_tensor(self, name, data):
self.tensors[name] = np.array(data)
return name
def get_tensor(self, name):
return self.tensors[name]
def matmul(self, a, b, output=None):
result = np.matmul(self.get_tensor(a), self.get_tensor(b))
if output:
self.tensors[output] = result
return output
name = f"matmul_result_{self.counter}"
self.counter += 1
self.tensors[name] = result
return name
def add(self, a, b, output=None):
result = self.get_tensor(a) + self.get_tensor(b)
if output:
self.tensors[output] = result
return output
name = f"add_result_{self.counter}"
self.counter += 1
self.tensors[name] = result
return name
def relu(self, x, output=None):
result = np.maximum(0, self.get_tensor(x))
if output:
self.tensors[output] = result
return output
name = f"relu_result_{self.counter}"
self.counter += 1
self.tensors[name] = result
return name
def fused_matmul_add(self, a, b, bias, output=None):
"""Optimized matmul + add operation"""
result = np.matmul(self.get_tensor(a), self.get_tensor(b))
result += self.get_tensor(bias)
if output:
self.tensors[output] = result
return output
name = f"fused_matmul_add_result_{self.counter}"
self.counter += 1
self.tensors[name] = result
return name
def fused_add_relu(self, a, b, output=None):
"""Optimized add + relu operation"""
result = self.get_tensor(a) + self.get_tensor(b)
result = np.maximum(0, result)
if output:
self.tensors[output] = result
return output
name = f"fused_add_relu_result_{self.counter}"
self.counter += 1
self.tensors[name] = result
return name
# Create driver instance
driver = DummyDriver()
# Define a function to be JIT compiled
def linear_relu(x_name: str, weight_name: str, bias_name: str) -> str:
"""Function implementing linear layer with ReLU"""
# Standard implementation
matmul_result = driver.matmul(x_name, weight_name)
bias_result = driver.add(matmul_result, bias_name)
return driver.relu(bias_result)
# Initialize JIT compiler
compiler = JITCompiler(driver)
# Create example inputs
batch_size = 32
in_features = 64
out_features = 128
x = np.random.randn(batch_size, in_features)
weight = np.random.randn(in_features, out_features)
bias = np.random.randn(out_features)
# Store tensors in driver
x_name = driver.create_tensor("x", x)
weight_name = driver.create_tensor("weight", weight)
bias_name = driver.create_tensor("bias", bias)
# Compile the function
example_inputs = {
"x_name": x_name,
"weight_name": weight_name,
"bias_name": bias_name
}
compiled_fn = compiler.compile(linear_relu, example_inputs)
print("Running standard vs JIT compiled versions...")
# Run standard version
import time
start_time = time.time()
standard_result = linear_relu(x_name, weight_name, bias_name)
standard_time = time.time() - start_time
# Run JIT compiled version
start_time = time.time()
jit_result = compiled_fn(
x_name=x_name,
weight_name=weight_name,
bias_name=bias_name
)["output"]
jit_time = time.time() - start_time
# Compare results
standard_output = driver.get_tensor(standard_result)
jit_output = driver.get_tensor(jit_result)
print("\nResults:")
print(f"Standard execution time: {standard_time:.6f} seconds")
print(f"JIT execution time: {jit_time:.6f} seconds")
print(f"Speedup: {standard_time/jit_time:.2f}x")
print(f"Max difference in outputs: {np.max(np.abs(standard_output - jit_output))}")
# Show optimizations
print("\nOptimizations applied:")
print("1. Operation fusion:")
print(" - Matmul + Add -> fused_matmul_add")
print(" - Add + ReLU -> fused_add_relu")
print("2. Memory reuse:")
print(" - Intermediate tensors reuse memory slots")
print("3. Operation reordering:")
print(" - Independent operations can run in parallel")
# Test with different input sizes
print("\nTesting with different input sizes...")
sizes = [(16, 32, 64), (64, 128, 256), (128, 256, 512)]
for batch, in_dim, out_dim in sizes:
print(f"\nInput size: batch={batch}, in_features={in_dim}, out_features={out_dim}")
# Create new inputs
x = np.random.randn(batch, in_dim)
weight = np.random.randn(in_dim, out_dim)
bias = np.random.randn(out_dim)
x_name = driver.create_tensor(f"x_{batch}", x)
weight_name = driver.create_tensor(f"weight_{batch}", weight)
bias_name = driver.create_tensor(f"bias_{batch}", bias)
# Time standard execution
start_time = time.time()
standard_result = linear_relu(x_name, weight_name, bias_name)
standard_time = time.time() - start_time
# Time JIT execution
example_inputs = {
"x_name": x_name,
"weight_name": weight_name,
"bias_name": bias_name
}
compiled_fn = compiler.compile(linear_relu, example_inputs)
start_time = time.time()
jit_result = compiled_fn(
x_name=x_name,
weight_name=weight_name,
bias_name=bias_name
)["output"]
jit_time = time.time() - start_time
print(f"Standard time: {standard_time:.6f} seconds")
print(f"JIT time: {jit_time:.6f} seconds")
print(f"Speedup: {standard_time/jit_time:.2f}x")
if __name__ == "__main__":
test_jit_compilation()
|