| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.autograd import Function |
| from torch.nn import Parameter |
| from torch.utils.cpp_extension import load |
| import os |
|
|
| |
| backend_path = os.path.join(os.path.dirname(__file__), 'backend') |
| CUDA_AVAILABLE = False |
| try: |
| binary_gemm_cuda = load( |
| name='binary_gemm_cuda_v4', |
| sources=[os.path.join(backend_path, 'binary_gemm.cpp'), |
| os.path.join(backend_path, 'binary_gemm_kernel.cu')], |
| verbose=True, |
| extra_cflags=['-O3'], |
| extra_cuda_cflags=['-O3', '-allow-unsupported-compiler'] |
| ) |
| print("INFO: Successfully loaded custom CUDA binary_gemm operator.") |
| CUDA_AVAILABLE = True |
| except Exception as e: |
| print("="*40) |
| print("WARNING: Failed to compile custom CUDA binary_gemm operator.") |
| print("The model will fall back to the slower, PyTorch-based simulation mode.") |
| print(f"Error details: {e}") |
| print("="*40) |
|
|
| |
| def pack_bits(tensor): |
| tensor_binary = (tensor > 0).to(torch.uint8) |
| M, N = tensor_binary.shape |
| N_packed = (N + 31) // 32 |
| padding = N_packed * 32 - N |
| if padding > 0: tensor_binary = F.pad(tensor_binary, (0, padding), 'constant', 0) |
| tensor_reshaped = tensor_binary.view(M, N_packed, 32) |
| powers_of_2 = (2 ** torch.arange(32, dtype=torch.int64, device=tensor.device)).view(1, 1, 32) |
| packed_tensor = torch.sum(tensor_reshaped.to(torch.int64) * powers_of_2, dim=2).to(torch.uint32) |
| return packed_tensor.contiguous() |
|
|
| |
| class BinaryQuantize(Function): |
| @staticmethod |
| def forward(ctx, input): |
| ctx.save_for_backward(input) |
| return torch.sign(input) |
| @staticmethod |
| def backward(ctx, grad_output): |
| input, = ctx.saved_tensors |
| grad_input = grad_output.clone() |
| grad_input[input.gt(1)] = 0 |
| grad_input[input.lt(-1)] = 0 |
| return grad_input |
|
|
| |
| class BiLinearLSR(nn.Linear): |
| def __init__(self, in_features, out_features, bias=True): |
| super(BiLinearLSR, self).__init__(in_features, out_features, bias=bias) |
| self.register_parameter('scale', Parameter(torch.zeros(1))) |
|
|
| def forward(self, input): |
| |
| if self.training and self.scale.item() == 0.0: |
| centered_weight = (self.weight - self.weight.mean()).to(input.device) |
| full_precision_output_std = F.linear(input, centered_weight).std() |
| binary_centered_weight = torch.sign(centered_weight) |
| binary_input_temp = torch.sign(input) |
| binary_output_std = F.linear(binary_input_temp, binary_centered_weight).std() |
| scale_factor = full_precision_output_std / (binary_output_std + 1e-8) |
| if torch.isnan(scale_factor) or torch.isinf(scale_factor): |
| scale_factor = (centered_weight.std() / binary_centered_weight.std()).float() |
| self.scale.data.fill_(scale_factor.item()) |
|
|
| |
| if not self.training and CUDA_AVAILABLE: |
| |
| binary_input = BinaryQuantize.apply(input) |
| centered_weight = (self.weight - self.weight.mean()).to(input.device) |
| binary_weight = BinaryQuantize.apply(centered_weight) |
| packed_input = pack_bits(binary_input) |
| packed_weight = pack_bits(binary_weight) |
| packed_weight_transposed = packed_weight.transpose(0, 1).contiguous() |
| original_N = binary_input.shape[1] |
| output = binary_gemm_cuda.forward(packed_input, packed_weight_transposed, original_N) |
| output = output * self.scale.to(input.device) |
| else: |
| |
| centered_weight = (self.weight - self.weight.mean()).to(input.device) |
| binary_weight = BinaryQuantize.apply(centered_weight) |
| binary_input = BinaryQuantize.apply(input) |
| scaled_weight = binary_weight * self.scale.to(input.device) |
| output = F.linear(binary_input, scaled_weight) |
| |
| if self.bias is not None: |
| output += self.bias.to(input.device) |
| |
| return output |