Update gigaam_transformers.py
Browse files- gigaam_transformers.py +2 -1
gigaam_transformers.py
CHANGED
|
@@ -6,6 +6,7 @@ import torch.nn as nn
|
|
| 6 |
import torchaudio
|
| 7 |
from .encoder import ConformerEncoder
|
| 8 |
from torch import Tensor
|
|
|
|
| 9 |
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
|
| 10 |
from transformers.configuration_utils import PretrainedConfig
|
| 11 |
from transformers.feature_extraction_sequence_utils import \
|
|
@@ -445,4 +446,4 @@ class GigaAMRNNTHF(PreTrainedModel):
|
|
| 445 |
for i in range(b):
|
| 446 |
inseq = encoder_out[i, :, :].unsqueeze(1)
|
| 447 |
preds.append(self._greedy_decode(inseq, encoded_lengths[i]))
|
| 448 |
-
return
|
|
|
|
| 6 |
import torchaudio
|
| 7 |
from .encoder import ConformerEncoder
|
| 8 |
from torch import Tensor
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor
|
| 11 |
from transformers.configuration_utils import PretrainedConfig
|
| 12 |
from transformers.feature_extraction_sequence_utils import \
|
|
|
|
| 446 |
for i in range(b):
|
| 447 |
inseq = encoder_out[i, :, :].unsqueeze(1)
|
| 448 |
preds.append(self._greedy_decode(inseq, encoded_lengths[i]))
|
| 449 |
+
return pad_sequence(preds, batch_first=True, padding_value=self.config.blank_id)
|