| import os |
| import time |
|
|
| import pandas as pd |
| import torch |
| import triton |
| from profiling import clear_memory, current_memory, memory_measure |
|
|
| from boltz.model.layers.pairformer import PairformerLayer |
|
|
| |
| os.environ["CUEQ_DEFAULT_CONFIG"] = "1" |
| os.environ["CUEQ_DISABLE_AOT_TUNING"] = "1" |
|
|
| |
| C_S = 384 |
| C_Z = 128 |
| BATCH_SIZE = 1 |
| INFERENCE = False |
| SEQ_LEN = [128, 256, 384, 512, 768] |
| PRECISION = torch.bfloat16 |
| COMPILE = False |
| device = "cuda:0" |
| torch.set_grad_enabled(not INFERENCE) |
|
|
| |
| model = PairformerLayer(C_S, C_Z, v2=True) |
| model.cuda() |
| if COMPILE: |
| model = torch.compile(model, fullgraph=True, dynamic=False) |
|
|
| if INFERENCE: |
| model.eval() |
|
|
|
|
| def fwd( |
| model, |
| s, |
| z, |
| mask, |
| pair_mask, |
| use_cuequiv_mul=False, |
| use_cuequiv_attn=False, |
| ): |
| model( |
| s, |
| z, |
| mask, |
| pair_mask, |
| use_cuequiv_mul=use_cuequiv_mul, |
| use_cuequiv_attn=use_cuequiv_attn, |
| ) |
|
|
|
|
| def backward( |
| model, |
| s, |
| z, |
| mask, |
| pair_mask, |
| use_cuequiv_mul=False, |
| use_cuequiv_attn=False, |
| ): |
| s, z = model( |
| s, |
| z, |
| mask, |
| pair_mask, |
| use_cuequiv_mul=use_cuequiv_mul, |
| use_cuequiv_attn=use_cuequiv_attn, |
| ) |
| (s.sum() + z.sum()).backward() |
|
|
|
|
| def speed(func, its=10, warmup=10): |
| for _ in range(warmup): |
| func() |
| torch.cuda.synchronize() |
| start = time.time() |
| for _ in range(its): |
| func() |
| torch.cuda.synchronize() |
| time_a = time.time() - start |
| time_a /= its |
| return time_a |
|
|
|
|
| |
| @triton.testing.perf_report( |
| triton.testing.Benchmark( |
| x_names=["size"], |
| x_vals=SEQ_LEN, |
| line_arg="provider", |
| line_vals=[ |
| "Default", |
| "TriAttn", |
| "Trimul", |
| "TriAttn+Trimul", |
| ], |
| line_names=[ |
| "Default", |
| "TriAttn", |
| "Trimul", |
| "TriAttn+Trimul", |
| ], |
| plot_name="performance", |
| args={}, |
| ) |
| ) |
| def benchmark(size, provider): |
| clear_memory(device) |
|
|
| |
| s = torch.randn( |
| (BATCH_SIZE, size, C_S), |
| device=device, |
| dtype=PRECISION, |
| requires_grad=False, |
| ) |
| z = torch.randn( |
| (BATCH_SIZE, size, size, C_Z), |
| device=device, |
| dtype=PRECISION, |
| requires_grad=False, |
| ) |
| mask = torch.ones( |
| (BATCH_SIZE, size), |
| device=device, |
| dtype=PRECISION, |
| requires_grad=False, |
| ).float() |
| pair_mask = torch.ones( |
| (BATCH_SIZE, size, size), |
| device=device, |
| dtype=PRECISION, |
| requires_grad=False, |
| ).float() |
|
|
| with torch.autocast("cuda", dtype=PRECISION): |
| fn = fwd if INFERENCE else backward |
| if provider == "Default": |
| ms = speed( |
| lambda: fn( |
| model, |
| s, |
| z, |
| mask, |
| pair_mask, |
| use_cuequiv_mul=False, |
| use_cuequiv_attn=False, |
| ) |
| ) |
| elif provider == "TriAttn": |
| ms = speed( |
| lambda: fn( |
| model, |
| s, |
| z, |
| mask, |
| pair_mask, |
| use_cuequiv_attn=True, |
| use_cuequiv_mul=False, |
| ) |
| ) |
| elif provider == "Trimul": |
| ms = speed( |
| lambda: fn( |
| model, |
| s, |
| z, |
| mask, |
| pair_mask, |
| use_cuequiv_attn=False, |
| use_cuequiv_mul=True, |
| ) |
| ) |
| elif provider == "TriAttn+Trimul": |
| ms = speed( |
| lambda: fn( |
| model, |
| s, |
| z, |
| mask, |
| pair_mask, |
| use_cuequiv_attn=True, |
| use_cuequiv_mul=True, |
| ) |
| ) |
|
|
| |
| return ms / BATCH_SIZE |
|
|
|
|
| print("Speed") |
| benchmark.run(print_data=True, show_plots=False) |
|
|
| start_mem = current_memory(device) |
|
|
| df = [] |
| for size in SEQ_LEN: |
| print(size) |
| s = torch.randn( |
| (BATCH_SIZE, size, C_S), |
| device=device, |
| dtype=PRECISION, |
| requires_grad=False, |
| ) |
| z = torch.randn( |
| (BATCH_SIZE, size, size, C_Z), |
| device=device, |
| dtype=PRECISION, |
| requires_grad=False, |
| ) |
| mask = torch.ones( |
| (BATCH_SIZE, size), |
| device=device, |
| dtype=PRECISION, |
| requires_grad=False, |
| ).float() |
| pair_mask = torch.ones( |
| (BATCH_SIZE, size, size), |
| device=device, |
| dtype=PRECISION, |
| requires_grad=False, |
| ).float() |
|
|
| with torch.autocast("cuda", dtype=PRECISION): |
| memory_default = memory_measure( |
| lambda: fwd( |
| model, |
| s, |
| z, |
| mask, |
| pair_mask, |
| use_cuequiv_mul=False, |
| use_cuequiv_attn=False, |
| ), |
| device=device, |
| ) |
| memory_attn = memory_measure( |
| lambda: fwd( |
| model, |
| s, |
| z, |
| mask, |
| pair_mask, |
| use_cuequiv_mul=False, |
| use_cuequiv_attn=True, |
| ), |
| device=device, |
| ) |
| memory_mul = memory_measure( |
| lambda: fwd( |
| model, |
| s, |
| z, |
| mask, |
| pair_mask, |
| use_cuequiv_mul=True, |
| use_cuequiv_attn=False, |
| ), |
| device=device, |
| ) |
| memory_flash = memory_measure( |
| lambda: fwd( |
| model, |
| s, |
| z, |
| mask, |
| pair_mask, |
| use_cuequiv_mul=True, |
| use_cuequiv_attn=True, |
| ), |
| device=device, |
| ) |
| df.append( |
| { |
| "size": size, |
| "Default": memory_default - start_mem, |
| "TriAttn": memory_attn - start_mem, |
| "Trimul": memory_mul - start_mem, |
| "TriAttn+Trimul": memory_flash - start_mem, |
| } |
| ) |
|
|
| df = pd.DataFrame(df) |
| print("Memory") |
| print(df) |
|
|