| import torch | |
| import torch.nn as nn | |
| import quant_cuda | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| print('Benchmarking OPT-175B FC2 matvec ...') | |
| DEV = torch.device('cuda:0') | |
| M = 12288 * 4 | |
| N = 12288 | |
| DTYPE = torch.half | |
| mat = torch.randn((M, N), device=DEV, dtype=DTYPE) | |
| vec = torch.randn((1, M), device=DEV, dtype=DTYPE) | |
| mul = torch.zeros((1, N), device=DEV, dtype=DTYPE) | |
| COUNT = 1000 | |
| import time | |
| tick = time.time() | |
| for _ in range(COUNT): | |
| torch.matmul(vec, mat, out=mul) | |
| torch.cuda.synchronize() | |
| print('FP16:', (time.time() - tick) / COUNT) | |
| DTYPE = torch.float | |
| mat = mat.to(DTYPE) | |
| vec = vec.to(DTYPE) | |
| mul = mul.to(DTYPE) | |
| mat = torch.randint(-1000000000, 1000000000, (M // 1024 * 96, N), device=DEV, dtype=torch.int) | |
| scales = torch.randn(N, device=DEV, dtype=DTYPE) | |
| zeros = torch.randn(N, device=DEV, dtype=DTYPE) | |
| COUNT = 1000 | |
| import time | |
| tick = time.time() | |
| for _ in range(COUNT): | |
| quant_cuda.vecquant3matmul(vec, mat, mul, scales, zeros) | |
| torch.cuda.synchronize() | |
| print('3bit:', (time.time() - tick) / COUNT) | |
| COUNT = 1000 | |
| import time | |
| tick = time.time() | |
| for _ in range(COUNT): | |
| quant_cuda.vecquant3matmul_faster(vec, mat, mul, scales, zeros) | |
| torch.cuda.synchronize() | |
| print('3bit:', (time.time() - tick) / COUNT, '(faster)') | |
| print('Verifiying kernel correctness ...') | |
| M = 4 * 4096 | |
| N = 4096 | |
| layer = nn.Linear(M, N) | |
| vec = torch.randn(M).to(DEV) | |
| from quant import * | |
| quantizer = Quantizer() | |
| quantizer.configure(3, perchannel=True, sym=False, mse=False) | |
| quantizer.find_params(layer.weight.data, weight=True) | |
| layer.weight.data = quantize( | |
| layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq | |
| ) | |
| qlayer = Quant3Linear(layer.in_features, layer.out_features) | |
| qlayer.pack(layer, quantizer.scale, quantizer.zero) | |
| qlayer = qlayer.to(DEV) | |
| layer = layer.to(DEV) | |
| with torch.no_grad(): | |
| print('Simu:', layer.to(DEV)(vec)) | |
| print('Kern:', qlayer(vec)) | |
| qlayer.faster = True | |
| print('Kern:', qlayer(vec.half()), '(faster)') | |