sky2 / benchmarks /gpu_mode /trimul /reference.py
JustinTX's picture
Add files using upload-large-folder tool
af83196 verified
"""
Reference implementation for Triangle Multiplicative Update (TriMul) Triton kernel.
Core operation for AlphaFold3, Chai, Protenix protein structure models.
Same test cases, benchmarks, generate_input, ref_kernel, and check_implementation.
"""
import math
import torch
from torch import nn, einsum
# ---------------------------------------------------------------------------
# Scoring and benchmark configuration (read by shared_eval.py)
# ---------------------------------------------------------------------------
SCORE_SCALE = 3000.0
# trimul uses CUDA events timing, 0.1% rel error, 120s wall clock timeout
BENCH_USE_CUDA_EVENTS = True
BENCH_REL_ERROR = 0.001
BENCH_WALL_TIMEOUT_NS = 120e9
BENCH_NO_GRAD = False
BENCH_MAX_REPEATS = 100
BENCH_MAX_TIME_NS = 10e9
BENCH_WARMUP_STYLE = 'tiny_benchmark'
# ---------------------------------------------------------------------------
# Test / benchmark cases — full set from discover task.yml
# ---------------------------------------------------------------------------
TEST_CASES = [
{"seqlen": 32, "bs": 1, "dim": 128, "hiddendim": 128, "seed": 9371, "nomask": True, "distribution": "normal"},
{"seqlen": 32, "bs": 1, "dim": 128, "hiddendim": 128, "seed": 1092, "nomask": False, "distribution": "normal"},
{"seqlen": 64, "bs": 2, "dim": 256, "hiddendim": 128, "seed": 2291, "nomask": True, "distribution": "normal"},
{"seqlen": 64, "bs": 2, "dim": 256, "hiddendim": 128, "seed": 210284, "nomask": False, "distribution": "normal"},
{"seqlen": 128, "bs": 1, "dim": 768, "hiddendim": 128, "seed": 81934, "nomask": True, "distribution": "normal"},
{"seqlen": 256, "bs": 1, "dim": 128, "hiddendim": 128, "seed": 1932, "nomask": True, "distribution": "normal"},
{"seqlen": 256, "bs": 1, "dim": 128, "hiddendim": 128, "seed": 10432, "nomask": False, "distribution": "normal"},
{"seqlen": 768, "bs": 2, "dim": 128, "hiddendim": 128, "seed": 731, "nomask": True, "distribution": "normal"},
{"seqlen": 1024, "bs": 1, "dim": 384, "hiddendim": 128, "seed": 53121, "nomask": False, "distribution": "normal"},
{"seqlen": 1024, "bs": 1, "dim": 768, "hiddendim": 128, "seed": 31, "nomask": True, "distribution": "normal"},
{"seqlen": 1024, "bs": 1, "dim": 768, "hiddendim": 128, "seed": 4921, "nomask": False, "distribution": "normal"},
{"seqlen": 32, "bs": 1, "dim": 128, "hiddendim": 128, "seed": 937321, "nomask": True, "distribution": "cauchy"},
{"seqlen": 64, "bs": 2, "dim": 256, "hiddendim": 128, "seed": 2291, "nomask": True, "distribution": "cauchy"},
{"seqlen": 128, "bs": 1, "dim": 768, "hiddendim": 128, "seed": 8134, "nomask": True, "distribution": "cauchy"},
{"seqlen": 256, "bs": 1, "dim": 128, "hiddendim": 128, "seed": 932, "nomask": True, "distribution": "cauchy"},
{"seqlen": 768, "bs": 2, "dim": 128, "hiddendim": 128, "seed": 31, "nomask": True, "distribution": "cauchy"},
{"seqlen": 1024, "bs": 1, "dim": 384, "hiddendim": 128, "seed": 5321, "nomask": False, "distribution": "cauchy"},
{"seqlen": 1024, "bs": 1, "dim": 768, "hiddendim": 128, "seed": 491, "nomask": False, "distribution": "cauchy"},
]
BENCHMARK_CASES = [
{"seqlen": 256, "bs": 2, "dim": 128, "hiddendim": 128, "seed": 9371, "nomask": True, "distribution": "normal"},
{"seqlen": 768, "bs": 1, "dim": 128, "hiddendim": 128, "seed": 381, "nomask": True, "distribution": "cauchy"},
{"seqlen": 256, "bs": 2, "dim": 384, "hiddendim": 128, "seed": 2301, "nomask": False, "distribution": "normal"},
{"seqlen": 512, "bs": 1, "dim": 128, "hiddendim": 128, "seed": 12819, "nomask": True, "distribution": "normal"},
{"seqlen": 1024, "bs": 1, "dim": 128, "hiddendim": 128, "seed": 381, "nomask": True, "distribution": "cauchy"},
{"seqlen": 768, "bs": 1, "dim": 384, "hiddendim": 128, "seed": 481, "nomask": False, "distribution": "normal"},
{"seqlen": 1024, "bs": 1, "dim": 384, "hiddendim": 128, "seed": 23291, "nomask": True, "distribution": "normal"},
]
# ---------------------------------------------------------------------------
# Reference kernel
# ---------------------------------------------------------------------------
class _TriMul(nn.Module):
def __init__(self, dim, hidden_dim, device="cuda"):
super().__init__()
self.norm = nn.LayerNorm(dim, device=device)
self.left_proj = nn.Linear(dim, hidden_dim, bias=False, device=device)
self.right_proj = nn.Linear(dim, hidden_dim, bias=False, device=device)
self.left_gate = nn.Linear(dim, hidden_dim, bias=False, device=device)
self.right_gate = nn.Linear(dim, hidden_dim, bias=False, device=device)
self.out_gate = nn.Linear(dim, hidden_dim, bias=False, device=device)
self.to_out_norm = nn.LayerNorm(hidden_dim, device=device)
self.to_out = nn.Linear(hidden_dim, dim, bias=False, device=device)
def forward(self, x, mask):
x = self.norm(x)
left = self.left_proj(x)
right = self.right_proj(x)
mask = mask.unsqueeze(-1)
left = left * mask
right = right * mask
left = left * self.left_gate(x).sigmoid()
right = right * self.right_gate(x).sigmoid()
out_gate = self.out_gate(x).sigmoid()
out = einsum('... i k d, ... j k d -> ... i j d', left, right)
out = self.to_out_norm(out)
out = out * out_gate
return self.to_out(out)
def ref_kernel(data):
old_matmul = torch.backends.cuda.matmul.allow_tf32
old_cudnn = torch.backends.cudnn.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
try:
input_tensor, mask, weights, config = data
trimul = _TriMul(dim=config["dim"], hidden_dim=config["hidden_dim"],
device=input_tensor.device)
trimul.norm.weight = nn.Parameter(weights['norm.weight'])
trimul.norm.bias = nn.Parameter(weights['norm.bias'])
trimul.left_proj.weight = nn.Parameter(weights['left_proj.weight'])
trimul.right_proj.weight = nn.Parameter(weights['right_proj.weight'])
trimul.left_gate.weight = nn.Parameter(weights['left_gate.weight'])
trimul.right_gate.weight = nn.Parameter(weights['right_gate.weight'])
trimul.out_gate.weight = nn.Parameter(weights['out_gate.weight'])
trimul.to_out_norm.weight = nn.Parameter(weights['to_out_norm.weight'])
trimul.to_out_norm.bias = nn.Parameter(weights['to_out_norm.bias'])
trimul.to_out.weight = nn.Parameter(weights['to_out.weight'])
return trimul(input_tensor, mask)
finally:
torch.backends.cuda.matmul.allow_tf32 = old_matmul
torch.backends.cudnn.allow_tf32 = old_cudnn
def generate_input(seqlen, bs, dim, hiddendim, seed, nomask, distribution="normal"):
hidden_dim = hiddendim
config = {"hidden_dim": hidden_dim, "dim": dim}
gen = torch.Generator(device='cuda')
gen.manual_seed(seed)
if distribution == "cauchy":
u = torch.empty((bs, seqlen, seqlen, dim), device="cuda", dtype=torch.float32)
u.uniform_(0.0, 1.0, generator=gen)
input_tensor = 2.0 * torch.tan(math.pi * (u - 0.5))
else:
input_tensor = torch.randn(
(bs, seqlen, seqlen, dim), device='cuda', dtype=torch.float32, generator=gen
).contiguous()
if nomask:
mask = torch.ones(bs, seqlen, seqlen, device="cuda")
else:
mask = torch.randint(0, 2, (bs, seqlen, seqlen), device="cuda", generator=gen).float()
weights = {
"norm.weight": torch.randn(dim, device="cuda"),
"norm.bias": torch.randn(dim, device="cuda"),
"left_proj.weight": torch.randn(hidden_dim, dim, device="cuda") / math.sqrt(hidden_dim),
"right_proj.weight": torch.randn(hidden_dim, dim, device="cuda") / math.sqrt(hidden_dim),
"left_gate.weight": torch.randn(hidden_dim, dim, device="cuda") / math.sqrt(hidden_dim),
"right_gate.weight": torch.randn(hidden_dim, dim, device="cuda") / math.sqrt(hidden_dim),
"out_gate.weight": torch.randn(hidden_dim, dim, device="cuda") / math.sqrt(hidden_dim),
"to_out_norm.weight": torch.randn(hidden_dim, device="cuda"),
"to_out_norm.bias": torch.randn(hidden_dim, device="cuda"),
"to_out.weight": torch.randn(dim, hidden_dim, device="cuda") / math.sqrt(dim),
}
return (input_tensor, mask, weights, config)
def check_implementation(data, submission_output, rtol=2e-2, atol=2e-2):
old_matmul = torch.backends.cuda.matmul.allow_tf32
old_cudnn = torch.backends.cudnn.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
try:
ref_output = ref_kernel(data)
if ref_output.shape != submission_output.shape:
return False, f"Shape mismatch: {ref_output.shape} vs {submission_output.shape}"
if torch.allclose(ref_output.float(), submission_output.float(), rtol=rtol, atol=atol):
return True, "Match"
diff = torch.abs(ref_output.float() - submission_output.float())
return False, f"max_diff={diff.max().item():.6f}, avg_diff={diff.mean().item():.6f}"
finally:
torch.backends.cuda.matmul.allow_tf32 = old_matmul
torch.backends.cudnn.allow_tf32 = old_cudnn
# ---------------------------------------------------------------------------
# Self-contained reference code for Modal remote execution
# ---------------------------------------------------------------------------
MODAL_REFERENCE_CODE = r'''
import math
import torch
from torch import nn, einsum
class _TriMul(nn.Module):
def __init__(self, dim, hidden_dim, device="cuda"):
super().__init__()
self.norm = nn.LayerNorm(dim, device=device)
self.left_proj = nn.Linear(dim, hidden_dim, bias=False, device=device)
self.right_proj = nn.Linear(dim, hidden_dim, bias=False, device=device)
self.left_gate = nn.Linear(dim, hidden_dim, bias=False, device=device)
self.right_gate = nn.Linear(dim, hidden_dim, bias=False, device=device)
self.out_gate = nn.Linear(dim, hidden_dim, bias=False, device=device)
self.to_out_norm = nn.LayerNorm(hidden_dim, device=device)
self.to_out = nn.Linear(hidden_dim, dim, bias=False, device=device)
def forward(self, x, mask):
x = self.norm(x)
left = self.left_proj(x)
right = self.right_proj(x)
mask = mask.unsqueeze(-1)
left = left * mask
right = right * mask
left = left * self.left_gate(x).sigmoid()
right = right * self.right_gate(x).sigmoid()
out_gate = self.out_gate(x).sigmoid()
out = einsum('... i k d, ... j k d -> ... i j d', left, right)
out = self.to_out_norm(out)
out = out * out_gate
return self.to_out(out)
def ref_kernel(data):
old_matmul = torch.backends.cuda.matmul.allow_tf32
old_cudnn = torch.backends.cudnn.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
try:
input_tensor, mask, weights, config = data
trimul = _TriMul(dim=config["dim"], hidden_dim=config["hidden_dim"],
device=input_tensor.device)
trimul.norm.weight = nn.Parameter(weights['norm.weight'])
trimul.norm.bias = nn.Parameter(weights['norm.bias'])
trimul.left_proj.weight = nn.Parameter(weights['left_proj.weight'])
trimul.right_proj.weight = nn.Parameter(weights['right_proj.weight'])
trimul.left_gate.weight = nn.Parameter(weights['left_gate.weight'])
trimul.right_gate.weight = nn.Parameter(weights['right_gate.weight'])
trimul.out_gate.weight = nn.Parameter(weights['out_gate.weight'])
trimul.to_out_norm.weight = nn.Parameter(weights['to_out_norm.weight'])
trimul.to_out_norm.bias = nn.Parameter(weights['to_out_norm.bias'])
trimul.to_out.weight = nn.Parameter(weights['to_out.weight'])
return trimul(input_tensor, mask)
finally:
torch.backends.cuda.matmul.allow_tf32 = old_matmul
torch.backends.cudnn.allow_tf32 = old_cudnn
def generate_input(seqlen, bs, dim, hiddendim, seed, nomask, distribution="normal"):
hidden_dim = hiddendim
config = {"hidden_dim": hidden_dim, "dim": dim}
gen = torch.Generator(device='cuda')
gen.manual_seed(seed)
if distribution == "cauchy":
u = torch.empty((bs, seqlen, seqlen, dim), device="cuda", dtype=torch.float32)
u.uniform_(0.0, 1.0, generator=gen)
input_tensor = 2.0 * torch.tan(math.pi * (u - 0.5))
else:
input_tensor = torch.randn(
(bs, seqlen, seqlen, dim), device='cuda', dtype=torch.float32, generator=gen
).contiguous()
if nomask:
mask = torch.ones(bs, seqlen, seqlen, device="cuda")
else:
mask = torch.randint(0, 2, (bs, seqlen, seqlen), device="cuda", generator=gen).float()
weights = {
"norm.weight": torch.randn(dim, device="cuda"),
"norm.bias": torch.randn(dim, device="cuda"),
"left_proj.weight": torch.randn(hidden_dim, dim, device="cuda") / math.sqrt(hidden_dim),
"right_proj.weight": torch.randn(hidden_dim, dim, device="cuda") / math.sqrt(hidden_dim),
"left_gate.weight": torch.randn(hidden_dim, dim, device="cuda") / math.sqrt(hidden_dim),
"right_gate.weight": torch.randn(hidden_dim, dim, device="cuda") / math.sqrt(hidden_dim),
"out_gate.weight": torch.randn(hidden_dim, dim, device="cuda") / math.sqrt(hidden_dim),
"to_out_norm.weight": torch.randn(hidden_dim, device="cuda"),
"to_out_norm.bias": torch.randn(hidden_dim, device="cuda"),
"to_out.weight": torch.randn(dim, hidden_dim, device="cuda") / math.sqrt(dim),
}
return (input_tensor, mask, weights, config)
def check_implementation(data, submission_output, rtol=2e-2, atol=2e-2):
old_matmul = torch.backends.cuda.matmul.allow_tf32
old_cudnn = torch.backends.cudnn.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
try:
ref_output = ref_kernel(data)
if ref_output.shape != submission_output.shape:
return False, f"Shape mismatch: {ref_output.shape} vs {submission_output.shape}"
if torch.allclose(ref_output.float(), submission_output.float(), rtol=rtol, atol=atol):
return True, "Match"
diff = torch.abs(ref_output.float() - submission_output.float())
return False, f"max_diff={diff.max().item():.6f}, avg_diff={diff.mean().item():.6f}"
finally:
torch.backends.cuda.matmul.allow_tf32 = old_matmul
torch.backends.cudnn.allow_tf32 = old_cudnn
'''