# pointcept/models/quantization/binary_layers.py (修复mul_cuda uint32 bug) import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function from torch.nn import Parameter from torch.utils.cpp_extension import load import os # --- JIT Compilation with absolute safety --- backend_path = os.path.join(os.path.dirname(__file__), 'backend') CUDA_AVAILABLE = False try: binary_gemm_cuda = load( name='binary_gemm_cuda_v4', # bump when backend sources change (JIT cache) sources=[os.path.join(backend_path, 'binary_gemm.cpp'), os.path.join(backend_path, 'binary_gemm_kernel.cu')], verbose=True, extra_cflags=['-O3'], extra_cuda_cflags=['-O3', '-allow-unsupported-compiler'] ) print("INFO: Successfully loaded custom CUDA binary_gemm operator.") CUDA_AVAILABLE = True except Exception as e: print("="*40) print("WARNING: Failed to compile custom CUDA binary_gemm operator.") print("The model will fall back to the slower, PyTorch-based simulation mode.") print(f"Error details: {e}") print("="*40) # --- Python-side packing utility (修复uint32 mul bug) --- def pack_bits(tensor): tensor_binary = (tensor > 0).to(torch.uint8) M, N = tensor_binary.shape N_packed = (N + 31) // 32 padding = N_packed * 32 - N if padding > 0: tensor_binary = F.pad(tensor_binary, (0, padding), 'constant', 0) tensor_reshaped = tensor_binary.view(M, N_packed, 32) powers_of_2 = (2 ** torch.arange(32, dtype=torch.int64, device=tensor.device)).view(1, 1, 32) # 用int64,避免to uint32 mul packed_tensor = torch.sum(tensor_reshaped.to(torch.int64) * powers_of_2, dim=2).to(torch.uint32) # 【修复】int64 mul + sum,再to uint32 return packed_tensor.contiguous() # --- Core Binarization Function (Your battle-tested version) --- class BinaryQuantize(Function): @staticmethod def forward(ctx, input): ctx.save_for_backward(input) return torch.sign(input) @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors grad_input = grad_output.clone() grad_input[input.gt(1)] = 0 grad_input[input.lt(-1)] = 0 return grad_input # --- The Ultimate BiLinearLSR with Dual-Engine System --- class BiLinearLSR(nn.Linear): def __init__(self, in_features, out_features, bias=True): super(BiLinearLSR, self).__init__(in_features, out_features, bias=bias) self.register_parameter('scale', Parameter(torch.zeros(1))) def forward(self, input): # LSR scale calculation (Your battle-tested logic) if self.training and self.scale.item() == 0.0: centered_weight = (self.weight - self.weight.mean()).to(input.device) full_precision_output_std = F.linear(input, centered_weight).std() binary_centered_weight = torch.sign(centered_weight) binary_input_temp = torch.sign(input) binary_output_std = F.linear(binary_input_temp, binary_centered_weight).std() scale_factor = full_precision_output_std / (binary_output_std + 1e-8) if torch.isnan(scale_factor) or torch.isinf(scale_factor): scale_factor = (centered_weight.std() / binary_centered_weight.std()).float() self.scale.data.fill_(scale_factor.item()) # --- Engine Selection! Simulation or Real Acceleration? --- if not self.training and CUDA_AVAILABLE: # 【Real Acceleration Mode】 binary_input = BinaryQuantize.apply(input) centered_weight = (self.weight - self.weight.mean()).to(input.device) binary_weight = BinaryQuantize.apply(centered_weight) packed_input = pack_bits(binary_input) packed_weight = pack_bits(binary_weight) packed_weight_transposed = packed_weight.transpose(0, 1).contiguous() original_N = binary_input.shape[1] output = binary_gemm_cuda.forward(packed_input, packed_weight_transposed, original_N) output = output * self.scale.to(input.device) else: # 【Simulation Mode (for training and fallback)】 centered_weight = (self.weight - self.weight.mean()).to(input.device) binary_weight = BinaryQuantize.apply(centered_weight) binary_input = BinaryQuantize.apply(input) scaled_weight = binary_weight * self.scale.to(input.device) output = F.linear(binary_input, scaled_weight) if self.bias is not None: output += self.bias.to(input.device) return output