File size: 6,847 Bytes
7a0c684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import numpy as np
from helium.jit import JITCompiler
from helium.broadcast import BroadcastModule

def test_jit_compilation():
    """Test JIT compilation with various operations"""
    
    # Initialize dummy driver with JIT operation support
    class DummyDriver:
        def __init__(self):
            self.tensors = {}
            self.counter = 0
            
        def create_tensor(self, name, data):
            self.tensors[name] = np.array(data)
            return name
            
        def get_tensor(self, name):
            return self.tensors[name]
            
        def matmul(self, a, b, output=None):
            result = np.matmul(self.get_tensor(a), self.get_tensor(b))
            if output:
                self.tensors[output] = result
                return output
            name = f"matmul_result_{self.counter}"
            self.counter += 1
            self.tensors[name] = result
            return name
            
        def add(self, a, b, output=None):
            result = self.get_tensor(a) + self.get_tensor(b)
            if output:
                self.tensors[output] = result
                return output
            name = f"add_result_{self.counter}"
            self.counter += 1
            self.tensors[name] = result
            return name
            
        def relu(self, x, output=None):
            result = np.maximum(0, self.get_tensor(x))
            if output:
                self.tensors[output] = result
                return output
            name = f"relu_result_{self.counter}"
            self.counter += 1
            self.tensors[name] = result
            return name
            
        def fused_matmul_add(self, a, b, bias, output=None):
            """Optimized matmul + add operation"""
            result = np.matmul(self.get_tensor(a), self.get_tensor(b))
            result += self.get_tensor(bias)
            if output:
                self.tensors[output] = result
                return output
            name = f"fused_matmul_add_result_{self.counter}"
            self.counter += 1
            self.tensors[name] = result
            return name
            
        def fused_add_relu(self, a, b, output=None):
            """Optimized add + relu operation"""
            result = self.get_tensor(a) + self.get_tensor(b)
            result = np.maximum(0, result)
            if output:
                self.tensors[output] = result
                return output
            name = f"fused_add_relu_result_{self.counter}"
            self.counter += 1
            self.tensors[name] = result
            return name
    
    # Create driver instance
    driver = DummyDriver()
    
    # Define a function to be JIT compiled
    def linear_relu(x_name: str, weight_name: str, bias_name: str) -> str:
        """Function implementing linear layer with ReLU"""
        # Standard implementation
        matmul_result = driver.matmul(x_name, weight_name)
        bias_result = driver.add(matmul_result, bias_name)
        return driver.relu(bias_result)
    
    # Initialize JIT compiler
    compiler = JITCompiler(driver)
    
    # Create example inputs
    batch_size = 32
    in_features = 64
    out_features = 128
    
    x = np.random.randn(batch_size, in_features)
    weight = np.random.randn(in_features, out_features)
    bias = np.random.randn(out_features)
    
    # Store tensors in driver
    x_name = driver.create_tensor("x", x)
    weight_name = driver.create_tensor("weight", weight)
    bias_name = driver.create_tensor("bias", bias)
    
    # Compile the function
    example_inputs = {
        "x_name": x_name,
        "weight_name": weight_name,
        "bias_name": bias_name
    }
    
    compiled_fn = compiler.compile(linear_relu, example_inputs)
    
    print("Running standard vs JIT compiled versions...")
    
    # Run standard version
    import time
    start_time = time.time()
    standard_result = linear_relu(x_name, weight_name, bias_name)
    standard_time = time.time() - start_time
    
    # Run JIT compiled version
    start_time = time.time()
    jit_result = compiled_fn(
        x_name=x_name,
        weight_name=weight_name,
        bias_name=bias_name
    )["output"]
    jit_time = time.time() - start_time
    
    # Compare results
    standard_output = driver.get_tensor(standard_result)
    jit_output = driver.get_tensor(jit_result)
    
    print("\nResults:")
    print(f"Standard execution time: {standard_time:.6f} seconds")
    print(f"JIT execution time: {jit_time:.6f} seconds")
    print(f"Speedup: {standard_time/jit_time:.2f}x")
    print(f"Max difference in outputs: {np.max(np.abs(standard_output - jit_output))}")
    
    # Show optimizations
    print("\nOptimizations applied:")
    print("1. Operation fusion:")
    print("   - Matmul + Add -> fused_matmul_add")
    print("   - Add + ReLU -> fused_add_relu")
    print("2. Memory reuse:")
    print("   - Intermediate tensors reuse memory slots")
    print("3. Operation reordering:")
    print("   - Independent operations can run in parallel")
    
    # Test with different input sizes
    print("\nTesting with different input sizes...")
    
    sizes = [(16, 32, 64), (64, 128, 256), (128, 256, 512)]
    
    for batch, in_dim, out_dim in sizes:
        print(f"\nInput size: batch={batch}, in_features={in_dim}, out_features={out_dim}")
        
        # Create new inputs
        x = np.random.randn(batch, in_dim)
        weight = np.random.randn(in_dim, out_dim)
        bias = np.random.randn(out_dim)
        
        x_name = driver.create_tensor(f"x_{batch}", x)
        weight_name = driver.create_tensor(f"weight_{batch}", weight)
        bias_name = driver.create_tensor(f"bias_{batch}", bias)
        
        # Time standard execution
        start_time = time.time()
        standard_result = linear_relu(x_name, weight_name, bias_name)
        standard_time = time.time() - start_time
        
        # Time JIT execution
        example_inputs = {
            "x_name": x_name,
            "weight_name": weight_name,
            "bias_name": bias_name
        }
        compiled_fn = compiler.compile(linear_relu, example_inputs)
        
        start_time = time.time()
        jit_result = compiled_fn(
            x_name=x_name,
            weight_name=weight_name,
            bias_name=bias_name
        )["output"]
        jit_time = time.time() - start_time
        
        print(f"Standard time: {standard_time:.6f} seconds")
        print(f"JIT time: {jit_time:.6f} seconds")
        print(f"Speedup: {standard_time/jit_time:.2f}x")

if __name__ == "__main__":
    test_jit_compilation()