Spaces:
Paused
Paused
Anton Forsman commited on
Commit ·
06ebfb2
1
Parent(s): 94539f6
fix device issue
Browse files
unet.py
CHANGED
|
@@ -411,6 +411,10 @@ class ConditionalUnet(nn.Module):
|
|
| 411 |
|
| 412 |
self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
def forward(self, x, t, cond=None):
|
| 415 |
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
| 416 |
cond = cond.unsqueeze(0)
|
|
|
|
| 411 |
|
| 412 |
self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
|
| 413 |
|
| 414 |
+
def to(self, device):
|
| 415 |
+
self.device = device
|
| 416 |
+
return super().to(device)
|
| 417 |
+
|
| 418 |
def forward(self, x, t, cond=None):
|
| 419 |
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
| 420 |
cond = cond.unsqueeze(0)
|