0710-1056
Browse files- diffusion.py +22 -21
diffusion.py
CHANGED
|
@@ -302,6 +302,12 @@ class TrainConfig:
|
|
| 302 |
# @dataclass
|
| 303 |
class DDPM21CM:
|
| 304 |
def __init__(self, config):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
# config = TrainConfig()
|
| 306 |
# date = datetime.datetime.now().strftime("%m%d-%H%M")
|
| 307 |
config.run_name = datetime.datetime.now().strftime("%m%d-%H%M") # the unique name of each experiment
|
|
@@ -549,17 +555,12 @@ if __name__ == "__main__":
|
|
| 549 |
# torch.cuda.set_device(0)
|
| 550 |
|
| 551 |
# %%
|
| 552 |
-
print(
|
| 553 |
-
"torch.cuda.is_available() =", torch.cuda.is_available(),
|
| 554 |
-
"torch.cuda.device_count() =", torch.cuda.device_count(),
|
| 555 |
-
"torch.cuda.is_initialized() =", torch.cuda.is_initialized()
|
| 556 |
-
)
|
| 557 |
# print(torch.cuda.__dir__())
|
| 558 |
|
| 559 |
# %%
|
| 560 |
# print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
|
| 561 |
# print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
|
| 562 |
-
print("torch.cuda.current_device() =", torch.cuda.current_device())
|
| 563 |
# print("torch.cuda.get_device_capability() =", torch.cuda.get_device_capability())
|
| 564 |
# print("torch.cuda.get_device_properties(torch.cuda.device) =", torch.cuda.get_device_properties(torch.cuda.device))
|
| 565 |
# print('here')
|
|
@@ -573,26 +574,26 @@ print("torch.cuda.current_device() =", torch.cuda.current_device())
|
|
| 573 |
# # Sampling
|
| 574 |
|
| 575 |
# %%
|
| 576 |
-
if __name__ == "__main__":
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
|
| 587 |
-
|
| 588 |
|
| 589 |
-
|
| 590 |
|
| 591 |
-
|
| 592 |
|
| 593 |
-
|
| 594 |
|
| 595 |
-
|
| 596 |
|
| 597 |
# %%
|
| 598 |
# ls -lth outputs | head
|
|
|
|
| 302 |
# @dataclass
|
| 303 |
class DDPM21CM:
|
| 304 |
def __init__(self, config):
|
| 305 |
+
print(
|
| 306 |
+
"torch.cuda.is_available() =", torch.cuda.is_available(),
|
| 307 |
+
"torch.cuda.device_count() =", torch.cuda.device_count(),
|
| 308 |
+
"torch.cuda.is_initialized() =", torch.cuda.is_initialized(),
|
| 309 |
+
"torch.cuda.current_device() =", torch.cuda.current_device()
|
| 310 |
+
)
|
| 311 |
# config = TrainConfig()
|
| 312 |
# date = datetime.datetime.now().strftime("%m%d-%H%M")
|
| 313 |
config.run_name = datetime.datetime.now().strftime("%m%d-%H%M") # the unique name of each experiment
|
|
|
|
| 555 |
# torch.cuda.set_device(0)
|
| 556 |
|
| 557 |
# %%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
# print(torch.cuda.__dir__())
|
| 559 |
|
| 560 |
# %%
|
| 561 |
# print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
|
| 562 |
# print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
|
| 563 |
+
# print("torch.cuda.current_device() =", torch.cuda.current_device())
|
| 564 |
# print("torch.cuda.get_device_capability() =", torch.cuda.get_device_capability())
|
| 565 |
# print("torch.cuda.get_device_properties(torch.cuda.device) =", torch.cuda.get_device_properties(torch.cuda.device))
|
| 566 |
# print('here')
|
|
|
|
| 574 |
# # Sampling
|
| 575 |
|
| 576 |
# %%
|
| 577 |
+
# if __name__ == "__main__":
|
| 578 |
+
# # num_image_list = [1600,3200,6400,12800,25600]
|
| 579 |
+
# num_image_list = [1000]
|
| 580 |
+
# # num_image_list = [3200,6400,12800,25600]
|
| 581 |
+
# # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
| 582 |
+
# repeat = 2
|
| 583 |
+
# config = TrainConfig()
|
| 584 |
+
# for i, num_image in enumerate(num_image_list):
|
| 585 |
+
# config.num_image = num_image
|
| 586 |
+
# ddpm21cm = DDPM21CM(config)
|
| 587 |
|
| 588 |
+
# ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor([4.4, 131.341]), repeat=repeat)
|
| 589 |
|
| 590 |
+
# # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((5.6, 19.037)), repeat=repeat)
|
| 591 |
|
| 592 |
+
# # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((4.699, 30)), repeat=repeat)
|
| 593 |
|
| 594 |
+
# # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((5.477, 200)), repeat=repeat)
|
| 595 |
|
| 596 |
+
# # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((4.8, 131.341)), repeat=repeat)
|
| 597 |
|
| 598 |
# %%
|
| 599 |
# ls -lth outputs | head
|