Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.utils.cpp_extension import load | |
| import os | |
| import time | |
| import random | |
| import math | |
| from torch.utils.checkpoint import checkpoint | |
| from torch.autograd import Function | |
| from functools import partial | |
| import warnings | |
| # curr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extension") | |
| # src_files = ['tdp.cu', 'torch_extension.cpp'] | |
| # src_files = [os.path.join(curr_path, file) for file in src_files] | |
| # tdp = load('tdp', src_files, verbose = True) | |
| # import tdp | |
| def exported_tdp(param0, param1, weight, bias, times, custom = True): | |
| original_shape = param0.shape | |
| param0 = param0.reshape(-1) | |
| param1 = param1.reshape(-1) | |
| weight = weight.reshape(-1) | |
| bias = bias.reshape(-1) | |
| if custom and param0.shape[0] % 2 == 0: | |
| result = TDP.apply(param0, param1, weight, bias, times) | |
| else: | |
| warnings.warn(f'Using slower tdp_torch implementation for a tensor with shape {param0.shape}') | |
| result = tdp_torch(param0, param1, weight, bias, times) | |
| result = result.reshape(*([times.shape[0]] + [d for d in original_shape])) | |
| return result | |
| class TDP(Function): | |
| def forward(ctx, param0, param1, weight, bias, times): | |
| assert param0.shape[0] % 2 == 0 | |
| param0 = param0.contiguous() | |
| param1 = param1.contiguous() | |
| weight = weight.contiguous() | |
| bias = bias.contiguous() | |
| times = times.contiguous() | |
| assert param0.shape[0] == param1.shape[0] and param0.shape[0] == weight.shape[0] and param0.shape[0] == bias.shape[0] | |
| assert param0.dim() == 1 and param1.dim() == 1 and weight.dim() == 1 and bias.dim() == 1 and times.dim() == 1 | |
| ctx.save_for_backward(param0, param1, weight, bias, times) | |
| return tdp_cuda(param0, param1, weight, bias, times) | |
| def backward(ctx, g_result): | |
| g_result = g_result.contiguous() | |
| param0, param1, weight, bias, times = ctx.saved_tensors | |
| g_param0, g_param1, g_weight, g_bias = backward_tdp_cuda(param0, param1, weight, bias, times, g_result) | |
| return g_param0, g_param1, g_weight, g_bias, None | |
| def backward_tdp_torch(param0, param1, weight, bias, times, g_result): | |
| param0 = param0[None] | |
| param1 = param1[None] | |
| weight = weight[None] | |
| bias = bias[None] | |
| a = times[:, None] * weight + bias | |
| s = torch.sigmoid(a) | |
| g_param0 = (s * g_result).sum(0) | |
| g_param1 = ((1 - s) * g_result).sum(0) | |
| g_s = (param0 - param1) * g_result | |
| g_a = g_s * s * (1 - s) | |
| g_weight = (g_a * times[:, None]).sum(0) | |
| g_bias = g_a.sum(0) | |
| return g_param0, g_param1, g_weight, g_bias | |
| def backward_tdp_cuda(param0, param1, weight, bias, times, g_result): | |
| g_param0 = torch.empty_like(param0) | |
| g_param1 = torch.empty_like(param0) | |
| g_weight = torch.empty_like(param0) | |
| g_bias = torch.empty_like(param0) | |
| if param0.dtype == torch.half: | |
| tdp.backward_tdp_fp16(param0, param1, weight, bias, times, g_result, g_param0, g_param1, g_weight, g_bias) | |
| elif param0.dtype == torch.float: | |
| tdp.backward_tdp_fp32(param0, param1, weight, bias, times, g_result, g_param0, g_param1, g_weight, g_bias) | |
| else: | |
| raise NotImplementedError | |
| return g_param0, g_param1, g_weight, g_bias | |
| def tdp_torch(param0, param1, weight, bias, times): | |
| a = torch.addcmul(bias[None], times[:, None], weight[None]) | |
| s = torch.sigmoid(a) | |
| result = torch.addcmul(param1[None], s, param0[None] - param1[None]) | |
| return result | |
| def tdp_cuda(param0, param1, weight, bias, times): | |
| result = torch.empty(times.shape[0], param0.shape[0], dtype = param0.dtype, device = param0.device) | |
| if param0.dtype == torch.half: | |
| tdp.tdp_fp16(param0, param1, weight, bias, times, result) | |
| elif param0.dtype == torch.float: | |
| tdp.tdp_fp32(param0, param1, weight, bias, times, result) | |
| else: | |
| raise NotImplementedError | |
| return result | |
| def corrcoef(x, y): | |
| return torch.corrcoef(torch.stack([x.reshape(-1).float(), y.reshape(-1).float()], dim = 0))[0, 1] | |
| def tdp_cuda_unit_test(): | |
| print("***** tdp_cuda_unit_test *****") | |
| batch_size = random.randrange(1, 128) | |
| num_params = random.randrange(1, 1000000) * 2 | |
| print("batch_size", batch_size, "num_params", num_params) | |
| param0 = torch.randn(num_params).cuda() | |
| param1 = torch.randn(num_params).cuda() | |
| weight = torch.randn(num_params).cuda() | |
| bias = torch.randn(num_params).cuda() | |
| times = torch.rand(batch_size).cuda() | |
| ref = tdp_torch(param0, param1, weight, bias, times) | |
| out = tdp_cuda(param0, param1, weight, bias, times) | |
| print(corrcoef(ref, out), (ref - out).abs().max()) | |
| out = tdp_cuda(param0.half(), param1.half(), weight.half(), bias.half(), times.half()).float() | |
| print(corrcoef(ref, out), (ref - out).abs().max()) | |
| def backward_tdp_cuda_unit_test(): | |
| print("***** backward_tdp_cuda_unit_test *****") | |
| batch_size = random.randrange(1, 128) | |
| num_params = random.randrange(1, 100000) * 2 | |
| print("batch_size", batch_size, "num_params", num_params) | |
| param0 = torch.randn(num_params).cuda() | |
| param1 = torch.randn(num_params).cuda() | |
| weight = torch.randn(num_params).cuda() | |
| bias = torch.randn(num_params).cuda() | |
| times = torch.rand(batch_size).cuda() | |
| g_result = torch.randn(batch_size, num_params).cuda() | |
| refs = backward_tdp_torch(param0, param1, weight, bias, times, g_result) | |
| outs = backward_tdp_cuda(param0, param1, weight, bias, times, g_result) | |
| for r, o in zip(refs, outs): | |
| print(corrcoef(r, o), (r - o).abs().max()) | |
| outs = backward_tdp_cuda(param0.half(), param1.half(), weight.half(), bias.half(), times.half(), g_result.half()) | |
| for r, o in zip(refs, outs): | |
| print(corrcoef(r, o), (r - o).abs().max()) | |
| def autograd_unit_test(): | |
| print("***** autograd_unit_test *****") | |
| batch_size = random.randrange(1, 128) | |
| num_params = random.randrange(1, 100000) * 2 | |
| print("batch_size", batch_size, "num_params", num_params) | |
| def get_outputs(fn): | |
| torch.manual_seed(1) | |
| param0 = torch.randn(num_params, requires_grad = True).cuda() | |
| param1 = torch.randn(num_params, requires_grad = True).cuda() | |
| weight = torch.randn(num_params, requires_grad = True).cuda() | |
| bias = torch.randn(num_params, requires_grad = True).cuda() | |
| times = torch.rand(batch_size).cuda() | |
| out = fn(param0, param1, weight, bias, times) | |
| loss = ((out - 1.5) ** 2).mean() | |
| param0.retain_grad() | |
| param1.retain_grad() | |
| weight.retain_grad() | |
| bias.retain_grad() | |
| loss.backward() | |
| g_param0 = param0.grad | |
| g_param1 = param1.grad | |
| g_weight = weight.grad | |
| g_bias = bias.grad | |
| return out, g_param0, g_param1, g_weight, g_bias | |
| refs = get_outputs(tdp_torch) | |
| outs = get_outputs(TDP.apply) | |
| for r, o in zip(refs, outs): | |
| print(corrcoef(r, o), (r - o).abs().max()) | |
| def exported_tdp_unit_test(): | |
| print("***** exported_tdp_unit_test *****") | |
| batch_size = random.randrange(1, 128) | |
| num_params = random.randrange(1, 100000) * 2 | |
| print("batch_size", batch_size, "num_params", num_params) | |
| def get_outputs(fn): | |
| torch.manual_seed(1) | |
| param0 = torch.randn(num_params, requires_grad = True).cuda() | |
| param1 = torch.randn(num_params, requires_grad = True).cuda() | |
| weight = torch.randn(num_params, requires_grad = True).cuda() | |
| bias = torch.randn(num_params, requires_grad = True).cuda() | |
| times = torch.rand(batch_size).cuda() | |
| out = fn(param0, param1, weight, bias, times) | |
| loss = ((out - 1.5) ** 2).mean() | |
| param0.retain_grad() | |
| param1.retain_grad() | |
| weight.retain_grad() | |
| bias.retain_grad() | |
| loss.backward() | |
| g_param0 = param0.grad | |
| g_param1 = param1.grad | |
| g_weight = weight.grad | |
| g_bias = bias.grad | |
| return out, g_param0, g_param1, g_weight, g_bias | |
| refs = get_outputs(partial(exported_tdp, custom = False)) | |
| outs = get_outputs(partial(exported_tdp, custom = True)) | |
| for r, o in zip(refs, outs): | |
| print(corrcoef(r, o), (r - o).abs().max()) | |
| def tdp_cuda_profile(): | |
| print("***** tdp_cuda_profile *****") | |
| def profiler(fn, args): | |
| for _ in range(10): | |
| fn(*args) | |
| torch.cuda.synchronize() | |
| t0 = time.time() | |
| for _ in range(100): | |
| fn(*args) | |
| torch.cuda.synchronize() | |
| t1 = time.time() | |
| return t1 - t0 | |
| batch_size = 16 | |
| num_params = 1024 * 1024 | |
| print("batch_size", batch_size, "num_params", num_params) | |
| param0 = torch.randn(num_params).cuda() | |
| param1 = torch.randn(num_params).cuda() | |
| weight = torch.randn(num_params).cuda() | |
| bias = torch.randn(num_params).cuda() | |
| times = torch.rand(batch_size).cuda() | |
| print("ref", profiler(tdp_torch, (param0, param1, weight, bias, times))) | |
| print("cuda", profiler(tdp_cuda, (param0, param1, weight, bias, times))) | |
| print("ref", profiler(tdp_torch, (param0.half(), param1.half(), weight.half(), bias.half(), times.half()))) | |
| print("cuda", profiler(tdp_cuda, (param0.half(), param1.half(), weight.half(), bias.half(), times.half()))) | |
| def backward_tdp_cuda_profile(): | |
| print("***** backward_tdp_cuda_profile *****") | |
| def profiler(fn, args): | |
| for _ in range(10): | |
| fn(*args) | |
| torch.cuda.synchronize() | |
| t0 = time.time() | |
| for _ in range(100): | |
| fn(*args) | |
| torch.cuda.synchronize() | |
| t1 = time.time() | |
| return t1 - t0 | |
| batch_size = 16 | |
| num_params = 1024 * 1024 | |
| print("batch_size", batch_size, "num_params", num_params) | |
| param0 = torch.randn(num_params).cuda() | |
| param1 = torch.randn(num_params).cuda() | |
| weight = torch.randn(num_params).cuda() | |
| bias = torch.randn(num_params).cuda() | |
| times = torch.rand(batch_size).cuda() | |
| g_result = torch.randn(batch_size, num_params).cuda() | |
| print("ref", profiler(backward_tdp_torch, (param0, param1, weight, bias, times, g_result))) | |
| print("cuda", profiler(backward_tdp_cuda, (param0, param1, weight, bias, times, g_result))) | |
| print("ref", profiler(backward_tdp_torch, (param0.half(), param1.half(), weight.half(), bias.half(), times.half(), g_result.half()))) | |
| print("cuda", profiler(backward_tdp_cuda, (param0.half(), param1.half(), weight.half(), bias.half(), times.half(), g_result.half()))) | |
| def autogad_profile(): | |
| print("***** autogad_profile *****") | |
| def profiler(fn, args): | |
| for _ in range(10): | |
| fn(*args).mean().backward() | |
| torch.cuda.synchronize() | |
| t0 = time.time() | |
| for _ in range(100): | |
| fn(*args).mean().backward() | |
| torch.cuda.synchronize() | |
| t1 = time.time() | |
| return t1 - t0 | |
| batch_size = 16 | |
| num_params = 1024 * 1024 | |
| print("batch_size", batch_size, "num_params", num_params) | |
| param0 = nn.Parameter(torch.randn(num_params)).cuda() | |
| param1 = nn.Parameter(torch.randn(num_params)).cuda() | |
| weight = nn.Parameter(torch.randn(num_params)).cuda() | |
| bias = nn.Parameter(torch.randn(num_params)).cuda() | |
| times = torch.rand(batch_size).cuda() | |
| print("ref", profiler(tdp_torch, (param0, param1, weight, bias, times))) | |
| print("cuda", profiler(TDP.apply, (param0, param1, weight, bias, times))) | |
| print("ref", profiler(tdp_torch, (param0.half(), param1.half(), weight.half(), bias.half(), times.half()))) | |
| print("cuda", profiler(TDP.apply, (param0.half(), param1.half(), weight.half(), bias.half(), times.half()))) | |
| if __name__ == "__main__": | |
| tdp_cuda_unit_test() | |
| backward_tdp_cuda_unit_test() | |
| autograd_unit_test() | |
| exported_tdp_unit_test() | |
| tdp_cuda_profile() | |
| backward_tdp_cuda_profile() | |
| autogad_profile() |