1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
970 Bytes
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# pyre-unsafe
import torch
addmm_act_op = torch.ops.aten._addmm_activation
def addmm_act(activation, linear, mat1):
if torch.is_grad_enabled():
raise ValueError("Expected grad to be disabled.")
self = linear.bias.detach()
mat2 = linear.weight.detach()
self = self.to(torch.bfloat16)
mat1 = mat1.to(torch.bfloat16)
mat2 = mat2.to(torch.bfloat16)
mat1_flat = mat1.view(-1, mat1.shape[-1])
if activation in [torch.nn.functional.relu, torch.nn.ReLU]:
y = addmm_act_op(self, mat1_flat, mat2.t(), beta=1, alpha=1, use_gelu=False)
return y.view(mat1.shape[:-1] + (y.shape[-1],))
if activation in [torch.nn.functional.gelu, torch.nn.GELU]:
y = addmm_act_op(self, mat1_flat, mat2.t(), beta=1, alpha=1, use_gelu=True)
return y.view(mat1.shape[:-1] + (y.shape[-1],))
raise ValueError(f"Unexpected activation {activation}")