File size: 13,304 Bytes
518db7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
"""
Fused Triton kernels for mHC (manifold-constrained Hyper-Connections).

Fuses rearrange + einsum operations to reduce kernel launch overhead:
- Baseline: 9 kernels per mHC module (2 rearranges + 3 einsums + 4 softmax/sinkhorn ops)
- Fused: 3-4 kernels per mHC module (fused stream mixing + separate sinkhorn)

Performance target: 3-8x speedup for 48-layer model (96 mHC modules).
"""

import torch
import triton  # type: ignore[import-untyped]
import triton.language as tl  # type: ignore[import-untyped]
from typing import Optional

@triton.jit
def fused_stream_mixing_kernel(  # type: ignore[no-untyped-def]
    # Input pointers
    x_ptr,  # Residual input (batch, seq_len, num_streams, stream_dim)
    transformed_ptr,  # Transformed input (batch, seq_len, num_streams, stream_dim)
    H_res_ptr,  # Doubly stochastic matrix (num_streams, num_streams)
    H_pre_ptr,  # Pre-mixing matrix (num_streams, num_streams)
    H_post_ptr,  # Post-mixing matrix (num_streams, num_streams)
    # Output pointer
    output_ptr,  # Output (batch, seq_len, num_streams, stream_dim)
    # Dimensions
    batch_size: tl.constexpr,
    seq_len: tl.constexpr,
    num_streams: tl.constexpr,
    stream_dim: tl.constexpr,
    # Block size
    BLOCK_SIZE: tl.constexpr,
) -> None:
    """
    Fused kernel for width_connection stream mixing.

    Fuses the following operations:
    1. residual_mixed = einsum(H_res, x_streams, "n m, b s n d -> b s m d")
    2. pre_mixed = einsum(H_pre, transformed_streams, "n m, b s n d -> b s m d")
    3. post_mixed = einsum(H_post, pre_mixed, "m n, b s m d -> b s n d")
    4. output = residual_mixed + post_mixed

    Each thread block processes a chunk of (batch, seq_len) positions.
    """
    # Get thread block ID
    pid = tl.program_id(0)

    # Calculate batch and sequence position for this block
    total_positions = batch_size * seq_len
    num_blocks = tl.cdiv(total_positions, BLOCK_SIZE)

    if pid >= num_blocks:
        return

    # Calculate batch and seq indices for this block
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # Mask for valid positions
    mask = offsets < total_positions

    # Convert flat index to (batch, seq) coordinates
    batch_idx = offsets // seq_len
    seq_idx = offsets % seq_len

    # Process each stream dimension
    for s_out in range(num_streams):
        for d in range(stream_dim):
            # Initialize accumulator for this output position
            residual_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
            post_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)

            # Step 1: Compute residual_mixed[b, s, s_out, d]
            # = sum over s_in: H_res[s_out, s_in] * x[b, s, s_in, d]
            for s_in in range(num_streams):
                # Load H_res[s_out, s_in]
                h_res_val = tl.load(
                    H_res_ptr + s_out * num_streams + s_in
                )

                # Load x[batch_idx, seq_idx, s_in, d]
                x_offset = (
                    batch_idx * seq_len * num_streams * stream_dim
                    + seq_idx * num_streams * stream_dim
                    + s_in * stream_dim
                    + d
                )
                x_val = tl.load(
                    x_ptr + x_offset,
                    mask=mask,
                    other=0.0
                )

                residual_acc += h_res_val * x_val

            # Step 2 & 3: Compute post_mixed[b, s, s_out, d]
            # First compute pre_mixed[b, s, s_mid, d] (intermediate)
            # Then multiply by H_post[s_mid, s_out]

            for s_mid in range(num_streams):
                # Compute pre_mixed[b, s, s_mid, d]
                # = sum over s_in: H_pre[s_mid, s_in] * transformed[b, s, s_in, d]
                pre_acc = 0.0
                for s_in in range(num_streams):
                    # Load H_pre[s_mid, s_in]
                    h_pre_val = tl.load(
                        H_pre_ptr + s_mid * num_streams + s_in
                    )

                    # Load transformed[batch_idx, seq_idx, s_in, d]
                    transformed_offset = (
                        batch_idx * seq_len * num_streams * stream_dim
                        + seq_idx * num_streams * stream_dim
                        + s_in * stream_dim
                        + d
                    )
                    transformed_val = tl.load(
                        transformed_ptr + transformed_offset,
                        mask=mask,
                        other=0.0
                    )

                    pre_acc += h_pre_val * transformed_val

                # Now multiply by H_post[s_mid, s_out] and accumulate
                h_post_val = tl.load(
                    H_post_ptr + s_mid * num_streams + s_out
                )
                post_acc += h_post_val * pre_acc

            # Step 4: Combine residual and post-mixed
            output_val = residual_acc + post_acc

            # Store output[batch_idx, seq_idx, s_out, d]
            output_offset = (
                batch_idx * seq_len * num_streams * stream_dim
                + seq_idx * num_streams * stream_dim
                + s_out * stream_dim
                + d
            )
            tl.store(
                output_ptr + output_offset,
                output_val,
                mask=mask
            )

