Update modeling_lstm_seq2seq_en_hi.py
Browse files
modeling_lstm_seq2seq_en_hi.py
CHANGED
|
@@ -77,6 +77,9 @@ class Encoder(nn.Module):
|
|
| 77 |
total_length=input_ids.size(1)
|
| 78 |
)
|
| 79 |
|
|
|
|
|
|
|
|
|
|
| 80 |
return LSTMEncoderOutput(
|
| 81 |
last_hidden_state=outputs,
|
| 82 |
hidden_state=hidden,
|
|
@@ -289,16 +292,20 @@ class Seq2SeqHFModel(PreTrainedModel):
|
|
| 289 |
|
| 290 |
def _merge_bidir_state(self, state, bridge_layer):
|
| 291 |
num_directions = 2
|
| 292 |
-
|
|
|
|
| 293 |
num_layers = num_layers_times_dirs // num_directions
|
| 294 |
-
|
| 295 |
-
state = state.view(num_layers, num_directions,
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
merged =
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
def forward(
|
| 304 |
self,
|
|
|
|
| 77 |
total_length=input_ids.size(1)
|
| 78 |
)
|
| 79 |
|
| 80 |
+
hidden = hidden.transpose(0, 1).contiguous() # (batch, layers*dirs, hidden)
|
| 81 |
+
cell = cell.transpose(0, 1).contiguous() # (batch, layers*dirs, hidden)
|
| 82 |
+
|
| 83 |
return LSTMEncoderOutput(
|
| 84 |
last_hidden_state=outputs,
|
| 85 |
hidden_state=hidden,
|
|
|
|
| 292 |
|
| 293 |
def _merge_bidir_state(self, state, bridge_layer):
|
| 294 |
num_directions = 2
|
| 295 |
+
|
| 296 |
+
batch_size, num_layers_times_dirs, hidden_dim = state.size()
|
| 297 |
num_layers = num_layers_times_dirs // num_directions
|
| 298 |
+
|
| 299 |
+
state = state.view(batch_size, num_layers, num_directions, hidden_dim)
|
| 300 |
+
|
| 301 |
+
forward_state = state[:, :, 0, :] # (batch, num_layers, hidden)
|
| 302 |
+
backward_state = state[:, :, 1, :] # (batch, num_layers, hidden)
|
| 303 |
+
|
| 304 |
+
merged = torch.cat([forward_state, backward_state], dim=-1) # (batch, num_layers, 2*hidden)
|
| 305 |
+
merged = bridge_layer(merged) # (batch, num_layers, hidden)
|
| 306 |
+
merged = torch.tanh(merged)
|
| 307 |
+
|
| 308 |
+
return merged.transpose(0, 1).contiguous() # (num_layers, batch, hidden)
|
| 309 |
|
| 310 |
def forward(
|
| 311 |
self,
|