Spaces:
Build error
Build error
Warvito
commited on
Commit
·
1d1c1e7
1
Parent(s):
6686e5d
Try fix cpu only
Browse files- models/ddim.py +3 -3
models/ddim.py
CHANGED
|
@@ -68,8 +68,8 @@ class DDIMSampler(object):
|
|
| 68 |
|
| 69 |
def register_buffer(self, name, attr):
|
| 70 |
if type(attr) == torch.Tensor:
|
| 71 |
-
if attr.device != torch.device("
|
| 72 |
-
attr = attr.to(torch.device("
|
| 73 |
setattr(self, name, attr)
|
| 74 |
|
| 75 |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
|
@@ -77,7 +77,7 @@ class DDIMSampler(object):
|
|
| 77 |
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
| 78 |
alphas_cumprod = self.model.alphas_cumprod
|
| 79 |
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
| 80 |
-
to_torch = lambda x: x.clone().detach().to(torch.float32).to(torch.device("
|
| 81 |
|
| 82 |
self.register_buffer('betas', to_torch(self.model.betas))
|
| 83 |
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
|
|
|
| 68 |
|
| 69 |
def register_buffer(self, name, attr):
|
| 70 |
if type(attr) == torch.Tensor:
|
| 71 |
+
if attr.device != torch.device("cpu"):
|
| 72 |
+
attr = attr.to(torch.device("cpu"))
|
| 73 |
setattr(self, name, attr)
|
| 74 |
|
| 75 |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
|
|
|
| 77 |
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
| 78 |
alphas_cumprod = self.model.alphas_cumprod
|
| 79 |
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
| 80 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(torch.device("cpu"))
|
| 81 |
|
| 82 |
self.register_buffer('betas', to_torch(self.model.betas))
|
| 83 |
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|