def fused_width_connection_triton(
    x: torch.Tensor,
    transformed: torch.Tensor,
    H_res: torch.Tensor,
    H_pre: torch.Tensor,
    H_post: torch.Tensor,
) -> torch.Tensor:
    """
    Fused Triton implementation of width_connection.

    Args:
        x: Residual input (batch, seq_len, num_streams, stream_dim)
        transformed: Transformed input (batch, seq_len, num_streams, stream_dim)
        H_res: Doubly stochastic matrix (num_streams, num_streams)
        H_pre: Pre-mixing matrix (num_streams, num_streams)
        H_post: Post-mixing matrix (num_streams, num_streams)

    Returns:
        output: Mixed features (batch, seq_len, num_streams, stream_dim)
    """
    batch_size, seq_len, num_streams, stream_dim = x.shape

    # Validate inputs
    assert x.is_cuda, "Input must be on GPU"
    assert x.is_contiguous(), "Input must be contiguous"
    assert transformed.shape == x.shape, "Shape mismatch"
    assert H_res.shape == (num_streams, num_streams), "H_res shape mismatch"
    assert H_pre.shape == (num_streams, num_streams), "H_pre shape mismatch"
    assert H_post.shape == (num_streams, num_streams), "H_post shape mismatch"

    # Allocate output
    output = torch.empty_like(x)

    # Launch kernel
    total_positions = batch_size * seq_len
    BLOCK_SIZE = 128
    grid = (triton.cdiv(total_positions, BLOCK_SIZE),)

    fused_stream_mixing_kernel[grid](
        x, transformed, H_res, H_pre, H_post, output,
        batch_size, seq_len, num_streams, stream_dim,
        BLOCK_SIZE=BLOCK_SIZE,
    )

    return output

def compare_with_einops_reference(
    x: torch.Tensor,
    transformed: torch.Tensor,
    H_res: torch.Tensor,
    H_pre: torch.Tensor,
    H_post: torch.Tensor,
    atol: float = 1e-5,
    rtol: float = 1e-4,
) -> tuple[bool, float, Optional[str]]:
    """
    Compare Triton kernel output with einops reference implementation.

    Args:
        x: Residual input (batch, seq_len, num_streams, stream_dim)
        transformed: Transformed input (batch, seq_len, num_streams, stream_dim)
        H_res: Doubly stochastic matrix (num_streams, num_streams)
        H_pre: Pre-mixing matrix (num_streams, num_streams)
        H_post: Post-mixing matrix (num_streams, num_streams)
        atol: Absolute tolerance for comparison
        rtol: Relative tolerance for comparison

    Returns:
        matches: True if outputs match within tolerance
        max_diff: Maximum absolute difference
        error_msg: Error message if mismatch, None otherwise
    """
    from einops import einsum  # type: ignore[import-not-found]

    # Triton kernel output
    triton_output = fused_width_connection_triton(
        x, transformed, H_res, H_pre, H_post
    )

    # Reference einops implementation (from hyper_connections.py)
    residual_mixed = einsum(
        H_res, x, "n m, b s n d -> b s m d"
    )
    pre_mixed = einsum(
        H_pre, transformed, "n m, b s n d -> b s m d"
    )
    post_mixed = einsum(
        H_post, pre_mixed, "m n, b s m d -> b s n d"
    )
    einops_output = residual_mixed + post_mixed

    # Compare
    matches = torch.allclose(triton_output, einops_output, atol=atol, rtol=rtol)
    max_diff = (triton_output - einops_output).abs().max().item()

    error_msg = None
    if not matches:
        mean_diff = (triton_output - einops_output).abs().mean().item()
        error_msg = (
            f"Triton kernel output does not match einops reference!\n"
            f"Max diff: {max_diff:.6e} (atol={atol}, rtol={rtol})\n"
            f"Mean diff: {mean_diff:.6e}\n"
            f"Triton output range: [{triton_output.min().item():.4f}, "
            f"{triton_output.max().item():.4f}]\n"
            f"Einops output range: [{einops_output.min().item():.4f}, "
            f"{einops_output.max().item():.4f}]"
        )

    return matches, max_diff, error_msg

