Spaces:
Paused
Paused
Anton Forsman
commited on
Commit
·
06ebfb2
1
Parent(s):
94539f6
fix device issue
Browse files
unet.py
CHANGED
|
@@ -411,6 +411,10 @@ class ConditionalUnet(nn.Module):
|
|
| 411 |
|
| 412 |
self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
def forward(self, x, t, cond=None):
|
| 415 |
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
| 416 |
cond = cond.unsqueeze(0)
|
|
|
|
| 411 |
|
| 412 |
self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
|
| 413 |
|
| 414 |
+
def to(self, device):
|
| 415 |
+
self.device = device
|
| 416 |
+
return super().to(device)
|
| 417 |
+
|
| 418 |
def forward(self, x, t, cond=None):
|
| 419 |
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
| 420 |
cond = cond.unsqueeze(0)
|