JensLundsgaard commited on
Commit
c43cf99
·
verified ·
1 Parent(s): 939f16b

Upload raffael_conv_lstm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. raffael_conv_lstm.py +163 -0
raffael_conv_lstm.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ConvLSTM Implementation
3
+ True convolutional LSTM for spatiotemporal data processing
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class ConvLSTMCell(nn.Module):
10
+ """Single ConvLSTM Cell"""
11
+
12
+ def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
13
+ super(ConvLSTMCell, self).__init__()
14
+
15
+ self.input_dim = input_dim
16
+ self.hidden_dim = hidden_dim
17
+ self.kernel_size = kernel_size
18
+ self.padding = kernel_size[0] // 2, kernel_size[1] // 2
19
+ self.bias = bias
20
+
21
+ # Input gate, forget gate, output gate, candidate values
22
+ self.conv = nn.Conv2d(
23
+ in_channels=self.input_dim + self.hidden_dim,
24
+ out_channels=4 * self.hidden_dim,
25
+ kernel_size=self.kernel_size,
26
+ padding=self.padding,
27
+ bias=self.bias
28
+ )
29
+
30
+ def forward(self, input_tensor, cur_state):
31
+ h_cur, c_cur = cur_state
32
+
33
+ # Concatenate input and hidden state
34
+ combined = torch.cat([input_tensor, h_cur], dim=1)
35
+
36
+ # Compute all gates
37
+ combined_conv = self.conv(combined)
38
+ cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
39
+
40
+ # Activation functions
41
+ i = torch.sigmoid(cc_i)
42
+ f = torch.sigmoid(cc_f)
43
+ o = torch.sigmoid(cc_o)
44
+ g = torch.tanh(cc_g)
45
+
46
+ # Update cell state and hidden state
47
+ c_next = f * c_cur + i * g
48
+ h_next = o * torch.tanh(c_next)
49
+
50
+ return h_next, c_next
51
+
52
+ def init_hidden(self, batch_size, image_size):
53
+ """Initialize hidden state"""
54
+ height, width = image_size
55
+ return (
56
+ torch.zeros(batch_size, self.hidden_dim, height, width,
57
+ device=self.conv.weight.device),
58
+ torch.zeros(batch_size, self.hidden_dim, height, width,
59
+ device=self.conv.weight.device)
60
+ )
61
+
62
+
63
+ class ConvLSTM(nn.Module):
64
+ """
65
+ ConvLSTM Module
66
+ Supports multiple layers, bidirectional (optional)
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ input_dim,
72
+ hidden_dim,
73
+ kernel_size,
74
+ num_layers=1,
75
+ batch_first=True,
76
+ bias=True,
77
+ return_all_layers=False
78
+ ):
79
+ super(ConvLSTM, self).__init__()
80
+
81
+ self.input_dim = input_dim
82
+ # If hidden_dim is int, convert to list
83
+ if isinstance(hidden_dim, int):
84
+ self.hidden_dim = [hidden_dim] * num_layers
85
+ else:
86
+ self.hidden_dim = hidden_dim
87
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
88
+ self.num_layers = num_layers
89
+ self.batch_first = batch_first
90
+ self.bias = bias
91
+ self.return_all_layers = return_all_layers
92
+
93
+ cell_list = []
94
+ for i in range(self.num_layers):
95
+ cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
96
+ cell_list.append(
97
+ ConvLSTMCell(
98
+ input_dim=cur_input_dim,
99
+ hidden_dim=self.hidden_dim[i],
100
+ kernel_size=self.kernel_size,
101
+ bias=self.bias
102
+ )
103
+ )
104
+ self.cell_list = nn.ModuleList(cell_list)
105
+
106
+ def forward(self, input_tensor, hidden_state=None):
107
+ """
108
+ Args:
109
+ input_tensor: (B, T, C, H, W) if batch_first else (T, B, C, H, W)
110
+ hidden_state: initial hidden state (optional)
111
+
112
+ Returns:
113
+ last_state_list: (h_n, c_n) of last layer
114
+ layer_output_list: outputs of all timesteps
115
+ """
116
+ if not self.batch_first:
117
+ # (T, B, C, H, W) -> (B, T, C, H, W)
118
+ input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
119
+
120
+ b, _, _, h, w = input_tensor.size()
121
+
122
+ # Initialize hidden state
123
+ if hidden_state is None:
124
+ hidden_state = self._init_hidden(batch_size=b, image_size=(h, w))
125
+
126
+ layer_output_list = []
127
+ last_state_list = []
128
+
129
+ seq_len = input_tensor.size(1)
130
+ cur_layer_input = input_tensor
131
+
132
+ for layer_idx in range(self.num_layers):
133
+ h, c = hidden_state[layer_idx]
134
+ output_inner = []
135
+
136
+ for t in range(seq_len):
137
+ h, c = self.cell_list[layer_idx](
138
+ input_tensor=cur_layer_input[:, t, :, :, :],
139
+ cur_state=[h, c]
140
+ )
141
+ output_inner.append(h)
142
+
143
+ layer_output = torch.stack(output_inner, dim=1) # (B, T, C, H, W)
144
+ cur_layer_input = layer_output
145
+
146
+ layer_output_list.append(layer_output)
147
+ last_state_list.append([h, c])
148
+
149
+ if not self.return_all_layers:
150
+ layer_output_list = layer_output_list[-1:]
151
+ last_state_list = last_state_list[-1:]
152
+
153
+ return layer_output_list, last_state_list
154
+
155
+ def _init_hidden(self, batch_size, image_size):
156
+ """Initialize hidden states for all layers"""
157
+ init_states = []
158
+ for i in range(self.num_layers):
159
+ init_states.append(
160
+ self.cell_list[i].init_hidden(batch_size, image_size)
161
+ )
162
+ return init_states
163
+