File size: 529 Bytes
14f592d
 
 
 
17c0b30
14f592d
 
 
17c0b30
 
14f592d
 
 
 
17c0b30
14f592d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn as nn

class UstaLayerNorm(nn.Module):
    def __init__(self,embedding_dim,eps=1e-5,device="cpu"):
        super().__init__()
        self.eps = eps

        self.weight = nn.Parameter(torch.ones(embedding_dim,device=device))
        self.device = device
    

    def forward(self,x):
        mean = x.mean(dim=-1,keepdim=True)
        variance = x.var(dim=-1,keepdim=True,unbiased=False)
        normalized_x = (x - mean) / torch.sqrt(variance + self.eps)
        return self.weight * normalized_x