Xsmos commited on
Commit
d66ea1b
·
verified ·
1 Parent(s): c250830
Files changed (1) hide show
  1. diffusion.py +22 -21
diffusion.py CHANGED
@@ -302,6 +302,12 @@ class TrainConfig:
302
  # @dataclass
303
  class DDPM21CM:
304
  def __init__(self, config):
 
 
 
 
 
 
305
  # config = TrainConfig()
306
  # date = datetime.datetime.now().strftime("%m%d-%H%M")
307
  config.run_name = datetime.datetime.now().strftime("%m%d-%H%M") # the unique name of each experiment
@@ -549,17 +555,12 @@ if __name__ == "__main__":
549
  # torch.cuda.set_device(0)
550
 
551
  # %%
552
- print(
553
- "torch.cuda.is_available() =", torch.cuda.is_available(),
554
- "torch.cuda.device_count() =", torch.cuda.device_count(),
555
- "torch.cuda.is_initialized() =", torch.cuda.is_initialized()
556
- )
557
  # print(torch.cuda.__dir__())
558
 
559
  # %%
560
  # print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
561
  # print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
562
- print("torch.cuda.current_device() =", torch.cuda.current_device())
563
  # print("torch.cuda.get_device_capability() =", torch.cuda.get_device_capability())
564
  # print("torch.cuda.get_device_properties(torch.cuda.device) =", torch.cuda.get_device_properties(torch.cuda.device))
565
  # print('here')
@@ -573,26 +574,26 @@ print("torch.cuda.current_device() =", torch.cuda.current_device())
573
  # # Sampling
574
 
575
  # %%
576
- if __name__ == "__main__":
577
- # num_image_list = [1600,3200,6400,12800,25600]
578
- num_image_list = [1000]
579
- # num_image_list = [3200,6400,12800,25600]
580
- # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
581
- repeat = 2
582
- config = TrainConfig()
583
- for i, num_image in enumerate(num_image_list):
584
- config.num_image = num_image
585
- ddpm21cm = DDPM21CM(config)
586
 
587
- ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor([4.4, 131.341]), repeat=repeat)
588
 
589
- # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((5.6, 19.037)), repeat=repeat)
590
 
591
- # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((4.699, 30)), repeat=repeat)
592
 
593
- # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((5.477, 200)), repeat=repeat)
594
 
595
- # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((4.8, 131.341)), repeat=repeat)
596
 
597
  # %%
598
  # ls -lth outputs | head
 
302
  # @dataclass
303
  class DDPM21CM:
304
  def __init__(self, config):
305
+ print(
306
+ "torch.cuda.is_available() =", torch.cuda.is_available(),
307
+ "torch.cuda.device_count() =", torch.cuda.device_count(),
308
+ "torch.cuda.is_initialized() =", torch.cuda.is_initialized(),
309
+ "torch.cuda.current_device() =", torch.cuda.current_device()
310
+ )
311
  # config = TrainConfig()
312
  # date = datetime.datetime.now().strftime("%m%d-%H%M")
313
  config.run_name = datetime.datetime.now().strftime("%m%d-%H%M") # the unique name of each experiment
 
555
  # torch.cuda.set_device(0)
556
 
557
  # %%
 
 
 
 
 
558
  # print(torch.cuda.__dir__())
559
 
560
  # %%
561
  # print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
562
  # print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
563
+ # print("torch.cuda.current_device() =", torch.cuda.current_device())
564
  # print("torch.cuda.get_device_capability() =", torch.cuda.get_device_capability())
565
  # print("torch.cuda.get_device_properties(torch.cuda.device) =", torch.cuda.get_device_properties(torch.cuda.device))
566
  # print('here')
 
574
  # # Sampling
575
 
576
  # %%
577
+ # if __name__ == "__main__":
578
+ # # num_image_list = [1600,3200,6400,12800,25600]
579
+ # num_image_list = [1000]
580
+ # # num_image_list = [3200,6400,12800,25600]
581
+ # # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
582
+ # repeat = 2
583
+ # config = TrainConfig()
584
+ # for i, num_image in enumerate(num_image_list):
585
+ # config.num_image = num_image
586
+ # ddpm21cm = DDPM21CM(config)
587
 
588
+ # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor([4.4, 131.341]), repeat=repeat)
589
 
590
+ # # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((5.6, 19.037)), repeat=repeat)
591
 
592
+ # # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((4.699, 30)), repeat=repeat)
593
 
594
+ # # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((5.477, 200)), repeat=repeat)
595
 
596
+ # # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((4.8, 131.341)), repeat=repeat)
597
 
598
  # %%
599
  # ls -lth outputs | head