facehuggingjay commited on
Commit
7b63b18
·
verified ·
1 Parent(s): 3fb5067

Update BidirectionalTranslation/models/cycle_ganstft_model.py

Browse files
BidirectionalTranslation/models/cycle_ganstft_model.py CHANGED
@@ -16,8 +16,8 @@ class CycleGANSTFTModel(BaseModel):
16
  self.interchnnls = 4
17
  use_noise = False
18
  self.half_size = opt.batch_size //2
19
- self.device=opt.local_rank
20
- self.gpu_ids=[self.device]
21
  self.local_rank = opt.local_rank
22
  self.cropsize = opt.crop_size
23
 
@@ -100,4 +100,4 @@ class CycleGANSTFTModel(BaseModel):
100
  k += 1
101
  z = torch.clamp(z, -tvalue, tvalue)
102
 
103
- return z.detach().to(self.device)
 
16
  self.interchnnls = 4
17
  use_noise = False
18
  self.half_size = opt.batch_size //2
19
+ self.device = torch.device('cuda:{}'.format(opt.local_rank)) if torch.cuda.is_available() else torch.device('cpu')
20
+ self.gpu_ids = [opt.local_rank] if torch.cuda.is_available() else []
21
  self.local_rank = opt.local_rank
22
  self.cropsize = opt.crop_size
23
 
 
100
  k += 1
101
  z = torch.clamp(z, -tvalue, tvalue)
102
 
103
+ return z.detach().to(self.device)