Update modeling_conformer.py
Browse files- modeling_conformer.py +2 -1
modeling_conformer.py
CHANGED
|
@@ -232,7 +232,8 @@ encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv
|
|
| 232 |
k = (b | ~v).view(1, -1, 1)
|
| 233 |
h = (torch.where(k, h[0], h2[0]), torch.where(k, h[1], h2[1]))
|
| 234 |
self.cache_length = None
|
| 235 |
-
|
|
|
|
| 236 |
|
| 237 |
def make_srt(self, x, ts):
|
| 238 |
t , s = x
|
|
|
|
| 232 |
k = (b | ~v).view(1, -1, 1)
|
| 233 |
h = (torch.where(k, h[0], h2[0]), torch.where(k, h[1], h2[1]))
|
| 234 |
self.cache_length = None
|
| 235 |
+
device = next(self.parameters()).device
|
| 236 |
+
return [torch.tensor(i,device=device) for i in tok], [torch.tensor(i,device=device) for i in st]
|
| 237 |
|
| 238 |
def make_srt(self, x, ts):
|
| 239 |
t , s = x
|