tungman commited on
Commit
17082ab
·
verified ·
1 Parent(s): 4bd6312

Upload TransformerSinglestep.py

Browse files
Files changed (1) hide show
  1. TransformerSinglestep.py +50 -0
TransformerSinglestep.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class LearnablePositionalEncoding(nn.Module):
6
+ def __init__(self, d_model, max_len=10000000):
7
+ super().__init__()
8
+ self.pos_embedding = nn.Embedding(max_len, d_model)
9
+ self._init_weights()
10
+
11
+ def _init_weights(self):
12
+ # initialization ที่เหมาะสมสำหรับ positional encoding
13
+ nn.init.uniform_(self.pos_embedding.weight, -0.1, 0.1)
14
+
15
+ def forward(self, x):
16
+ # x.shape = [batch_size, seq_len, d_model]
17
+ seq_len = x.size(1)
18
+ pos = torch.arange(0, seq_len, device=x.device).unsqueeze(0) # [1, seq_len]
19
+ pos_embed = self.pos_embedding(pos) # [1, seq_len, d_model]
20
+ return x + pos_embed
21
+
22
+
23
+ class TransformerSingleStep(nn.Module):
24
+ def __init__(self, input_size, d_model=64, nhead=4, num_layers=2, dropout=0.1, max_len=500):
25
+ super().__init__()
26
+ self.input_fc = nn.Linear(input_size, d_model)
27
+ self.tanh = nn.Tanh()
28
+ self.pos_encoder = LearnablePositionalEncoding(d_model, max_len=max_len)
29
+
30
+ encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=128,
31
+ dropout=dropout, batch_first=True)
32
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
33
+ self.output = nn.Linear(d_model, 2)
34
+
35
+ def forward(self, x):
36
+ # x.shape = [batch_size, seq_len, input_size]
37
+ x = self.input_fc(x)
38
+ x = self.tanh(x)
39
+ x = self.pos_encoder(x)
40
+ x = self.transformer(x) # [B, seq_len, d_model]
41
+
42
+ # ใช้เฉพาะ output ของ time step สุดท้าย
43
+ last_output = x[:, -1, :] # [B, d_model]
44
+
45
+ # output layer แบ่ง high/low
46
+ out = self.output(last_output) # [B, 2] → col0=high, col1=low
47
+ pred_high = out[:, 0].unsqueeze(1) # [B,1]
48
+ pred_low = out[:, 1].unsqueeze(1) # [B,1]
49
+
50
+ return pred_high, pred_low