Update modeling_conformer.py
Browse files- modeling_conformer.py +4 -3
modeling_conformer.py
CHANGED
|
@@ -239,15 +239,16 @@ encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv
|
|
| 239 |
t , s = x
|
| 240 |
start_token_segment = self.config.languages.index(self.language) * self.joint.out_features
|
| 241 |
all_tokens, all_starts, all_ends = [], [], []
|
|
|
|
| 242 |
for tokens, starts, (s, e) in zip(t,s, ts):
|
| 243 |
tokens += start_token_segment
|
| 244 |
starts += s
|
| 245 |
all_tokens.append(tokens)
|
| 246 |
all_starts.append(starts)
|
| 247 |
all_ends.append(torch.cat([starts[1:], e[None]]))
|
| 248 |
-
all_tokens.append(torch.tensor([-1]))
|
| 249 |
-
all_starts.append(torch.tensor([e]))
|
| 250 |
-
all_ends.append(torch.tensor([e + 0.005]))
|
| 251 |
return [srt.Subtitle(i,timedelta(seconds=float(st)),timedelta(seconds=float(en)),"<line>" if tok == -1 else self.config.vocab[int(tok)]) for i, (tok, st, en) in enumerate(zip(torch.cat(all_tokens), torch.cat(all_starts), torch.cat(all_ends)), 1)]
|
| 252 |
|
| 253 |
|
|
|
|
| 239 |
t , s = x
|
| 240 |
start_token_segment = self.config.languages.index(self.language) * self.joint.out_features
|
| 241 |
all_tokens, all_starts, all_ends = [], [], []
|
| 242 |
+
device = t[0].device
|
| 243 |
for tokens, starts, (s, e) in zip(t,s, ts):
|
| 244 |
tokens += start_token_segment
|
| 245 |
starts += s
|
| 246 |
all_tokens.append(tokens)
|
| 247 |
all_starts.append(starts)
|
| 248 |
all_ends.append(torch.cat([starts[1:], e[None]]))
|
| 249 |
+
all_tokens.append(torch.tensor([-1],device=device))
|
| 250 |
+
all_starts.append(torch.tensor([e],device=device))
|
| 251 |
+
all_ends.append(torch.tensor([e + 0.005],device=device))
|
| 252 |
return [srt.Subtitle(i,timedelta(seconds=float(st)),timedelta(seconds=float(en)),"<line>" if tok == -1 else self.config.vocab[int(tok)]) for i, (tok, st, en) in enumerate(zip(torch.cat(all_tokens), torch.cat(all_starts), torch.cat(all_ends)), 1)]
|
| 253 |
|
| 254 |
|