YYYYYYUUU's picture
Add core reproduction code (binarization layers, PTv3, superpoint ops, min-repro pack)
7b95dc2 verified
Raw
History Blame Contribute Delete
4.67 kB
# pointcept/models/quantization/binary_layers.py (修复mul_cuda uint32 bug)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.nn import Parameter
from torch.utils.cpp_extension import load
import os
# --- JIT Compilation with absolute safety ---
backend_path = os.path.join(os.path.dirname(__file__), 'backend')
CUDA_AVAILABLE = False
try:
binary_gemm_cuda = load(
name='binary_gemm_cuda_v4', # bump when backend sources change (JIT cache)
sources=[os.path.join(backend_path, 'binary_gemm.cpp'),
os.path.join(backend_path, 'binary_gemm_kernel.cu')],
verbose=True,
extra_cflags=['-O3'],
extra_cuda_cflags=['-O3', '-allow-unsupported-compiler']
)
print("INFO: Successfully loaded custom CUDA binary_gemm operator.")
CUDA_AVAILABLE = True
except Exception as e:
print("="*40)
print("WARNING: Failed to compile custom CUDA binary_gemm operator.")
print("The model will fall back to the slower, PyTorch-based simulation mode.")
print(f"Error details: {e}")
print("="*40)
# --- Python-side packing utility (修复uint32 mul bug) ---
def pack_bits(tensor):
tensor_binary = (tensor > 0).to(torch.uint8)
M, N = tensor_binary.shape
N_packed = (N + 31) // 32
padding = N_packed * 32 - N
if padding > 0: tensor_binary = F.pad(tensor_binary, (0, padding), 'constant', 0)
tensor_reshaped = tensor_binary.view(M, N_packed, 32)
powers_of_2 = (2 ** torch.arange(32, dtype=torch.int64, device=tensor.device)).view(1, 1, 32) # 用int64,避免to uint32 mul
packed_tensor = torch.sum(tensor_reshaped.to(torch.int64) * powers_of_2, dim=2).to(torch.uint32) # 【修复】int64 mul + sum,再to uint32
return packed_tensor.contiguous()
# --- Core Binarization Function (Your battle-tested version) ---
class BinaryQuantize(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return torch.sign(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input.gt(1)] = 0
grad_input[input.lt(-1)] = 0
return grad_input
# --- The Ultimate BiLinearLSR with Dual-Engine System ---
class BiLinearLSR(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super(BiLinearLSR, self).__init__(in_features, out_features, bias=bias)
self.register_parameter('scale', Parameter(torch.zeros(1)))
def forward(self, input):
# LSR scale calculation (Your battle-tested logic)
if self.training and self.scale.item() == 0.0:
centered_weight = (self.weight - self.weight.mean()).to(input.device)
full_precision_output_std = F.linear(input, centered_weight).std()
binary_centered_weight = torch.sign(centered_weight)
binary_input_temp = torch.sign(input)
binary_output_std = F.linear(binary_input_temp, binary_centered_weight).std()
scale_factor = full_precision_output_std / (binary_output_std + 1e-8)
if torch.isnan(scale_factor) or torch.isinf(scale_factor):
scale_factor = (centered_weight.std() / binary_centered_weight.std()).float()
self.scale.data.fill_(scale_factor.item())
# --- Engine Selection! Simulation or Real Acceleration? ---
if not self.training and CUDA_AVAILABLE:
# 【Real Acceleration Mode】
binary_input = BinaryQuantize.apply(input)
centered_weight = (self.weight - self.weight.mean()).to(input.device)
binary_weight = BinaryQuantize.apply(centered_weight)
packed_input = pack_bits(binary_input)
packed_weight = pack_bits(binary_weight)
packed_weight_transposed = packed_weight.transpose(0, 1).contiguous()
original_N = binary_input.shape[1]
output = binary_gemm_cuda.forward(packed_input, packed_weight_transposed, original_N)
output = output * self.scale.to(input.device)
else:
# 【Simulation Mode (for training and fallback)】
centered_weight = (self.weight - self.weight.mean()).to(input.device)
binary_weight = BinaryQuantize.apply(centered_weight)
binary_input = BinaryQuantize.apply(input)
scaled_weight = binary_weight * self.scale.to(input.device)
output = F.linear(binary_input, scaled_weight)
if self.bias is not None:
output += self.bias.to(input.device)
return output