tungman commited on
Commit
346cb38
·
verified ·
1 Parent(s): 4c2b102

Delete TransformerSinglestep.py

Browse files
Files changed (1) hide show
  1. TransformerSinglestep.py +0 -50
TransformerSinglestep.py DELETED
@@ -1,50 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- class LearnablePositionalEncoding(nn.Module):
6
- def __init__(self, d_model, max_len=10000000):
7
- super().__init__()
8
- self.pos_embedding = nn.Embedding(max_len, d_model)
9
- self._init_weights()
10
-
11
- def _init_weights(self):
12
- # initialization ที่เหมาะสมสำหรับ positional encoding
13
- nn.init.uniform_(self.pos_embedding.weight, -0.1, 0.1)
14
-
15
- def forward(self, x):
16
- # x.shape = [batch_size, seq_len, d_model]
17
- seq_len = x.size(1)
18
- pos = torch.arange(0, seq_len, device=x.device).unsqueeze(0) # [1, seq_len]
19
- pos_embed = self.pos_embedding(pos) # [1, seq_len, d_model]
20
- return x + pos_embed
21
-
22
-
23
- class TransformerSingleStep(nn.Module):
24
- def __init__(self, input_size, d_model=64, nhead=4, num_layers=2, dropout=0.1, max_len=500):
25
- super().__init__()
26
- self.input_fc = nn.Linear(input_size, d_model)
27
- self.tanh = nn.Tanh()
28
- self.pos_encoder = LearnablePositionalEncoding(d_model, max_len=max_len)
29
-
30
- encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=128,
31
- dropout=dropout, batch_first=True)
32
- self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
33
- self.output = nn.Linear(d_model, 2)
34
-
35
- def forward(self, x):
36
- # x.shape = [batch_size, seq_len, input_size]
37
- x = self.input_fc(x)
38
- x = self.tanh(x)
39
- x = self.pos_encoder(x)
40
- x = self.transformer(x) # [B, seq_len, d_model]
41
-
42
- # ใช้เฉพาะ output ของ time step สุดท้าย
43
- last_output = x[:, -1, :] # [B, d_model]
44
-
45
- # output layer แบ่ง high/low
46
- out = self.output(last_output) # [B, 2] → col0=high, col1=low
47
- pred_high = out[:, 0].unsqueeze(1) # [B,1]
48
- pred_low = out[:, 1].unsqueeze(1) # [B,1]
49
-
50
- return pred_high, pred_low