Xsmos commited on
Commit
c250830
·
verified ·
1 Parent(s): d130d8c
Files changed (2) hide show
  1. context_unet.py +1 -1
  2. diffusion.py +6 -3
context_unet.py CHANGED
@@ -516,7 +516,7 @@ class ContextUnet(nn.Module):
516
 
517
  def forward(self, x, timesteps, y=None):
518
  hs = []
519
- print("device of timesteps, self.model_channels:", timesteps.device, self.model_channels)
520
  emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
521
  if y != None:
522
  text_outputs = self.token_embedding(y.float())
 
516
 
517
  def forward(self, x, timesteps, y=None):
518
  hs = []
519
+ # print("device of timesteps, self.model_channels:", timesteps.device, self.model_channels)
520
  emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
521
  if y != None:
522
  text_outputs = self.token_embedding(y.float())
diffusion.py CHANGED
@@ -549,12 +549,15 @@ if __name__ == "__main__":
549
  # torch.cuda.set_device(0)
550
 
551
  # %%
552
- print("torch.cuda.is_available() =", torch.cuda.is_available())
553
- print("torch.cuda.device_count() =", torch.cuda.device_count())
 
 
 
554
  # print(torch.cuda.__dir__())
555
 
556
  # %%
557
- print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
558
  # print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
559
  print("torch.cuda.current_device() =", torch.cuda.current_device())
560
  # print("torch.cuda.get_device_capability() =", torch.cuda.get_device_capability())
 
549
  # torch.cuda.set_device(0)
550
 
551
  # %%
552
+ print(
553
+ "torch.cuda.is_available() =", torch.cuda.is_available(),
554
+ "torch.cuda.device_count() =", torch.cuda.device_count(),
555
+ "torch.cuda.is_initialized() =", torch.cuda.is_initialized()
556
+ )
557
  # print(torch.cuda.__dir__())
558
 
559
  # %%
560
+ # print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
561
  # print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
562
  print("torch.cuda.current_device() =", torch.cuda.current_device())
563
  # print("torch.cuda.get_device_capability() =", torch.cuda.get_device_capability())