shethjenil commited on
Commit
62fc451
·
verified ·
1 Parent(s): ae4543e

Update modeling_conformer.py

Browse files
Files changed (1) hide show
  1. modeling_conformer.py +2 -1
modeling_conformer.py CHANGED
@@ -232,7 +232,8 @@ encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv
232
  k = (b | ~v).view(1, -1, 1)
233
  h = (torch.where(k, h[0], h2[0]), torch.where(k, h[1], h2[1]))
234
  self.cache_length = None
235
- return [torch.tensor(i) for i in tok], [torch.tensor(i) for i in st]
 
236
 
237
  def make_srt(self, x, ts):
238
  t , s = x
 
232
  k = (b | ~v).view(1, -1, 1)
233
  h = (torch.where(k, h[0], h2[0]), torch.where(k, h[1], h2[1]))
234
  self.cache_length = None
235
+ device = next(self.parameters()).device
236
+ return [torch.tensor(i,device=device) for i in tok], [torch.tensor(i,device=device) for i in st]
237
 
238
  def make_srt(self, x, ts):
239
  t , s = x