0710-1050
Browse files- context_unet.py +1 -1
- diffusion.py +6 -3
context_unet.py
CHANGED
|
@@ -516,7 +516,7 @@ class ContextUnet(nn.Module):
|
|
| 516 |
|
| 517 |
def forward(self, x, timesteps, y=None):
|
| 518 |
hs = []
|
| 519 |
-
print("device of timesteps, self.model_channels:", timesteps.device, self.model_channels)
|
| 520 |
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
| 521 |
if y != None:
|
| 522 |
text_outputs = self.token_embedding(y.float())
|
|
|
|
| 516 |
|
| 517 |
def forward(self, x, timesteps, y=None):
|
| 518 |
hs = []
|
| 519 |
+
# print("device of timesteps, self.model_channels:", timesteps.device, self.model_channels)
|
| 520 |
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
| 521 |
if y != None:
|
| 522 |
text_outputs = self.token_embedding(y.float())
|
diffusion.py
CHANGED
|
@@ -549,12 +549,15 @@ if __name__ == "__main__":
|
|
| 549 |
# torch.cuda.set_device(0)
|
| 550 |
|
| 551 |
# %%
|
| 552 |
-
print(
|
| 553 |
-
|
|
|
|
|
|
|
|
|
|
| 554 |
# print(torch.cuda.__dir__())
|
| 555 |
|
| 556 |
# %%
|
| 557 |
-
print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
|
| 558 |
# print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
|
| 559 |
print("torch.cuda.current_device() =", torch.cuda.current_device())
|
| 560 |
# print("torch.cuda.get_device_capability() =", torch.cuda.get_device_capability())
|
|
|
|
| 549 |
# torch.cuda.set_device(0)
|
| 550 |
|
| 551 |
# %%
|
| 552 |
+
print(
|
| 553 |
+
"torch.cuda.is_available() =", torch.cuda.is_available(),
|
| 554 |
+
"torch.cuda.device_count() =", torch.cuda.device_count(),
|
| 555 |
+
"torch.cuda.is_initialized() =", torch.cuda.is_initialized()
|
| 556 |
+
)
|
| 557 |
# print(torch.cuda.__dir__())
|
| 558 |
|
| 559 |
# %%
|
| 560 |
+
# print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
|
| 561 |
# print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
|
| 562 |
print("torch.cuda.current_device() =", torch.cuda.current_device())
|
| 563 |
# print("torch.cuda.get_device_capability() =", torch.cuda.get_device_capability())
|