| ## Model Structure | |
| ```python | |
| class CausalTimeConv2d(nn.Conv2d): | |
| """ | |
| Input: [B, C=in_ch, H=stocks, W=time] | |
| kernel_size=(ksz,1), dilation=(dil,1), padding=(0,0) # important! | |
| """ | |
| def __init__( | |
| self, | |
| in_channel: int, | |
| out_channel: int, | |
| kernel_size: int = 4, | |
| dilation: int = 1, | |
| bias: bool = False, | |
| ) -> None: | |
| super().__init__( | |
| in_channel, | |
| out_channel, | |
| kernel_size=(1, kernel_size), | |
| stride=(1, 1), | |
| padding=(0, 0), | |
| dilation=(1, dilation), | |
| bias=bias, | |
| ) | |
| self.pad_w = (kernel_size - 1) * dilation | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| if self.pad_w > 0: | |
| input = F.pad(input, (self.pad_w, 0, 0, 0)) | |
| return super().forward(input) | |
| class ParallelTCNBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channel: int, | |
| out_channel: int, | |
| kernel_size: int = 4, | |
| dilation: int = 1, | |
| dropout: float = 0.0, | |
| ) -> None: | |
| super().__init__() | |
| self.conv1 = CausalTimeConv2d( | |
| in_channel, out_channel, kernel_size, dilation, bias=False | |
| ) | |
| self.relu1 = nn.ReLU(inplace=True) | |
| self.conv2 = CausalTimeConv2d( | |
| out_channel, out_channel, kernel_size, dilation, bias=False | |
| ) | |
| self.relu2 = nn.ReLU(inplace=True) | |
| self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity() | |
| self.down = ( | |
| nn.Conv2d(in_channel, out_channel, kernel_size=1, bias=False) | |
| if in_channel != out_channel | |
| else nn.Identity() | |
| ) | |
| def forward(self, x): # x: [B, C, S, T] | |
| y = self.relu1(self.conv1(x)) # width T preserved | |
| y = self.relu2(self.conv2(y)) # width T preserved | |
| y = self.drop(y) | |
| # residual width must match; no extra padding here | |
| res = self.down(x) | |
| # Optional assert to catch shape drift during dev: | |
| # assert y.shape == res.shape, f"{y.shape} vs {res.shape}" | |
| return torch.relu_(y + res) | |
| class TCN(nn.Module): | |
| """ | |
| Parallel TCN over [B, T, S, F]: | |
| - Converts to [B, F, S, T] | |
| - Applies dilated causal Conv2d with kernel (k,1) so each stock is independent but parallel | |
| - Takes the last time step (T) and projects to c_out | |
| """ | |
| def __init__( | |
| self, | |
| enc_in: int, | |
| c_out: int, | |
| d_model: int, | |
| d_ff: int, | |
| e_layers: int, | |
| kernel_size: int = 4, | |
| dropout: float = 0.0, | |
| ) -> None: | |
| super().__init__() | |
| blocks = [] | |
| for i in range(e_layers): | |
| in_ch = enc_in if i == 0 else d_model | |
| dil = 2**i | |
| blocks.append( | |
| ParallelTCNBlock( | |
| in_ch, d_model, kernel_size=kernel_size, dilation=dil, dropout=dropout | |
| ) | |
| ) | |
| self.tcn = nn.Sequential(*blocks) | |
| self.proj = nn.Sequential( | |
| nn.Linear(d_model, d_ff, bias=True), | |
| nn.GELU(), | |
| nn.Linear(d_ff, c_out, bias=True), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| B, T, S, F = x.shape | |
| x = x.permute(0, 3, 2, 1).contiguous() # [b, f, s, t] | |
| y = self.tcn(x) # [B, d_model, S, T] | |
| last = y[:, :, :, -1] # take last time step -> [B, d_model, S] | |
| out = self.proj(last.transpose(1, 2)) # [B, S, c_out] | |
| return out.squeeze(-1) # [B, S] if c_out=1 | |
| ``` | |
| ## Model Config | |
| ```yaml | |
| enc_in: 8 | |
| c_out: 1 | |
| d_model: 64 | |
| d_ff: 64 | |
| e_layers: 2 | |
| kernel_size: 4 | |
| dropout: 0.0 | |
| ``` |