drixo commited on
Commit
7434d70
·
verified ·
1 Parent(s): 692ad67

Create positional_encoding.py

Browse files
Files changed (1) hide show
  1. positional_encoding.py +20 -0
positional_encoding.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class FramePositionalEncoding(nn.Module):
6
+ def __init__(self, d_model, max_len=4096):
7
+ super().__init__()
8
+ pe = torch.zeros(max_len, d_model)
9
+ position = torch.arange(0, max_len).unsqueeze(1)
10
+ div_term = torch.exp(
11
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
12
+ )
13
+
14
+ pe[:, 0::2] = torch.sin(position * div_term)
15
+ pe[:, 1::2] = torch.cos(position * div_term)
16
+
17
+ self.register_buffer("pe", pe.unsqueeze(0))
18
+
19
+ def forward(self, x):
20
+ return x + self.pe[:, :x.size(1)]