File size: 1,320 Bytes
85ba398 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
batch norm done in fp32 (for fp16 training)
"""
import torch
import torch.nn as nn
class Fp32BatchNorm(nn.Module):
def __init__(self, sync=False, *args, **kwargs):
super().__init__()
if sync:
from fairseq.distributed import utils
if utils.get_global_world_size() == 1:
sync = False
if sync:
self.bn = nn.SyncBatchNorm(*args, **kwargs)
else:
self.bn = nn.BatchNorm1d(*args, **kwargs)
self.sync = sync
def forward(self, input):
if self.bn.running_mean.dtype != torch.float:
if self.sync:
self.bn.running_mean = self.bn.running_mean.float()
self.bn.running_var = self.bn.running_var.float()
if self.bn.affine:
try:
self.bn.weight = self.bn.weight.float()
self.bn.bias = self.bn.bias.float()
except:
self.bn.float()
else:
self.bn.float()
output = self.bn(input.float())
return output.type_as(input)
|