Xsmos commited on
Commit
ce05a08
·
verified ·
1 Parent(s): 7b0e0c4
Files changed (1) hide show
  1. diffusion.py +2 -2
diffusion.py CHANGED
@@ -282,7 +282,7 @@ class TrainConfig:
282
  # params = params
283
  # data_dir = './data' # data directory
284
 
285
- use_fp16 = False
286
  dtype = torch.float16 if use_fp16 else torch.float32
287
  mixed_precision = "fp16"
288
  gradient_accumulation_steps = 1
@@ -437,7 +437,7 @@ class DDPM21CM:
437
  with self.accelerator.accumulate(self.nn_model):
438
  x = x.to(self.config.device)
439
  print("x = x.to(self.config.device), x.dtype =", x.dtype)
440
- x = x.to(self.config.dtype)
441
  print("x = x.to(self.dtype), x.dtype =", x.dtype)
442
  xt, noise, ts = self.ddpm.add_noise(x)
443
 
 
282
  # params = params
283
  # data_dir = './data' # data directory
284
 
285
+ use_fp16 = True
286
  dtype = torch.float16 if use_fp16 else torch.float32
287
  mixed_precision = "fp16"
288
  gradient_accumulation_steps = 1
 
437
  with self.accelerator.accumulate(self.nn_model):
438
  x = x.to(self.config.device)
439
  print("x = x.to(self.config.device), x.dtype =", x.dtype)
440
+ # x = x.to(self.config.dtype)
441
  print("x = x.to(self.dtype), x.dtype =", x.dtype)
442
  xt, noise, ts = self.ddpm.add_noise(x)
443