0715-1309
Browse files- context_unet.py +3 -3
- diffusion.py +13 -5
context_unet.py
CHANGED
|
@@ -318,7 +318,7 @@ class ContextUnet(nn.Module):
|
|
| 318 |
encoder_channels = None,
|
| 319 |
dim = 2,
|
| 320 |
stride = (2,2),
|
| 321 |
-
|
| 322 |
):
|
| 323 |
super().__init__()
|
| 324 |
|
|
@@ -351,8 +351,8 @@ class ContextUnet(nn.Module):
|
|
| 351 |
|
| 352 |
# self.n_param = n_param
|
| 353 |
self.model_channels = model_channels
|
| 354 |
-
self.use_fp16 = use_fp16
|
| 355 |
-
self.dtype = torch.float16 if self.use_fp16 else torch.float32
|
| 356 |
|
| 357 |
self.token_embedding = nn.Linear(n_param, model_channels * 4)
|
| 358 |
|
|
|
|
| 318 |
encoder_channels = None,
|
| 319 |
dim = 2,
|
| 320 |
stride = (2,2),
|
| 321 |
+
dtype = torch.float32,
|
| 322 |
):
|
| 323 |
super().__init__()
|
| 324 |
|
|
|
|
| 351 |
|
| 352 |
# self.n_param = n_param
|
| 353 |
self.model_channels = model_channels
|
| 354 |
+
# self.use_fp16 = use_fp16
|
| 355 |
+
self.dtype = dtype#torch.float16 if self.use_fp16 else torch.float32
|
| 356 |
|
| 357 |
self.token_embedding = nn.Linear(n_param, model_channels * 4)
|
| 358 |
|
diffusion.py
CHANGED
|
@@ -96,7 +96,7 @@ def ddp_setup(rank: int, world_size: int):
|
|
| 96 |
|
| 97 |
# %%
|
| 98 |
class DDPMScheduler(nn.Module):
|
| 99 |
-
def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu'):
|
| 100 |
super().__init__()
|
| 101 |
|
| 102 |
beta_1, beta_T = betas
|
|
@@ -112,6 +112,8 @@ class DDPMScheduler(nn.Module):
|
|
| 112 |
self.alpha_t = 1 - self.beta_t
|
| 113 |
# self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))
|
| 114 |
self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)
|
|
|
|
|
|
|
| 115 |
|
| 116 |
def add_noise(self, clean_images):
|
| 117 |
shape = clean_images.shape
|
|
@@ -280,7 +282,8 @@ class TrainConfig:
|
|
| 280 |
# params = params
|
| 281 |
# data_dir = './data' # data directory
|
| 282 |
|
| 283 |
-
use_fp16 =
|
|
|
|
| 284 |
mixed_precision = "fp16"
|
| 285 |
gradient_accumulation_steps = 1
|
| 286 |
|
|
@@ -317,10 +320,10 @@ class DDPM21CM:
|
|
| 317 |
# # print("shape_loaded =", self.shape_loaded)
|
| 318 |
# self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
|
| 319 |
# del dataset
|
| 320 |
-
self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)
|
| 321 |
|
| 322 |
# initialize the unet
|
| 323 |
-
self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride,
|
| 324 |
|
| 325 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 326 |
self.nn_model.train()
|
|
@@ -344,7 +347,7 @@ class DDPM21CM:
|
|
| 344 |
if config.ema:
|
| 345 |
self.ema = EMA(config.ema_rate)
|
| 346 |
if config.resume and os.path.exists(config.resume):
|
| 347 |
-
self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride,
|
| 348 |
self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])
|
| 349 |
print(f"resumed ema_model from {config.resume}")
|
| 350 |
else:
|
|
@@ -433,6 +436,9 @@ class DDPM21CM:
|
|
| 433 |
# print(f"device {torch.cuda.current_device()}, x[:,0,:2,0,0] =", x[:,0,:2,0,0])
|
| 434 |
with self.accelerator.accumulate(self.nn_model):
|
| 435 |
x = x.to(self.config.device)
|
|
|
|
|
|
|
|
|
|
| 436 |
xt, noise, ts = self.ddpm.add_noise(x)
|
| 437 |
|
| 438 |
if self.config.guide_w == -1:
|
|
@@ -440,6 +446,8 @@ class DDPM21CM:
|
|
| 440 |
else:
|
| 441 |
c = c.to(self.config.device)
|
| 442 |
noise_pred = self.nn_model(xt, ts, c)
|
|
|
|
|
|
|
| 443 |
|
| 444 |
loss = F.mse_loss(noise, noise_pred)
|
| 445 |
self.accelerator.backward(loss)
|
|
|
|
| 96 |
|
| 97 |
# %%
|
| 98 |
class DDPMScheduler(nn.Module):
|
| 99 |
+
def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu', dtype=torch.float32):
|
| 100 |
super().__init__()
|
| 101 |
|
| 102 |
beta_1, beta_T = betas
|
|
|
|
| 112 |
self.alpha_t = 1 - self.beta_t
|
| 113 |
# self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))
|
| 114 |
self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)
|
| 115 |
+
# self.use_fp16 = use_fp16
|
| 116 |
+
self.dtype = dtype#torch.float16 if self.use_fp16 else torch.float32
|
| 117 |
|
| 118 |
def add_noise(self, clean_images):
|
| 119 |
shape = clean_images.shape
|
|
|
|
| 282 |
# params = params
|
| 283 |
# data_dir = './data' # data directory
|
| 284 |
|
| 285 |
+
use_fp16 = False
|
| 286 |
+
dtype = torch.float16 if use_fp16 else torch.float32
|
| 287 |
mixed_precision = "fp16"
|
| 288 |
gradient_accumulation_steps = 1
|
| 289 |
|
|
|
|
| 320 |
# # print("shape_loaded =", self.shape_loaded)
|
| 321 |
# self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
|
| 322 |
# del dataset
|
| 323 |
+
self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device, dtype=config.dtype)
|
| 324 |
|
| 325 |
# initialize the unet
|
| 326 |
+
self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, dtype=config.dtype)
|
| 327 |
|
| 328 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 329 |
self.nn_model.train()
|
|
|
|
| 347 |
if config.ema:
|
| 348 |
self.ema = EMA(config.ema_rate)
|
| 349 |
if config.resume and os.path.exists(config.resume):
|
| 350 |
+
self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, dtype=config.dtype).to(config.device)
|
| 351 |
self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])
|
| 352 |
print(f"resumed ema_model from {config.resume}")
|
| 353 |
else:
|
|
|
|
| 436 |
# print(f"device {torch.cuda.current_device()}, x[:,0,:2,0,0] =", x[:,0,:2,0,0])
|
| 437 |
with self.accelerator.accumulate(self.nn_model):
|
| 438 |
x = x.to(self.config.device)
|
| 439 |
+
print("x = x.to(self.config.device), x.dtype =", x.dtype)
|
| 440 |
+
x = x.to(self.config.dtype)
|
| 441 |
+
print("x = x.to(self.dtype), x.dtype =", x.dtype)
|
| 442 |
xt, noise, ts = self.ddpm.add_noise(x)
|
| 443 |
|
| 444 |
if self.config.guide_w == -1:
|
|
|
|
| 446 |
else:
|
| 447 |
c = c.to(self.config.device)
|
| 448 |
noise_pred = self.nn_model(xt, ts, c)
|
| 449 |
+
|
| 450 |
+
print("noise_pred = self.nn_model(xt, ts, c), noise_pred.dtype =", noise_pred.dtype)
|
| 451 |
|
| 452 |
loss = F.mse_loss(noise, noise_pred)
|
| 453 |
self.accelerator.backward(loss)
|