Xsmos commited on
Commit
7b0e0c4
·
verified ·
1 Parent(s): 0eafba4
Files changed (2) hide show
  1. context_unet.py +3 -3
  2. diffusion.py +13 -5
context_unet.py CHANGED
@@ -318,7 +318,7 @@ class ContextUnet(nn.Module):
318
  encoder_channels = None,
319
  dim = 2,
320
  stride = (2,2),
321
- use_fp16 = False,
322
  ):
323
  super().__init__()
324
 
@@ -351,8 +351,8 @@ class ContextUnet(nn.Module):
351
 
352
  # self.n_param = n_param
353
  self.model_channels = model_channels
354
- self.use_fp16 = use_fp16
355
- self.dtype = torch.float16 if self.use_fp16 else torch.float32
356
 
357
  self.token_embedding = nn.Linear(n_param, model_channels * 4)
358
 
 
318
  encoder_channels = None,
319
  dim = 2,
320
  stride = (2,2),
321
+ dtype = torch.float32,
322
  ):
323
  super().__init__()
324
 
 
351
 
352
  # self.n_param = n_param
353
  self.model_channels = model_channels
354
+ # self.use_fp16 = use_fp16
355
+ self.dtype = dtype#torch.float16 if self.use_fp16 else torch.float32
356
 
357
  self.token_embedding = nn.Linear(n_param, model_channels * 4)
358
 
diffusion.py CHANGED
@@ -96,7 +96,7 @@ def ddp_setup(rank: int, world_size: int):
96
 
97
  # %%
98
  class DDPMScheduler(nn.Module):
99
- def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu'):
100
  super().__init__()
101
 
102
  beta_1, beta_T = betas
@@ -112,6 +112,8 @@ class DDPMScheduler(nn.Module):
112
  self.alpha_t = 1 - self.beta_t
113
  # self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))
114
  self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)
 
 
115
 
116
  def add_noise(self, clean_images):
117
  shape = clean_images.shape
@@ -280,7 +282,8 @@ class TrainConfig:
280
  # params = params
281
  # data_dir = './data' # data directory
282
 
283
- use_fp16 = True
 
284
  mixed_precision = "fp16"
285
  gradient_accumulation_steps = 1
286
 
@@ -317,10 +320,10 @@ class DDPM21CM:
317
  # # print("shape_loaded =", self.shape_loaded)
318
  # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
319
  # del dataset
320
- self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)
321
 
322
  # initialize the unet
323
- self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, use_fp16=config.use_fp16)
324
 
325
  # nn_model = ContextUnet(n_param=1, image_size=28)
326
  self.nn_model.train()
@@ -344,7 +347,7 @@ class DDPM21CM:
344
  if config.ema:
345
  self.ema = EMA(config.ema_rate)
346
  if config.resume and os.path.exists(config.resume):
347
- self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, use_fp16=config.use_fp16).to(config.device)
348
  self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])
349
  print(f"resumed ema_model from {config.resume}")
350
  else:
@@ -433,6 +436,9 @@ class DDPM21CM:
433
  # print(f"device {torch.cuda.current_device()}, x[:,0,:2,0,0] =", x[:,0,:2,0,0])
434
  with self.accelerator.accumulate(self.nn_model):
435
  x = x.to(self.config.device)
 
 
 
436
  xt, noise, ts = self.ddpm.add_noise(x)
437
 
438
  if self.config.guide_w == -1:
@@ -440,6 +446,8 @@ class DDPM21CM:
440
  else:
441
  c = c.to(self.config.device)
442
  noise_pred = self.nn_model(xt, ts, c)
 
 
443
 
444
  loss = F.mse_loss(noise, noise_pred)
445
  self.accelerator.backward(loss)
 
96
 
97
  # %%
98
  class DDPMScheduler(nn.Module):
99
+ def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu', dtype=torch.float32):
100
  super().__init__()
101
 
102
  beta_1, beta_T = betas
 
112
  self.alpha_t = 1 - self.beta_t
113
  # self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))
114
  self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)
115
+ # self.use_fp16 = use_fp16
116
+ self.dtype = dtype#torch.float16 if self.use_fp16 else torch.float32
117
 
118
  def add_noise(self, clean_images):
119
  shape = clean_images.shape
 
282
  # params = params
283
  # data_dir = './data' # data directory
284
 
285
+ use_fp16 = False
286
+ dtype = torch.float16 if use_fp16 else torch.float32
287
  mixed_precision = "fp16"
288
  gradient_accumulation_steps = 1
289
 
 
320
  # # print("shape_loaded =", self.shape_loaded)
321
  # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
322
  # del dataset
323
+ self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device, dtype=config.dtype)
324
 
325
  # initialize the unet
326
+ self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, dtype=config.dtype)
327
 
328
  # nn_model = ContextUnet(n_param=1, image_size=28)
329
  self.nn_model.train()
 
347
  if config.ema:
348
  self.ema = EMA(config.ema_rate)
349
  if config.resume and os.path.exists(config.resume):
350
+ self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, dtype=config.dtype).to(config.device)
351
  self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])
352
  print(f"resumed ema_model from {config.resume}")
353
  else:
 
436
  # print(f"device {torch.cuda.current_device()}, x[:,0,:2,0,0] =", x[:,0,:2,0,0])
437
  with self.accelerator.accumulate(self.nn_model):
438
  x = x.to(self.config.device)
439
+ print("x = x.to(self.config.device), x.dtype =", x.dtype)
440
+ x = x.to(self.config.dtype)
441
+ print("x = x.to(self.dtype), x.dtype =", x.dtype)
442
  xt, noise, ts = self.ddpm.add_noise(x)
443
 
444
  if self.config.guide_w == -1:
 
446
  else:
447
  c = c.to(self.config.device)
448
  noise_pred = self.nn_model(xt, ts, c)
449
+
450
+ print("noise_pred = self.nn_model(xt, ts, c), noise_pred.dtype =", noise_pred.dtype)
451
 
452
  loss = F.mse_loss(noise, noise_pred)
453
  self.accelerator.backward(loss)