ishanjmukherjee commited on
Commit
67959c9
·
1 Parent(s): 5ce9741

Remove more custom device management

Browse files
Files changed (1) hide show
  1. model.py +4 -2
model.py CHANGED
@@ -490,8 +490,10 @@ class ParallelGatedConvBlock(nn.Module):
490
 
491
  normalized = self.pre_norm(x)
492
  normalized = self.pad_to_multiple(normalized)
493
- with torch.cuda.device(x.device):
494
- projected = self.projections(normalized)
 
 
495
 
496
  if isinstance(projected, tuple):
497
  projected = projected[0]
 
490
 
491
  normalized = self.pre_norm(x)
492
  normalized = self.pad_to_multiple(normalized)
493
+ # Ishan: comment out this vestige of manual device management
494
+ # with torch.cuda.device(x.device):
495
+ # projected = self.projections(normalized)
496
+ projected = self.projections(normalized)
497
 
498
  if isinstance(projected, tuple):
499
  projected = projected[0]