File size: 6,487 Bytes
a35137b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | ## Source code based on publicly released dilated CNN models as found in
## SimTS model: https://github.com/xingyu617/SimTS_Representation_Learning/blob/main/models/dilation.py
## and
## TS2Vec repo: https://github.com/zhihanyue/ts2vec/blob/main/models/dilated_conv.py
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
def init_weights(m):
"""
Relevant reading material:
https://pytorch.org/docs/stable/nn.init.html
https://github.com/pytorch/vision/blob/309bd7a1512ad9ff0e9729fbdad043cb3472e4cb/torchvision/models/densenet.py#L203
"""
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
class SamePadConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
):
"""Padded convolution to ensure same sized input and output."""
super().__init__()
self.receptive_field = (kernel_size - 1) * dilation + 1
padding = self.receptive_field // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
(1, kernel_size),
padding=(0, padding),
stride=(1, stride),
dilation=(1, dilation),
groups=groups,
)
init_weights(self.conv)
self.remove = 1 if self.receptive_field % 2 == 0 else 0
def forward(self, x):
out = self.conv(x)
if self.remove > 0:
out = out[:, :, :, : -self.remove]
return out
class ConvBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
dilation,
final=False,
enable_checkpointing=False,
):
"""
Convolutional block implementation.
Consists of two convolution layers followed by a residual stream.
Args:
in_channels: int. Input channel count.
out_channels: int. Output channel count.
kernel_size: int. Convolution kernel size.
stride: int. Convolution stride size.
dilation: int. Convolution dilation amount.
final: bool. This is the final convolutional block in the stack. Only relevant for
using a projection head for the residual stream.
enable_checkpointing: bool. Enable checkpointing of the intermediate weights if
desired. Default False.
"""
super().__init__()
self.enable_checkpointing = enable_checkpointing
self.conv1 = SamePadConv(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
)
self.conv2 = SamePadConv(
out_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
)
self.projector = (
nn.Conv2d(
in_channels, out_channels, kernel_size=(1, 1), stride=(1, stride**2),
)
if in_channels != out_channels or final or stride != 1
else None
)
if self.projector is not None:
init_weights(self.projector)
def _forward_mini_block(self, x: torch.tensor, block_num: int):
x = self.conv1(x) if block_num == 1 else self.conv2(x)
x = F.layer_norm(x, (x.shape[-1],))
x = F.gelu(x)
return x
def forward(self, x: torch.tensor):
residual = x if self.projector is None else self.projector(x)
if self.enable_checkpointing:
x = checkpoint(self._forward_mini_block, x, 1, use_reentrant=False)
x = checkpoint(self._forward_mini_block, x, 2, use_reentrant=False)
else:
x = self._forward_mini_block(x, block_num=1)
x = self._forward_mini_block(x, block_num=2)
return x + residual
class DilatedConvEncoder(nn.Module):
def __init__(
self,
in_channels,
channels,
kernel_size,
stride=1,
enable_checkpointing=False,
):
"""Dilated CNN implementation. See ConvBlock for argument definitions."""
super().__init__()
self.enable_checkpointing = enable_checkpointing
self.net = nn.ModuleList(
[
ConvBlock(
channels[i - 1] if i > 0 else in_channels,
channels[i],
kernel_size=kernel_size,
stride=stride,
dilation=2**i,
final=(i == len(channels) - 1),
enable_checkpointing=enable_checkpointing,
)
for i in range(len(channels))
]
)
def forward(self, x: torch.tensor):
for layer in self.net:
x = layer(x)
return x
class TSEncoder2D(nn.Module):
def __init__(
self,
input_dims,
output_dims,
hidden_dims=64,
depth=10,
kernel_size=3,
stride=1,
enable_checkpointing=False,
):
"""
Original source implementation:
TS2Vec Encoder: https://github.com/zhihanyue/ts2vec/blob/main/models/encoder.py
See ConvBlock function for argument definitions.
"""
super().__init__()
self.input_dims = input_dims
self.output_dims = output_dims
self.hidden_dims = hidden_dims
self.enable_checkpointing = enable_checkpointing
self.feature_extractor = DilatedConvEncoder(
input_dims,
[hidden_dims] * depth + [output_dims],
kernel_size=kernel_size,
stride=stride,
enable_checkpointing=self.enable_checkpointing,
)
def forward(self, x: torch.tensor):
"""
Args:
x: torch.tensor of shape (1, 1, B * T * D, N) with time (N) along the last axis.
Note: the additional (1, 1) for the first two axies is to use 2D convs for
1D convolution operations.
Note: B=Batch, T=Number of segments, D=Channels.
Returns:
Temporal encoded version of the input tensor of shape (1, 1, B * T * D, N)
"""
return self.feature_extractor(x)
|