Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| class FP32LayerNorm(nn.LayerNorm): | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| origin_dtype = inputs.dtype | |
| return F.layer_norm( | |
| inputs.float(), | |
| self.normalized_shape, | |
| self.weight.float() if self.weight is not None else None, | |
| self.bias.float() if self.bias is not None else None, | |
| self.eps, | |
| ).to(origin_dtype) | |