Commit
·
67959c9
1
Parent(s):
5ce9741
Remove more custom device management
Browse files
model.py
CHANGED
|
@@ -490,8 +490,10 @@ class ParallelGatedConvBlock(nn.Module):
|
|
| 490 |
|
| 491 |
normalized = self.pre_norm(x)
|
| 492 |
normalized = self.pad_to_multiple(normalized)
|
| 493 |
-
|
| 494 |
-
|
|
|
|
|
|
|
| 495 |
|
| 496 |
if isinstance(projected, tuple):
|
| 497 |
projected = projected[0]
|
|
|
|
| 490 |
|
| 491 |
normalized = self.pre_norm(x)
|
| 492 |
normalized = self.pad_to_multiple(normalized)
|
| 493 |
+
# Ishan: comment out this vestige of manual device management
|
| 494 |
+
# with torch.cuda.device(x.device):
|
| 495 |
+
# projected = self.projections(normalized)
|
| 496 |
+
projected = self.projections(normalized)
|
| 497 |
|
| 498 |
if isinstance(projected, tuple):
|
| 499 |
projected = projected[0]
|