Update README.md
Browse files
README.md
CHANGED
|
@@ -69,30 +69,6 @@ class ParallelTCNBlock(nn.Module):
|
|
| 69 |
return torch.relu_(y + res)
|
| 70 |
|
| 71 |
|
| 72 |
-
class TCNComp(nn.Module):
|
| 73 |
-
def __init__(self, enc_in, d_model, e_layers, kernel_size=4, dropout=0.0):
|
| 74 |
-
super().__init__()
|
| 75 |
-
blocks = []
|
| 76 |
-
for i in range(e_layers):
|
| 77 |
-
in_ch = enc_in if i == 0 else d_model
|
| 78 |
-
dil = 2**i
|
| 79 |
-
blocks.append(
|
| 80 |
-
ParallelTCNBlock(
|
| 81 |
-
in_ch, d_model, kernel_size=kernel_size, dilation=dil, dropout=dropout
|
| 82 |
-
)
|
| 83 |
-
)
|
| 84 |
-
self.tcn = nn.Sequential(*blocks)
|
| 85 |
-
|
| 86 |
-
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 87 |
-
B, T, S, _ = x.shape
|
| 88 |
-
x = x.permute(0, 3, 2, 1).contiguous()
|
| 89 |
-
y = self.tcn(x) # [B, d_model, S, T]
|
| 90 |
-
tcn_out = y.permute(0, 2, 3, 1).reshape(B * S, T, -1)
|
| 91 |
-
last = y[:, :, :, -1].transpose(1, 2) # [B, S, d_model]
|
| 92 |
-
|
| 93 |
-
return tcn_out, last
|
| 94 |
-
|
| 95 |
-
|
| 96 |
class TCN(nn.Module):
|
| 97 |
"""
|
| 98 |
Parallel TCN over [B, T, S, F]:
|
|
|
|
| 69 |
return torch.relu_(y + res)
|
| 70 |
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
class TCN(nn.Module):
|
| 73 |
"""
|
| 74 |
Parallel TCN over [B, T, S, F]:
|