yoso / benchmarks /benchmark.py
danieldk's picture
danieldk HF Staff
Benchmarks uploaded using `kernels`.
e32aaaa verified
import torch
from kernels.benchmark import Benchmark
def lsh_weighted_cumulation_reference(
query_mask: torch.Tensor,
query_hash_code: torch.Tensor,
query_weight: torch.Tensor,
key_mask: torch.Tensor,
key_hash_code: torch.Tensor,
key_weight: torch.Tensor,
value: torch.Tensor,
hashtable_capacity: int,
) -> torch.Tensor:
batch_size, num_query, num_hash_f = query_hash_code.shape
_, num_key, value_dim = value.shape
weight_dim = query_weight.shape[2]
device = value.device
dtype = value.dtype
output = torch.zeros(batch_size, num_query, value_dim, device=device, dtype=dtype)
for b in range(batch_size):
for weight_idx in range(weight_dim):
# Build hashtables for all hash functions
hashtables = torch.zeros(
num_hash_f, hashtable_capacity, value_dim, device=device, dtype=dtype
)
k_mask = key_mask[b, :].float() # [num_key]
k_weight_val = key_weight[b, :, weight_idx] # [num_key]
for h in range(num_hash_f):
k_hash = key_hash_code[b, :, h].long() # [num_key]
# Weighted values: [num_key, value_dim]
weighted_values = (
k_mask.unsqueeze(-1) * k_weight_val.unsqueeze(-1) * value[b]
)
k_hash_expanded = k_hash.unsqueeze(-1).expand(-1, value_dim)
hashtables[h].scatter_add_(0, k_hash_expanded, weighted_values)
# Query: sum over all hash functions
q_mask = query_mask[b, :].float() # [num_query]
q_weight_val = query_weight[b, :, weight_idx] # [num_query]
sum_val = torch.zeros(num_query, value_dim, device=device, dtype=dtype)
for h in range(num_hash_f):
q_hash = query_hash_code[b, :, h].long() # [num_query]
gathered = hashtables[h][q_hash] # [num_query, value_dim]
sum_val += gathered
# Apply query weight and divide by num_hash_f
output[b] += (
q_mask.unsqueeze(-1) * q_weight_val.unsqueeze(-1) * sum_val / num_hash_f
)
return output
class YosoBenchmark(Benchmark):
seed: int = 42
def setup(self):
batch_size = 2
num_query = 128
num_key = 128
dim = 64
self.num_hash_f = 32
self.hash_code_len = 9
self.weight_dim = self.num_hash_f
self.value_dim = dim
self.hashtable_capacity = 1 << self.hash_code_len
self.query_mask = torch.ones(
batch_size, num_query, device=self.device, dtype=torch.int32
)
self.query_vector = torch.randn(
batch_size, num_query, dim, device=self.device, dtype=torch.float32
)
self.key_mask = torch.ones(
batch_size, num_key, device=self.device, dtype=torch.int32
)
self.key_vector = torch.randn(
batch_size, num_key, dim, device=self.device, dtype=torch.float32
)
self.value = torch.randn(
batch_size, num_key, self.value_dim, device=self.device, dtype=torch.float32
)
self.query_weight = torch.randn(
batch_size,
num_query,
self.weight_dim,
device=self.device,
dtype=torch.float32,
)
self.key_weight = torch.randn(
batch_size,
num_key,
self.weight_dim,
device=self.device,
dtype=torch.float32,
)
# Pre-compute hash codes for cumulation benchmarks
hash_result = self.kernel.fast_hash(
self.query_mask,
self.query_vector,
self.key_mask,
self.key_vector,
self.num_hash_f,
self.hash_code_len,
True,
1,
)
self.query_hash_code = hash_result[0]
self.key_hash_code = hash_result[1]
self.out = torch.empty(
batch_size,
num_query,
self.value_dim,
device=self.device,
dtype=torch.float32,
)
def benchmark_base(self):
self.out = self.kernel.lsh_weighted_cumulation(
self.query_mask,
self.query_hash_code,
self.query_weight,
self.key_mask,
self.key_hash_code,
self.key_weight,
self.value,
self.hashtable_capacity,
True,
1,
)
def verify_base(self) -> torch.Tensor:
return lsh_weighted_cumulation_reference(
self.query_mask,
self.query_hash_code,
self.query_weight,
self.key_mask,
self.key_hash_code,
self.key_weight,
self.value,
self.hashtable_capacity,
)
def setup_large(self):
batch_size = 4
num_query = 512
num_key = 512
dim = 128
self.num_hash_f = 32
self.hash_code_len = 9
self.weight_dim = self.num_hash_f
self.value_dim = dim
self.hashtable_capacity = 1 << self.hash_code_len
self.query_mask = torch.ones(
batch_size, num_query, device=self.device, dtype=torch.int32
)
self.query_vector = torch.randn(
batch_size, num_query, dim, device=self.device, dtype=torch.float32
)
self.key_mask = torch.ones(
batch_size, num_key, device=self.device, dtype=torch.int32
)
self.key_vector = torch.randn(
batch_size, num_key, dim, device=self.device, dtype=torch.float32
)
self.value = torch.randn(
batch_size, num_key, self.value_dim, device=self.device, dtype=torch.float32
)
self.query_weight = torch.randn(
batch_size,
num_query,
self.weight_dim,
device=self.device,
dtype=torch.float32,
)
self.key_weight = torch.randn(
batch_size,
num_key,
self.weight_dim,
device=self.device,
dtype=torch.float32,
)
hash_result = self.kernel.fast_hash(
self.query_mask,
self.query_vector,
self.key_mask,
self.key_vector,
self.num_hash_f,
self.hash_code_len,
True,
1,
)
self.query_hash_code = hash_result[0]
self.key_hash_code = hash_result[1]
self.out = torch.empty(
batch_size,
num_query,
self.value_dim,
device=self.device,
dtype=torch.float32,
)
def benchmark_large(self):
self.out = self.kernel.lsh_weighted_cumulation(
self.query_mask,
self.query_hash_code,
self.query_weight,
self.key_mask,
self.key_hash_code,
self.key_weight,
self.value,
self.hashtable_capacity,
True,
1,
)
def verify_large(self) -> torch.Tensor:
return lsh_weighted_cumulation_reference(
self.query_mask,
self.query_hash_code,
self.query_weight,
self.key_mask,
self.key_hash_code,
self.key_weight,
self.value,
self.hashtable_capacity,
)