| import torch |
|
|
| from kernels.benchmark import Benchmark |
|
|
|
|
| def lsh_weighted_cumulation_reference( |
| query_mask: torch.Tensor, |
| query_hash_code: torch.Tensor, |
| query_weight: torch.Tensor, |
| key_mask: torch.Tensor, |
| key_hash_code: torch.Tensor, |
| key_weight: torch.Tensor, |
| value: torch.Tensor, |
| hashtable_capacity: int, |
| ) -> torch.Tensor: |
| batch_size, num_query, num_hash_f = query_hash_code.shape |
| _, num_key, value_dim = value.shape |
| weight_dim = query_weight.shape[2] |
| device = value.device |
| dtype = value.dtype |
|
|
| output = torch.zeros(batch_size, num_query, value_dim, device=device, dtype=dtype) |
|
|
| for b in range(batch_size): |
| for weight_idx in range(weight_dim): |
| |
| hashtables = torch.zeros( |
| num_hash_f, hashtable_capacity, value_dim, device=device, dtype=dtype |
| ) |
|
|
| k_mask = key_mask[b, :].float() |
| k_weight_val = key_weight[b, :, weight_idx] |
|
|
| for h in range(num_hash_f): |
| k_hash = key_hash_code[b, :, h].long() |
| |
| weighted_values = ( |
| k_mask.unsqueeze(-1) * k_weight_val.unsqueeze(-1) * value[b] |
| ) |
| k_hash_expanded = k_hash.unsqueeze(-1).expand(-1, value_dim) |
| hashtables[h].scatter_add_(0, k_hash_expanded, weighted_values) |
|
|
| |
| q_mask = query_mask[b, :].float() |
| q_weight_val = query_weight[b, :, weight_idx] |
|
|
| sum_val = torch.zeros(num_query, value_dim, device=device, dtype=dtype) |
| for h in range(num_hash_f): |
| q_hash = query_hash_code[b, :, h].long() |
| gathered = hashtables[h][q_hash] |
| sum_val += gathered |
|
|
| |
| output[b] += ( |
| q_mask.unsqueeze(-1) * q_weight_val.unsqueeze(-1) * sum_val / num_hash_f |
| ) |
|
|
| return output |
|
|
|
|
| class YosoBenchmark(Benchmark): |
| seed: int = 42 |
|
|
| def setup(self): |
| batch_size = 2 |
| num_query = 128 |
| num_key = 128 |
| dim = 64 |
| self.num_hash_f = 32 |
| self.hash_code_len = 9 |
| self.weight_dim = self.num_hash_f |
| self.value_dim = dim |
| self.hashtable_capacity = 1 << self.hash_code_len |
|
|
| self.query_mask = torch.ones( |
| batch_size, num_query, device=self.device, dtype=torch.int32 |
| ) |
| self.query_vector = torch.randn( |
| batch_size, num_query, dim, device=self.device, dtype=torch.float32 |
| ) |
| self.key_mask = torch.ones( |
| batch_size, num_key, device=self.device, dtype=torch.int32 |
| ) |
| self.key_vector = torch.randn( |
| batch_size, num_key, dim, device=self.device, dtype=torch.float32 |
| ) |
| self.value = torch.randn( |
| batch_size, num_key, self.value_dim, device=self.device, dtype=torch.float32 |
| ) |
| self.query_weight = torch.randn( |
| batch_size, |
| num_query, |
| self.weight_dim, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
| self.key_weight = torch.randn( |
| batch_size, |
| num_key, |
| self.weight_dim, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
|
|
| |
| hash_result = self.kernel.fast_hash( |
| self.query_mask, |
| self.query_vector, |
| self.key_mask, |
| self.key_vector, |
| self.num_hash_f, |
| self.hash_code_len, |
| True, |
| 1, |
| ) |
| self.query_hash_code = hash_result[0] |
| self.key_hash_code = hash_result[1] |
|
|
| self.out = torch.empty( |
| batch_size, |
| num_query, |
| self.value_dim, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
|
|
| def benchmark_base(self): |
| self.out = self.kernel.lsh_weighted_cumulation( |
| self.query_mask, |
| self.query_hash_code, |
| self.query_weight, |
| self.key_mask, |
| self.key_hash_code, |
| self.key_weight, |
| self.value, |
| self.hashtable_capacity, |
| True, |
| 1, |
| ) |
|
|
| def verify_base(self) -> torch.Tensor: |
| return lsh_weighted_cumulation_reference( |
| self.query_mask, |
| self.query_hash_code, |
| self.query_weight, |
| self.key_mask, |
| self.key_hash_code, |
| self.key_weight, |
| self.value, |
| self.hashtable_capacity, |
| ) |
|
|
| def setup_large(self): |
| batch_size = 4 |
| num_query = 512 |
| num_key = 512 |
| dim = 128 |
| self.num_hash_f = 32 |
| self.hash_code_len = 9 |
| self.weight_dim = self.num_hash_f |
| self.value_dim = dim |
| self.hashtable_capacity = 1 << self.hash_code_len |
|
|
| self.query_mask = torch.ones( |
| batch_size, num_query, device=self.device, dtype=torch.int32 |
| ) |
| self.query_vector = torch.randn( |
| batch_size, num_query, dim, device=self.device, dtype=torch.float32 |
| ) |
| self.key_mask = torch.ones( |
| batch_size, num_key, device=self.device, dtype=torch.int32 |
| ) |
| self.key_vector = torch.randn( |
| batch_size, num_key, dim, device=self.device, dtype=torch.float32 |
| ) |
| self.value = torch.randn( |
| batch_size, num_key, self.value_dim, device=self.device, dtype=torch.float32 |
| ) |
| self.query_weight = torch.randn( |
| batch_size, |
| num_query, |
| self.weight_dim, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
| self.key_weight = torch.randn( |
| batch_size, |
| num_key, |
| self.weight_dim, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
|
|
| hash_result = self.kernel.fast_hash( |
| self.query_mask, |
| self.query_vector, |
| self.key_mask, |
| self.key_vector, |
| self.num_hash_f, |
| self.hash_code_len, |
| True, |
| 1, |
| ) |
| self.query_hash_code = hash_result[0] |
| self.key_hash_code = hash_result[1] |
|
|
| self.out = torch.empty( |
| batch_size, |
| num_query, |
| self.value_dim, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
|
|
| def benchmark_large(self): |
| self.out = self.kernel.lsh_weighted_cumulation( |
| self.query_mask, |
| self.query_hash_code, |
| self.query_weight, |
| self.key_mask, |
| self.key_hash_code, |
| self.key_weight, |
| self.value, |
| self.hashtable_capacity, |
| True, |
| 1, |
| ) |
|
|
| def verify_large(self) -> torch.Tensor: |
| return lsh_weighted_cumulation_reference( |
| self.query_mask, |
| self.query_hash_code, |
| self.query_weight, |
| self.key_mask, |
| self.key_hash_code, |
| self.key_weight, |
| self.value, |
| self.hashtable_capacity, |
| ) |
|
|