kd13 commited on
Commit
f9af18f
·
verified ·
1 Parent(s): 041c5c8

Update modeling_lstm_seq2seq_en_hi.py

Browse files
Files changed (1) hide show
  1. modeling_lstm_seq2seq_en_hi.py +16 -9
modeling_lstm_seq2seq_en_hi.py CHANGED
@@ -77,6 +77,9 @@ class Encoder(nn.Module):
77
  total_length=input_ids.size(1)
78
  )
79
 
 
 
 
80
  return LSTMEncoderOutput(
81
  last_hidden_state=outputs,
82
  hidden_state=hidden,
@@ -289,16 +292,20 @@ class Seq2SeqHFModel(PreTrainedModel):
289
 
290
  def _merge_bidir_state(self, state, bridge_layer):
291
  num_directions = 2
292
- num_layers_times_dirs, batch_size, hidden_dim = state.size()
 
293
  num_layers = num_layers_times_dirs // num_directions
294
-
295
- state = state.view(num_layers, num_directions, batch_size, hidden_dim)
296
- forward_state = state[:, 0, :, :]
297
- backward_state = state[:, 1, :, :]
298
-
299
- merged = torch.cat([forward_state, backward_state], dim=-1)
300
- merged = bridge_layer(merged)
301
- return torch.tanh(merged)
 
 
 
302
 
303
  def forward(
304
  self,
 
77
  total_length=input_ids.size(1)
78
  )
79
 
80
+ hidden = hidden.transpose(0, 1).contiguous() # (batch, layers*dirs, hidden)
81
+ cell = cell.transpose(0, 1).contiguous() # (batch, layers*dirs, hidden)
82
+
83
  return LSTMEncoderOutput(
84
  last_hidden_state=outputs,
85
  hidden_state=hidden,
 
292
 
293
  def _merge_bidir_state(self, state, bridge_layer):
294
  num_directions = 2
295
+
296
+ batch_size, num_layers_times_dirs, hidden_dim = state.size()
297
  num_layers = num_layers_times_dirs // num_directions
298
+
299
+ state = state.view(batch_size, num_layers, num_directions, hidden_dim)
300
+
301
+ forward_state = state[:, :, 0, :] # (batch, num_layers, hidden)
302
+ backward_state = state[:, :, 1, :] # (batch, num_layers, hidden)
303
+
304
+ merged = torch.cat([forward_state, backward_state], dim=-1) # (batch, num_layers, 2*hidden)
305
+ merged = bridge_layer(merged) # (batch, num_layers, hidden)
306
+ merged = torch.tanh(merged)
307
+
308
+ return merged.transpose(0, 1).contiguous() # (num_layers, batch, hidden)
309
 
310
  def forward(
311
  self,