klemenk commited on
Commit
a68b041
·
verified ·
1 Parent(s): df8c03d

Update modeling_speech_encoder.py

Browse files
Files changed (1) hide show
  1. modeling_speech_encoder.py +18 -1
modeling_speech_encoder.py CHANGED
@@ -11,7 +11,24 @@ import torchaudio
11
  from transformers import PreTrainedModel
12
 
13
  from .configuration_speech_encoder import SpeechEncoderConfig
14
- from .collater_utils import wrap_bos_eos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  # ----------------------------
 
11
  from transformers import PreTrainedModel
12
 
13
  from .configuration_speech_encoder import SpeechEncoderConfig
14
+
15
+
16
+ def wrap_bos_eos(
17
+ units: torch.Tensor,
18
+ durations: torch.Tensor,
19
+ f0: torch.Tensor | None,
20
+ dense_features: torch.Tensor,
21
+ bos: torch.Tensor,
22
+ eos: torch.Tensor,
23
+ ):
24
+ # bos/eos are 1-element tensors on the right device/dtype
25
+ one = durations.new_ones(1)
26
+ units = torch.cat([bos.to(units.device), units, eos.to(units.device)], dim=0)
27
+ durations = torch.cat([one, durations, one], dim=0)
28
+ if f0 is not None:
29
+ # pad f0 with edge values
30
+ f0 = torch.cat([f0[:1], f0, f0[-1:]], dim=0)
31
+ return units, durations, f0, dense_features
32
 
33
 
34
  # ----------------------------