0715-1314
Browse files- diffusion.py +2 -2
diffusion.py
CHANGED
|
@@ -282,7 +282,7 @@ class TrainConfig:
|
|
| 282 |
# params = params
|
| 283 |
# data_dir = './data' # data directory
|
| 284 |
|
| 285 |
-
use_fp16 =
|
| 286 |
dtype = torch.float16 if use_fp16 else torch.float32
|
| 287 |
mixed_precision = "fp16"
|
| 288 |
gradient_accumulation_steps = 1
|
|
@@ -437,7 +437,7 @@ class DDPM21CM:
|
|
| 437 |
with self.accelerator.accumulate(self.nn_model):
|
| 438 |
x = x.to(self.config.device)
|
| 439 |
print("x = x.to(self.config.device), x.dtype =", x.dtype)
|
| 440 |
-
x = x.to(self.config.dtype)
|
| 441 |
print("x = x.to(self.dtype), x.dtype =", x.dtype)
|
| 442 |
xt, noise, ts = self.ddpm.add_noise(x)
|
| 443 |
|
|
|
|
| 282 |
# params = params
|
| 283 |
# data_dir = './data' # data directory
|
| 284 |
|
| 285 |
+
use_fp16 = True
|
| 286 |
dtype = torch.float16 if use_fp16 else torch.float32
|
| 287 |
mixed_precision = "fp16"
|
| 288 |
gradient_accumulation_steps = 1
|
|
|
|
| 437 |
with self.accelerator.accumulate(self.nn_model):
|
| 438 |
x = x.to(self.config.device)
|
| 439 |
print("x = x.to(self.config.device), x.dtype =", x.dtype)
|
| 440 |
+
# x = x.to(self.config.dtype)
|
| 441 |
print("x = x.to(self.dtype), x.dtype =", x.dtype)
|
| 442 |
xt, noise, ts = self.ddpm.add_noise(x)
|
| 443 |
|