| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import importlib |
| import numbers |
| import os |
| import sys |
| import time |
|
|
| import torch |
| from torch.nn.parameter import Parameter |
|
|
| sys.path.append(os.path.dirname(__file__)) |
|
|
| try: |
| fastfold_layer_norm_cuda = importlib.import_module("fastfold_layer_norm_cuda") |
| except ImportError: |
| from protenix.model.layer_norm.torch_ext_compile import compile |
|
|
| current_dir = os.path.dirname(__file__) |
| fastfold_layer_norm_cuda = compile( |
| name="fastfold_layer_norm_cuda", |
| sources=[ |
| os.path.join(f"{current_dir}/kernel", file) |
| for file in ["layer_norm_cuda.cpp", "layer_norm_cuda_kernel.cu"] |
| ], |
| extra_include_paths=[f"{current_dir}/kernel"], |
| build_directory=current_dir, |
| ) |
|
|
|
|
| class FusedLayerNormAffineFunction(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward(ctx, input, weight, bias, normalized_shape, eps): |
| d = input.dtype |
| if d is torch.bfloat16: |
| with torch.cuda.amp.autocast(enabled=False): |
| ctx.normalized_shape = normalized_shape |
| ctx.eps = eps |
| input_ = input.contiguous() |
| weight_ = weight.contiguous().to(dtype=d) |
| bias_ = bias.contiguous().to(dtype=d) |
| output, mean, invvar = fastfold_layer_norm_cuda.forward_affine( |
| input_, ctx.normalized_shape, weight_, bias_, ctx.eps |
| ) |
| ctx.save_for_backward(input_, weight_, bias_, mean, invvar) |
| else: |
| ctx.normalized_shape = normalized_shape |
| ctx.eps = eps |
| input_ = input.contiguous() |
| weight_ = weight.contiguous() |
| bias_ = bias.contiguous() |
| output, mean, invvar = fastfold_layer_norm_cuda.forward_affine( |
| input_, ctx.normalized_shape, weight_, bias_, ctx.eps |
| ) |
| ctx.save_for_backward(input_, weight_, bias_, mean, invvar) |
|
|
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| d = grad_output.dtype |
| if d is torch.bfloat16: |
| with torch.cuda.amp.autocast(enabled=False): |
| input_, weight_, bias_, mean, invvar = ctx.saved_tensors |
| grad_input = grad_weight = grad_bias = None |
| grad_input, grad_weight, grad_bias = ( |
| fastfold_layer_norm_cuda.backward_affine( |
| grad_output.contiguous(), |
| mean, |
| invvar, |
| input_, |
| ctx.normalized_shape, |
| weight_.to(dtype=d), |
| bias_.to(dtype=d), |
| ctx.eps, |
| ) |
| ) |
| else: |
| input_, weight_, bias_, mean, invvar = ctx.saved_tensors |
| grad_input = grad_weight = grad_bias = None |
| grad_input, grad_weight, grad_bias = ( |
| fastfold_layer_norm_cuda.backward_affine( |
| grad_output.contiguous(), |
| mean, |
| invvar, |
| input_, |
| ctx.normalized_shape, |
| weight_, |
| bias_, |
| ctx.eps, |
| ) |
| ) |
|
|
| return grad_input, grad_weight, grad_bias, None, None |
|
|
|
|
| class FusedLayerNorm(torch.nn.Module): |
|
|
| def __init__(self, normalized_shape, eps=1e-5): |
| super(FusedLayerNorm, self).__init__() |
|
|
| if isinstance(normalized_shape, numbers.Integral): |
| normalized_shape = (normalized_shape,) |
| self.normalized_shape = torch.Size(normalized_shape) |
| self.eps = eps |
| self.weight = Parameter(torch.ones(*normalized_shape)) |
| self.bias = Parameter(torch.ones(*normalized_shape)) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| torch.nn.init.ones_(self.weight) |
| torch.nn.init.zeros_(self.bias) |
|
|
| def forward(self, input): |
| return self.kernel_forward(input) |
|
|
| def kernel_forward(self, input): |
| return FusedLayerNormAffineFunction.apply( |
| input, self.weight, self.bias, self.normalized_shape, self.eps |
| ) |
|
|