shethjenil commited on
Commit
fd5e502
·
verified ·
1 Parent(s): 62fc451

Update modeling_conformer.py

Browse files
Files changed (1) hide show
  1. modeling_conformer.py +4 -3
modeling_conformer.py CHANGED
@@ -239,15 +239,16 @@ encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv
239
  t , s = x
240
  start_token_segment = self.config.languages.index(self.language) * self.joint.out_features
241
  all_tokens, all_starts, all_ends = [], [], []
 
242
  for tokens, starts, (s, e) in zip(t,s, ts):
243
  tokens += start_token_segment
244
  starts += s
245
  all_tokens.append(tokens)
246
  all_starts.append(starts)
247
  all_ends.append(torch.cat([starts[1:], e[None]]))
248
- all_tokens.append(torch.tensor([-1]))
249
- all_starts.append(torch.tensor([e]))
250
- all_ends.append(torch.tensor([e + 0.005]))
251
  return [srt.Subtitle(i,timedelta(seconds=float(st)),timedelta(seconds=float(en)),"<line>" if tok == -1 else self.config.vocab[int(tok)]) for i, (tok, st, en) in enumerate(zip(torch.cat(all_tokens), torch.cat(all_starts), torch.cat(all_ends)), 1)]
252
 
253
 
 
239
  t , s = x
240
  start_token_segment = self.config.languages.index(self.language) * self.joint.out_features
241
  all_tokens, all_starts, all_ends = [], [], []
242
+ device = t[0].device
243
  for tokens, starts, (s, e) in zip(t,s, ts):
244
  tokens += start_token_segment
245
  starts += s
246
  all_tokens.append(tokens)
247
  all_starts.append(starts)
248
  all_ends.append(torch.cat([starts[1:], e[None]]))
249
+ all_tokens.append(torch.tensor([-1],device=device))
250
+ all_starts.append(torch.tensor([e],device=device))
251
+ all_ends.append(torch.tensor([e + 0.005],device=device))
252
  return [srt.Subtitle(i,timedelta(seconds=float(st)),timedelta(seconds=float(en)),"<line>" if tok == -1 else self.config.vocab[int(tok)]) for i, (tok, st, en) in enumerate(zip(torch.cat(all_tokens), torch.cat(all_starts), torch.cat(all_ends)), 1)]
253
 
254