Abner0803 commited on
Commit
0e05651
·
verified ·
1 Parent(s): 0bb2c9d

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +150 -0
README.md ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Model Structure
2
+
3
+ ```python
4
+ class CausalTimeConv2d(nn.Conv2d):
5
+ """
6
+ Input: [B, C=in_ch, H=stocks, W=time]
7
+ kernel_size=(ksz,1), dilation=(dil,1), padding=(0,0) # important!
8
+ """
9
+
10
+ def __init__(
11
+ self,
12
+ in_channel: int,
13
+ out_channel: int,
14
+ kernel_size: int = 4,
15
+ dilation: int = 1,
16
+ bias: bool = False,
17
+ ) -> None:
18
+ super().__init__(
19
+ in_channel,
20
+ out_channel,
21
+ kernel_size=(1, kernel_size),
22
+ stride=(1, 1),
23
+ padding=(0, 0),
24
+ dilation=(1, dilation),
25
+ bias=bias,
26
+ )
27
+ self.pad_w = (kernel_size - 1) * dilation
28
+
29
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
30
+ if self.pad_w > 0:
31
+ input = F.pad(input, (self.pad_w, 0, 0, 0))
32
+
33
+ return super().forward(input)
34
+
35
+
36
+ class ParallelTCNBlock(nn.Module):
37
+ def __init__(
38
+ self,
39
+ in_channel: int,
40
+ out_channel: int,
41
+ kernel_size: int = 4,
42
+ dilation: int = 1,
43
+ dropout: float = 0.0,
44
+ ) -> None:
45
+ super().__init__()
46
+ self.conv1 = CausalTimeConv2d(
47
+ in_channel, out_channel, kernel_size, dilation, bias=False
48
+ )
49
+ self.relu1 = nn.ReLU(inplace=True)
50
+ self.conv2 = CausalTimeConv2d(
51
+ out_channel, out_channel, kernel_size, dilation, bias=False
52
+ )
53
+ self.relu2 = nn.ReLU(inplace=True)
54
+ self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
55
+ self.down = (
56
+ nn.Conv2d(in_channel, out_channel, kernel_size=1, bias=False)
57
+ if in_channel != out_channel
58
+ else nn.Identity()
59
+ )
60
+
61
+ def forward(self, x): # x: [B, C, S, T]
62
+ y = self.relu1(self.conv1(x)) # width T preserved
63
+ y = self.relu2(self.conv2(y)) # width T preserved
64
+ y = self.drop(y)
65
+ # residual width must match; no extra padding here
66
+ res = self.down(x)
67
+ # Optional assert to catch shape drift during dev:
68
+ # assert y.shape == res.shape, f"{y.shape} vs {res.shape}"
69
+ return torch.relu_(y + res)
70
+
71
+
72
+ class TCNComp(nn.Module):
73
+ def __init__(self, enc_in, d_model, e_layers, kernel_size=4, dropout=0.0):
74
+ super().__init__()
75
+ blocks = []
76
+ for i in range(e_layers):
77
+ in_ch = enc_in if i == 0 else d_model
78
+ dil = 2**i
79
+ blocks.append(
80
+ ParallelTCNBlock(
81
+ in_ch, d_model, kernel_size=kernel_size, dilation=dil, dropout=dropout
82
+ )
83
+ )
84
+ self.tcn = nn.Sequential(*blocks)
85
+
86
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
87
+ B, T, S, _ = x.shape
88
+ x = x.permute(0, 3, 2, 1).contiguous()
89
+ y = self.tcn(x) # [B, d_model, S, T]
90
+ tcn_out = y.permute(0, 2, 3, 1).reshape(B * S, T, -1)
91
+ last = y[:, :, :, -1].transpose(1, 2) # [B, S, d_model]
92
+
93
+ return tcn_out, last
94
+
95
+
96
+ class TCN(nn.Module):
97
+ """
98
+ Parallel TCN over [B, T, S, F]:
99
+ - Converts to [B, F, S, T]
100
+ - Applies dilated causal Conv2d with kernel (k,1) so each stock is independent but parallel
101
+ - Takes the last time step (T) and projects to c_out
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ enc_in: int,
107
+ c_out: int,
108
+ d_model: int,
109
+ d_ff: int,
110
+ e_layers: int,
111
+ kernel_size: int = 4,
112
+ dropout: float = 0.0,
113
+ ) -> None:
114
+ super().__init__()
115
+ blocks = []
116
+ for i in range(e_layers):
117
+ in_ch = enc_in if i == 0 else d_model
118
+ dil = 2**i
119
+ blocks.append(
120
+ ParallelTCNBlock(
121
+ in_ch, d_model, kernel_size=kernel_size, dilation=dil, dropout=dropout
122
+ )
123
+ )
124
+ self.tcn = nn.Sequential(*blocks)
125
+ self.proj = nn.Sequential(
126
+ nn.Linear(d_model, d_ff, bias=True),
127
+ nn.GELU(),
128
+ nn.Linear(d_ff, c_out, bias=True),
129
+ )
130
+
131
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
132
+ B, T, S, F = x.shape
133
+ x = x.permute(0, 3, 2, 1).contiguous() # [b, f, s, t]
134
+ y = self.tcn(x) # [B, d_model, S, T]
135
+ last = y[:, :, :, -1] # take last time step -> [B, d_model, S]
136
+ out = self.proj(last.transpose(1, 2)) # [B, S, c_out]
137
+ return out.squeeze(-1) # [B, S] if c_out=1
138
+ ```
139
+
140
+ ## Model Config
141
+
142
+ ```yaml
143
+ enc_in: 8
144
+ c_out: 1
145
+ d_model: 64
146
+ d_ff: 64
147
+ e_layers: 2
148
+ kernel_size: 4
149
+ dropout: 0.0
150
+ ```