DarthReca commited on
Commit
aa02dab
·
verified ·
1 Parent(s): c85c48f

Create convlstm.py

Browse files
Files changed (1) hide show
  1. convlstm.py +209 -0
convlstm.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Seyong Kim
2
+
3
+ from typing import Any, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch import Tensor, nn, sigmoid, tanh
7
+
8
+
9
+ class ConvGate(nn.Module):
10
+ def __init__(
11
+ self,
12
+ in_channels: int,
13
+ hidden_channels: int,
14
+ kernel_size: Union[Tuple[int, int], int],
15
+ padding: Union[Tuple[int, int], int],
16
+ stride: Union[Tuple[int, int], int],
17
+ bias: bool,
18
+ ):
19
+ super(ConvGate, self).__init__()
20
+ self.conv_x = nn.Conv2d(
21
+ in_channels=in_channels,
22
+ out_channels=hidden_channels * 4,
23
+ kernel_size=kernel_size,
24
+ padding=padding,
25
+ stride=stride,
26
+ bias=bias,
27
+ )
28
+ self.conv_h = nn.Conv2d(
29
+ in_channels=hidden_channels,
30
+ out_channels=hidden_channels * 4,
31
+ kernel_size=kernel_size,
32
+ padding=padding,
33
+ stride=stride,
34
+ bias=bias,
35
+ )
36
+ self.bn2d = nn.BatchNorm2d(hidden_channels * 4)
37
+
38
+ def forward(self, x, hidden_state):
39
+ gated = self.conv_x(x) + self.conv_h(hidden_state)
40
+ return self.bn2d(gated)
41
+
42
+
43
+ class ConvLSTMCell(nn.Module):
44
+ def __init__(
45
+ self, in_channels, hidden_channels, kernel_size, padding, stride, bias
46
+ ):
47
+ super().__init__()
48
+ # To check the model structure with tools such as torchinfo, need to wrap
49
+ # the custom module with nn.ModuleList
50
+ self.gates = nn.ModuleList(
51
+ [ConvGate(in_channels, hidden_channels, kernel_size, padding, stride, bias)]
52
+ )
53
+
54
+ def forward(
55
+ self, x: Tensor, hidden_state: Tensor, cell_state: Tensor
56
+ ) -> Tuple[Tensor, Tensor]:
57
+ gated = self.gates[0](x, hidden_state)
58
+ i_gated, f_gated, c_gated, o_gated = gated.chunk(4, dim=1)
59
+
60
+ i_gated = sigmoid(i_gated)
61
+ f_gated = sigmoid(f_gated)
62
+ o_gated = sigmoid(o_gated)
63
+
64
+ cell_state = f_gated.mul(cell_state) + i_gated.mul(tanh(c_gated))
65
+ hidden_state = o_gated.mul(tanh(cell_state))
66
+
67
+ return hidden_state, cell_state
68
+
69
+
70
+ class ConvLSTM(nn.Module):
71
+ """ConvLSTM module"""
72
+
73
+ def __init__(
74
+ self,
75
+ in_channels,
76
+ hidden_channels,
77
+ kernel_size,
78
+ padding,
79
+ stride,
80
+ bias,
81
+ batch_first,
82
+ bidirectional,
83
+ ):
84
+ super().__init__()
85
+ self.in_channels = in_channels
86
+ self.hidden_channels = hidden_channels
87
+ self.bidirectional = bidirectional
88
+ self.batch_first = batch_first
89
+
90
+ # To check the model structure with tools such as torchinfo, need to wrap
91
+ # the custom module with nn.ModuleList
92
+ self.conv_lstm_cells = nn.ModuleList(
93
+ [
94
+ ConvLSTMCell(
95
+ in_channels, hidden_channels, kernel_size, padding, stride, bias
96
+ )
97
+ ]
98
+ )
99
+
100
+ if self.bidirectional:
101
+ self.conv_lstm_cells.append(
102
+ ConvLSTMCell(
103
+ in_channels, hidden_channels, kernel_size, padding, stride, bias
104
+ )
105
+ )
106
+
107
+ self.batch_size = None
108
+ self.seq_len = None
109
+ self.height = None
110
+ self.width = None
111
+
112
+ def forward(
113
+ self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]] = None
114
+ ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
115
+ # size of x: B, T, C, H, W or T, B, C, H, W
116
+ x = self._check_shape(x)
117
+ hidden_state, cell_state, backward_hidden_state, backward_cell_state = (
118
+ self.init_state(x, state)
119
+ )
120
+
121
+ output, hidden_state, cell_state = self._forward(
122
+ self.conv_lstm_cells[0], x, hidden_state, cell_state
123
+ )
124
+
125
+ if self.bidirectional:
126
+ x = torch.flip(x, [1])
127
+ backward_output, backward_hidden_state, backward_cell_state = self._forward(
128
+ self.conv_lstm_cells[1], x, backward_hidden_state, backward_cell_state
129
+ )
130
+
131
+ output = torch.cat([output, backward_output], dim=-3)
132
+ hidden_state = torch.cat([hidden_state, backward_hidden_state], dim=-1)
133
+ cell_state = torch.cat([cell_state, backward_cell_state], dim=-1)
134
+ return output, (hidden_state, cell_state)
135
+
136
+ def _forward(self, lstm_cell, x, hidden_state, cell_state):
137
+ outputs = []
138
+ for time_step in range(self.seq_len):
139
+ x_t = x[:, time_step, :, :, :]
140
+ hidden_state, cell_state = lstm_cell(x_t, hidden_state, cell_state)
141
+ outputs.append(hidden_state.detach())
142
+ output = torch.stack(outputs, dim=1)
143
+ return output, hidden_state, cell_state
144
+
145
+ def _check_shape(self, x: Tensor) -> Tensor:
146
+ if self.batch_first:
147
+ batch_size, self.seq_len = x.shape[0], x.shape[1]
148
+ else:
149
+ batch_size, self.seq_len = x.shape[1], x.shape[0]
150
+ x = x.permute(1, 0, 2, 3)
151
+ x = torch.swapaxes(x, 0, 1)
152
+
153
+ self.height = x.shape[-2]
154
+ self.width = x.shape[-1]
155
+
156
+ dim = len(x.shape)
157
+
158
+ if dim == 4:
159
+ x = x.unsqueeze(dim=1) # increase dimension
160
+ x = x.view(batch_size, self.seq_len, -1, self.height, self.width)
161
+ x = x.contiguous() # Reassign memory location
162
+ elif dim <= 3:
163
+ raise ValueError(
164
+ f"Got {len(x.shape)} dimensional tensor. Input shape unmatched"
165
+ )
166
+
167
+ return x
168
+
169
+ def init_state(
170
+ self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]]
171
+ ) -> Tuple[Union[Tensor, Any], Union[Tensor, Any], Optional[Any], Optional[Any]]:
172
+ # If state doesn't enter as input, initialize state to zeros
173
+ backward_hidden_state, backward_cell_state = None, None
174
+
175
+ if state is None:
176
+ self.batch_size = x.shape[0]
177
+ hidden_state, cell_state = self._init_state(x.dtype, x.device)
178
+
179
+ if self.bidirectional:
180
+ backward_hidden_state, backward_cell_state = self._init_state(
181
+ x.dtype, x.device
182
+ )
183
+ else:
184
+ if self.bidirectional:
185
+ hidden_state, hidden_state_back = state[0].chunk(2, dim=-1)
186
+ cell_state, cell_state_back = state[1].chunk(2, dim=-1)
187
+ else:
188
+ hidden_state, cell_state = state
189
+
190
+ return hidden_state, cell_state, backward_hidden_state, backward_cell_state
191
+
192
+ def _init_state(self, dtype, device):
193
+ self.register_buffer(
194
+ "hidden_state",
195
+ torch.zeros(
196
+ (1, self.hidden_channels, self.height, self.width),
197
+ dtype=dtype,
198
+ device=device,
199
+ ),
200
+ )
201
+ self.register_buffer(
202
+ "cell_state",
203
+ torch.zeros(
204
+ (1, self.hidden_channels, self.height, self.width),
205
+ dtype=dtype,
206
+ device=device,
207
+ ),
208
+ )
209
+ return self.hidden_state, self.cell_state