| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This software may be used and distributed in accordance with | |
| # the terms of the DINOv3 License Agreement. | |
| import torch.nn as nn | |
| class LayerNorm2D(nn.Module): | |
| def __init__(self, normalized_shape, norm_layer=nn.LayerNorm): | |
| super().__init__() | |
| self.ln = norm_layer(normalized_shape) if norm_layer is not None else nn.Identity() | |
| def forward(self, x): | |
| """ | |
| x: N C H W | |
| """ | |
| x = x.permute(0, 2, 3, 1) | |
| x = self.ln(x) | |
| x = x.permute(0, 3, 1, 2) | |
| return x | |