0710-1457
Browse files- diffusion.py +2 -1
diffusion.py
CHANGED
|
@@ -329,8 +329,9 @@ class DDPM21CM:
|
|
| 329 |
print(f"resumed nn_model from {config.resume}")
|
| 330 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 331 |
self.nn_model.train()
|
| 332 |
-
print("self.ddpm.device =", self.ddpm.device)
|
| 333 |
self.nn_model.to(self.ddpm.device)
|
|
|
|
| 334 |
# print("nn_model.device =", ddpm.device)
|
| 335 |
# number of parameters to be trained
|
| 336 |
self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
|
|
|
|
| 329 |
print(f"resumed nn_model from {config.resume}")
|
| 330 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 331 |
self.nn_model.train()
|
| 332 |
+
# print("self.ddpm.device =", self.ddpm.device)
|
| 333 |
self.nn_model.to(self.ddpm.device)
|
| 334 |
+
self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device])
|
| 335 |
# print("nn_model.device =", ddpm.device)
|
| 336 |
# number of parameters to be trained
|
| 337 |
self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
|