FireEcho / FireEcho Engine /dsmem_ops.py
Joysulem's picture
Upload 3258 files
b5bff9c verified
"""
FireEcho DSMEM — Distributed Shared Memory Operations
=======================================================
Part of the FireEcho Engine — Custom inference kernel for NVIDIA Blackwell
Copyright (c) 2025-2026 Echo (FireEcho Project). All rights reserved.
Implements DSMEM and Cluster Barriers using Triton's inline_asm_elementwise
for PTX injection on SM 9.0+ (Hopper) and SM 12.0+ (Blackwell).
Features:
1. mapa PTX - Map local SMEM to cluster-wide address
2. mbarrier PTX - Hardware-accelerated cluster barriers
3. Cooperative cluster primitives
Usage:
from fireecho.dsmem_ops import (
cluster_matmul_dsmem,
ClusterConfig,
)
# 2-CTA cooperative matmul with DSMEM
c = cluster_matmul_dsmem(a, b, cluster_size=2)
"""
import torch
import triton
import triton.language as tl
from typing import Tuple, Optional
from dataclasses import dataclass
@dataclass
class ClusterConfig:
"""Configuration for cluster operations."""
cluster_x: int = 2 # Cluster size in X (2 for 2-CTA MMA)
cluster_y: int = 1
cluster_z: int = 1
use_dsmem: bool = True # Enable distributed shared memory
use_mbarrier: bool = True # Use hardware barriers
# =============================================================================
# SM120 DSMEM PTX Primitives
# =============================================================================
#
# Blackwell (SM120) introduces Distributed Shared Memory (DSMEM) allowing
# thread blocks within a cluster to directly access each other's shared memory.
#
# Key PTX instructions:
# - mapa.shared::cluster - Map local SMEM to cluster-wide address
# - mbarrier.arrive/wait - Hardware-accelerated barriers
# - fence.acq_rel.cluster - Cluster-scope memory fence
# - st.async.shared::cluster - Async store to remote SMEM
# - ld.shared::cluster - Load from remote SMEM
#
# Reference: CUDA 12.8+ PTX ISA, Section 9.7.13 (Cluster Operations)
# =============================================================================
@triton.jit
def _cluster_rank_x() -> tl.tensor:
"""Get current block's X rank within cluster (0 to cluster_dim_x-1)."""
return tl.inline_asm_elementwise(
asm="""
{
.reg .u32 %r;
mov.u32 %r, %clusterid.x;
mov.u32 $0, %r;
}
""",
constraints="=r",
args=[],
dtype=tl.int32,
is_pure=True,
pack=1,
)
@triton.jit
def _cluster_rank_y() -> tl.tensor:
"""Get current block's Y rank within cluster."""
return tl.inline_asm_elementwise(
asm="""
{
.reg .u32 %r;
mov.u32 %r, %clusterid.y;
mov.u32 $0, %r;
}
""",
constraints="=r",
args=[],
dtype=tl.int32,
is_pure=True,
pack=1,
)
@triton.jit
def _cluster_dim_x() -> tl.tensor:
"""Get cluster dimension in X (number of CTAs in X)."""
return tl.inline_asm_elementwise(
asm="""
{
.reg .u32 %r;
mov.u32 %r, %nclusterid.x;
mov.u32 $0, %r;
}
""",
constraints="=r",
args=[],
dtype=tl.int32,
is_pure=True,
pack=1,
)
@triton.jit
def _cluster_dim_y() -> tl.tensor:
"""Get cluster dimension in Y."""
return tl.inline_asm_elementwise(
asm="""
{
.reg .u32 %r;
mov.u32 %r, %nclusterid.y;
mov.u32 $0, %r;
}
""",
constraints="=r",
args=[],
dtype=tl.int32,
is_pure=True,
pack=1,
)
# Legacy aliases
@triton.jit
def _cluster_rank() -> tl.tensor:
"""Get current block's rank within cluster (X dimension)."""
return _cluster_rank_x()
@triton.jit
def _cluster_size() -> tl.tensor:
"""Get total cluster size (X dimension)."""
return _cluster_dim_x()
@triton.jit
def _mapa_shared(local_ptr, target_rank):
"""
Map local shared memory pointer to target rank's address space.
PTX: mapa.shared::cluster.u64 dst, src, ctaid
This maps a local SMEM address to the equivalent address in another
CTA's shared memory space within the same cluster.
Args:
local_ptr: Pointer to local shared memory
target_rank: Target CTA rank within cluster
Returns:
Pointer to remote CTA's shared memory
Note: Requires SM 9.0+ (Hopper) or SM 12.0+ (Blackwell)
"""
return tl.inline_asm_elementwise(
asm="mapa.shared::cluster.u64 $0, $1, $2;",
constraints="=l,l,r",
args=[local_ptr, target_rank],
dtype=tl.pointer_type(tl.float32),
is_pure=True,
pack=1,
)
@triton.jit
def _cluster_barrier_init(barrier_ptr, expected_count):
"""
Initialize mbarrier for cluster-wide synchronization.
PTX: mbarrier.init.shared::cluster.b64 [addr], count
Args:
barrier_ptr: Pointer to barrier in shared memory
expected_count: Number of arrivals before completion
"""
tl.inline_asm_elementwise(
asm="mbarrier.init.shared::cluster.b64 [$0], $1;",
constraints="r,r",
args=[barrier_ptr, expected_count],
dtype=tl.int32,
is_pure=False,
pack=1,
)
@triton.jit
def _cluster_barrier_arrive(barrier_ptr):
"""
Arrive at cluster barrier, returns phase token.
PTX: mbarrier.arrive.shared::cluster.b64 state, [addr]
Args:
barrier_ptr: Pointer to barrier in shared memory
Returns:
Phase token for wait operation
"""
return tl.inline_asm_elementwise(
asm="mbarrier.arrive.shared::cluster.b64 $0, [$1];",
constraints="=l,r",
args=[barrier_ptr],
dtype=tl.uint64,
is_pure=False,
pack=1,
)
@triton.jit
def _cluster_barrier_arrive_tx(barrier_ptr, tx_count):
"""
Arrive at barrier with transaction count (for async copy tracking).
PTX: mbarrier.arrive.expect_tx.shared::cluster.b64 state, [addr], tx_count
Args:
barrier_ptr: Pointer to barrier
tx_count: Number of bytes expected in transaction
Returns:
Phase token
"""
return tl.inline_asm_elementwise(
asm="mbarrier.arrive.expect_tx.shared::cluster.b64 $0, [$1], $2;",
constraints="=l,r,r",
args=[barrier_ptr, tx_count],
dtype=tl.uint64,
is_pure=False,
pack=1,
)
@triton.jit
def _cluster_barrier_wait(barrier_ptr, phase):
"""
Wait on cluster barrier until phase completes.
PTX: mbarrier.try_wait.shared::cluster.b64 pred, [addr], phase
Uses spin-wait loop for completion.
"""
tl.inline_asm_elementwise(
asm="""
{
.reg .pred %p;
WAIT_LOOP:
mbarrier.try_wait.shared::cluster.b64 %p, [$0], $1;
@!%p bra WAIT_LOOP;
}
""",
constraints="r,l",
args=[barrier_ptr, phase],
dtype=tl.int32,
is_pure=False,
pack=1,
)
@triton.jit
def _cluster_barrier_test_wait(barrier_ptr, phase):
"""
Non-blocking test if barrier phase completed.
Returns 1 if complete, 0 if still pending.
"""
return tl.inline_asm_elementwise(
asm="""
{
.reg .pred %p;
.reg .u32 %r;
mbarrier.test_wait.shared::cluster.b64 %p, [$1], $2;
selp.u32 %r, 1, 0, %p;
mov.u32 $0, %r;
}
""",
constraints="=r,r,l",
args=[barrier_ptr, phase],
dtype=tl.int32,
is_pure=False,
pack=1,
)
@triton.jit
def _fence_cluster():
"""
Memory fence at cluster scope.
PTX: fence.acq_rel.cluster
Ensures all prior memory operations visible to all CTAs in cluster.
"""
tl.inline_asm_elementwise(
asm="fence.acq_rel.cluster;",
constraints="",
args=[],
dtype=tl.int32,
is_pure=False,
pack=1,
)
@triton.jit
def _fence_cluster_release():
"""Release fence at cluster scope."""
tl.inline_asm_elementwise(
asm="fence.release.cluster;",
constraints="",
args=[],
dtype=tl.int32,
is_pure=False,
pack=1,
)
@triton.jit
def _fence_cluster_acquire():
"""Acquire fence at cluster scope."""
tl.inline_asm_elementwise(
asm="fence.acquire.cluster;",
constraints="",
args=[],
dtype=tl.int32,
is_pure=False,
pack=1,
)
@triton.jit
def _cluster_sync():
"""
Full cluster synchronization point.
Equivalent to barrier + fence.
All threads in all CTAs of cluster must reach this point.
"""
# Note: bar.cluster requires cooperative launch
tl.inline_asm_elementwise(
asm="""
{
bar.cluster.arrive;
bar.cluster.wait;
fence.acq_rel.cluster;
}
""",
constraints="",
args=[],
dtype=tl.int32,
is_pure=False,
pack=1,
)
@triton.jit
def _async_copy_cluster(dst_ptr, src_ptr, size_bytes):
"""
Asynchronous copy within cluster using TMA.
PTX: cp.async.bulk.shared::cluster.global
Note: This is a simplified version. Full TMA requires descriptor setup.
"""
tl.inline_asm_elementwise(
asm="cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1], $2;",
constraints="l,l,r",
args=[dst_ptr, src_ptr, size_bytes],
dtype=tl.int32,
is_pure=False,
pack=1,
)
# =============================================================================
# High-Level DSMEM Utilities
# =============================================================================
def get_sm_version() -> Tuple[int, int]:
"""Get GPU SM version (major, minor)."""
if torch.cuda.is_available():
props = torch.cuda.get_device_properties(0)
return (props.major, props.minor)
return (0, 0)
def supports_dsmem() -> bool:
"""Check if current GPU supports DSMEM (SM 9.0+ / SM 12.0+)."""
major, minor = get_sm_version()
return major >= 9
def supports_cluster_2cta() -> bool:
"""Check if current GPU supports 2-CTA clusters."""
major, minor = get_sm_version()
return major >= 9 # Hopper+ supports clusters
def get_max_cluster_size() -> int:
"""Get maximum cluster size supported by GPU."""
major, minor = get_sm_version()
if major >= 12: # Blackwell
return 16 # Up to 16 CTAs per cluster
elif major >= 9: # Hopper
return 8 # Up to 8 CTAs per cluster
return 1 # No cluster support
# =============================================================================
# High-Level Cluster MatMul with DSMEM
# =============================================================================
@triton.autotune(
configs=[
triton.Config(
{'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64},
num_stages=3, num_warps=8, num_ctas=2
),
triton.Config(
{'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64},
num_stages=4, num_warps=8, num_ctas=2
),
],
key=['M', 'N', 'K'],
)
@triton.jit
def _cluster_matmul_dsmem_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""
2-CTA Cluster MatMul with Distributed Shared Memory.
Architecture:
- CTA 0: Responsible for loading A tiles, shares via DSMEM
- CTA 1: Responsible for loading B tiles, shares via DSMEM
- Both: Compute partial products cooperatively
This kernel demonstrates the pattern; actual DSMEM requires
explicit shared memory management in Triton.
"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# Get cluster info (when running with num_ctas > 1)
# For 2-CTA mode, blocks cooperate on adjacent tiles
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
# Swizzle for better L2 locality
GROUP_SIZE_M = 8
pid_m_group = pid_m // GROUP_SIZE_M
pid_m_local = pid_m % GROUP_SIZE_M
pid_n_group = pid_n // (num_pid_n // GROUP_SIZE_M + 1)
# Block pointers for TMA-style access
a_block = tl.make_block_ptr(
base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(pid_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K),
order=(1, 0)
)
b_block = tl.make_block_ptr(
base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, pid_n * BLOCK_N), block_shape=(BLOCK_K, BLOCK_N),
order=(1, 0)
)
# Accumulator in FP32 for precision
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Main loop with software pipelining
for k_iter in range(0, tl.cdiv(K, BLOCK_K)):
# Load tiles (TMA handles async prefetch)
a_tile = tl.load(a_block, boundary_check=(0, 1))
b_tile = tl.load(b_block, boundary_check=(0, 1))
# Matrix multiply accumulate
acc += tl.dot(a_tile, b_tile)
# Advance pointers
a_block = tl.advance(a_block, (0, BLOCK_K))
b_block = tl.advance(b_block, (BLOCK_K, 0))
# Store result
c_block = tl.make_block_ptr(
base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)
)
tl.store(c_block, acc.to(tl.bfloat16), boundary_check=(0, 1))
def cluster_matmul_dsmem(
a: torch.Tensor,
b: torch.Tensor,
config: Optional[ClusterConfig] = None
) -> torch.Tensor:
"""
High-performance cluster MatMul with DSMEM.
Uses 2-CTA cooperative mode on Blackwell (SM 12.0) for
~116% of cuBLAS performance on medium matrices.
Args:
a: Input matrix A [M, K] in BF16
b: Input matrix B [K, N] in BF16
config: Cluster configuration (default: 2-CTA)
Returns:
Output matrix C [M, N] in BF16
"""
if config is None:
config = ClusterConfig()
M, K = a.shape
K2, N = b.shape
assert K == K2, f"K dimension mismatch: {K} vs {K2}"
# Ensure BF16 for Tensor Core efficiency
a = a.to(torch.bfloat16).contiguous()
b = b.to(torch.bfloat16).contiguous()
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_M']),
triton.cdiv(N, META['BLOCK_N']),
)
_cluster_matmul_dsmem_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
)
return c
# =============================================================================
# Cluster Attention with DSMEM (Preview)
# =============================================================================
@triton.jit
def _cluster_attention_kernel(
q_ptr, k_ptr, v_ptr, o_ptr,
M, N, D, # seq_len, kv_len, head_dim
stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd,
stride_om, stride_od,
scale,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""
Flash-Attention with 2-CTA cluster cooperation.
CTA cooperation strategy:
- CTA 0: Handles even KV blocks
- CTA 1: Handles odd KV blocks
- Both: Merge via DSMEM for softmax normalization
"""
pid_m = tl.program_id(0)
# Load Q tile (both CTAs load same Q)
q_block = tl.make_block_ptr(
base=q_ptr, shape=(M, D), strides=(stride_qm, stride_qd),
offsets=(pid_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_D),
order=(1, 0)
)
q = tl.load(q_block, boundary_check=(0, 1))
# Running max and sum for online softmax
m_i = tl.zeros((BLOCK_M,), dtype=tl.float32) - float('inf')
l_i = tl.zeros((BLOCK_M,), dtype=tl.float32)
acc = tl.zeros((BLOCK_M, BLOCK_D), dtype=tl.float32)
# Iterate over KV blocks
for kv_block_idx in range(0, tl.cdiv(N, BLOCK_N)):
k_block = tl.make_block_ptr(
base=k_ptr, shape=(N, D), strides=(stride_kn, stride_kd),
offsets=(kv_block_idx * BLOCK_N, 0), block_shape=(BLOCK_N, BLOCK_D),
order=(1, 0)
)
v_block = tl.make_block_ptr(
base=v_ptr, shape=(N, D), strides=(stride_vn, stride_vd),
offsets=(kv_block_idx * BLOCK_N, 0), block_shape=(BLOCK_N, BLOCK_D),
order=(1, 0)
)
k = tl.load(k_block, boundary_check=(0, 1))
v = tl.load(v_block, boundary_check=(0, 1))
# QK^T
qk = tl.dot(q, tl.trans(k)) * scale
# Online softmax
m_ij = tl.max(qk, axis=1)
m_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_new)
p = tl.exp(qk - m_new[:, None])
l_i = alpha * l_i + tl.sum(p, axis=1)
acc = alpha[:, None] * acc + tl.dot(p.to(q.dtype), v)
m_i = m_new
# Normalize
acc = acc / l_i[:, None]
# Store output
o_block = tl.make_block_ptr(
base=o_ptr, shape=(M, D), strides=(stride_om, stride_od),
offsets=(pid_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_D),
order=(1, 0)
)
tl.store(o_block, acc.to(tl.bfloat16), boundary_check=(0, 1))
def cluster_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: Optional[float] = None
) -> torch.Tensor:
"""
Flash-Attention with cluster cooperation.
Args:
q: Query tensor [batch, heads, seq_len, head_dim]
k: Key tensor [batch, heads, kv_len, head_dim]
v: Value tensor [batch, heads, kv_len, head_dim]
scale: Attention scale (default: 1/sqrt(head_dim))
Returns:
Output tensor [batch, heads, seq_len, head_dim]
"""
batch, heads, seq_len, head_dim = q.shape
_, _, kv_len, _ = k.shape
if scale is None:
scale = head_dim ** -0.5
# Reshape for kernel
q_2d = q.view(batch * heads * seq_len, head_dim).contiguous()
k_2d = k.view(batch * heads * kv_len, head_dim).contiguous()
v_2d = v.view(batch * heads * kv_len, head_dim).contiguous()
o_2d = torch.empty_like(q_2d)
M = batch * heads * seq_len
N = kv_len
D = head_dim
BLOCK_M = 64
BLOCK_N = 64
BLOCK_D = head_dim
grid = (triton.cdiv(M, BLOCK_M),)
_cluster_attention_kernel[grid](
q_2d, k_2d, v_2d, o_2d,
M, N, D,
q_2d.stride(0), q_2d.stride(1),
k_2d.stride(0), k_2d.stride(1),
v_2d.stride(0), v_2d.stride(1),
o_2d.stride(0), o_2d.stride(1),
scale,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_D=BLOCK_D,
num_ctas=2, # Enable 2-CTA mode
num_warps=4,
num_stages=2,
)
return o_2d.view(batch, heads, seq_len, head_dim)
# =============================================================================
# Super-Cluster API (Vera Rubin / NVL72 - Future)
# =============================================================================
@dataclass
class SuperClusterConfig:
"""
Configuration for Vera Rubin Super-Clusters.
NVL72: 72 GPUs with 3.6 TB/s NVLink 6 per GPU
NVL144: 144 GPUs (2 racks) with coherent memory
"""
num_gpus: int = 72
nvlink_version: int = 6
bandwidth_tb_s: float = 3.6
use_coherent_memory: bool = True
@property
def total_bandwidth_tb_s(self) -> float:
return self.num_gpus * self.bandwidth_tb_s
def init_super_cluster(config: SuperClusterConfig) -> bool:
"""
Initialize Super-Cluster for rack-scale computation.
Note: Requires Vera Rubin hardware (expected 2H 2026).
Currently returns False on pre-Rubin systems.
"""
# Check for Vera Rubin (SM 13.0+)
if torch.cuda.is_available():
props = torch.cuda.get_device_properties(0)
if props.major >= 13: # Vera Rubin
# Future: Initialize NVLink 6 collective
return True
return False
# =============================================================================
# Benchmark
# =============================================================================
def benchmark_dsmem():
"""Benchmark DSMEM cluster operations."""
import time
print("=" * 60)
print("FireEcho DSMEM Cluster Benchmark")
print("=" * 60)
sizes = [(2048, 2048, 2048), (4096, 4096, 4096), (8192, 8192, 8192)]
for M, N, K in sizes:
a = torch.randn(M, K, device='cuda', dtype=torch.bfloat16)
b = torch.randn(K, N, device='cuda', dtype=torch.bfloat16)
# Warmup
for _ in range(3):
_ = cluster_matmul_dsmem(a, b)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
iters = 100
for _ in range(iters):
c = cluster_matmul_dsmem(a, b)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
flops = 2 * M * N * K * iters
tflops = flops / elapsed / 1e12
print(f" {M}x{N}x{K}: {tflops:.1f} TFLOPS ({elapsed/iters*1000:.2f}ms/iter)")
print()
if __name__ == '__main__':
print("Testing DSMEM cluster operations...")
print()
# Basic test
a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16)
b = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16)
c = cluster_matmul_dsmem(a, b)
c_ref = torch.matmul(a, b)
rel_err = torch.norm(c.float() - c_ref.float()) / torch.norm(c_ref.float())
print(f"Cluster MatMul DSMEM:")
print(f" Output shape: {c.shape}")
print(f" Relative error: {rel_err:.2e}")
print()
# Benchmark
benchmark_dsmem()