def benchmark_kernel_speedup(
    batch_size: int = 16,
    seq_len: int = 128,
    dim: int = 512,
    num_streams: int = 8,
    num_warmup: int = 10,
    num_iters: int = 100,
) -> tuple[float, float, float]:
    """
    Benchmark Triton kernel vs einops reference.

    Args:
        batch_size: Batch size
        seq_len: Sequence length
        dim: Hidden dimension
        num_streams: Number of streams
        num_warmup: Warmup iterations
        num_iters: Benchmark iterations

    Returns:
        einops_time_ms: Average einops time (milliseconds)
        triton_time_ms: Average Triton time (milliseconds)
        speedup: Speedup factor (einops_time / triton_time)
    """
    from einops import einsum  # type: ignore[import-not-found]
    import time

    assert torch.cuda.is_available(), "GPU required for benchmarking"
    device = torch.device('cuda')

    stream_dim = dim // num_streams

    # Generate random inputs
    x = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device)
    transformed = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device)
    H_res = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1)
    H_pre = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1)
    H_post = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1)

    # Warmup
    for _ in range(num_warmup):
        _ = fused_width_connection_triton(x, transformed, H_res, H_pre, H_post)
        residual_mixed = einsum(H_res, x, "n m, b s n d -> b s m d")
        pre_mixed = einsum(H_pre, transformed, "n m, b s n d -> b s m d")
        post_mixed = einsum(H_post, pre_mixed, "m n, b s m d -> b s n d")
        _ = residual_mixed + post_mixed

    torch.cuda.synchronize()

    # Benchmark einops
    start = time.perf_counter()
    for _ in range(num_iters):
        residual_mixed = einsum(H_res, x, "n m, b s n d -> b s m d")
        pre_mixed = einsum(H_pre, transformed, "n m, b s n d -> b s m d")
        post_mixed = einsum(H_post, pre_mixed, "m n, b s m d -> b s n d")
        output_einops = residual_mixed + post_mixed
    torch.cuda.synchronize()
    einops_time = (time.perf_counter() - start) / num_iters * 1000  # ms

    # Benchmark Triton
    start = time.perf_counter()
    for _ in range(num_iters):
        output_triton = fused_width_connection_triton(
            x, transformed, H_res, H_pre, H_post
        )
    torch.cuda.synchronize()
    triton_time = (time.perf_counter() - start) / num_iters * 1000  # ms

    speedup = einops_time / triton_time

    return einops_time, triton_time, speedup

if __name__ == "__main__":
    """Quick test of fused kernel correctness."""
    print("Testing fused mHC Triton kernel...")

    if not torch.cuda.is_available():
        print("CUDA not available, skipping test")
        exit(0)

    device = torch.device('cuda')

    # Test configuration
    batch_size = 4
    seq_len = 32
    num_streams = 8
    stream_dim = 64

    # Generate test inputs
    x = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device)
    transformed = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device)
    H_res = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1)
    H_pre = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1)
    H_post = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1)

    # Compare with reference
    matches, max_diff, error_msg = compare_with_einops_reference(
        x, transformed, H_res, H_pre, H_post
    )

    if matches:
        print(f"✓ Correctness test PASSED (max diff: {max_diff:.6e})")
    else:
        print(f"✗ Correctness test FAILED")
        print(error_msg)
        exit(1)

    # Benchmark
    print("\nBenchmarking...")
    einops_time, triton_time, speedup = benchmark_kernel_speedup()
    print(f"Einops time: {einops_time:.3f} ms")
    print(f"Triton time: {triton_time:.3f} ms")
    print(f"Speedup: {speedup:.2f}x")

    print("\n✓ All tests passed!")