Xsmos commited on
Commit
fb10d7e
·
verified ·
1 Parent(s): 97a4229
Files changed (1) hide show
  1. diffusion.py +2 -1
diffusion.py CHANGED
@@ -329,8 +329,9 @@ class DDPM21CM:
329
  print(f"resumed nn_model from {config.resume}")
330
  # nn_model = ContextUnet(n_param=1, image_size=28)
331
  self.nn_model.train()
332
- print("self.ddpm.device =", self.ddpm.device)
333
  self.nn_model.to(self.ddpm.device)
 
334
  # print("nn_model.device =", ddpm.device)
335
  # number of parameters to be trained
336
  self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
 
329
  print(f"resumed nn_model from {config.resume}")
330
  # nn_model = ContextUnet(n_param=1, image_size=28)
331
  self.nn_model.train()
332
+ # print("self.ddpm.device =", self.ddpm.device)
333
  self.nn_model.to(self.ddpm.device)
334
+ self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device])
335
  # print("nn_model.device =", ddpm.device)
336
  # number of parameters to be trained
337
  self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())