Spaces:
Runtime error
Runtime error
denoise
Browse files
app.py
CHANGED
|
@@ -122,6 +122,22 @@ ab_t[0] = 1
|
|
| 122 |
# construct model
|
| 123 |
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
# sample quickly using DDIM
|
| 126 |
@torch.no_grad()
|
| 127 |
def sample_ddim(n_sample, n=20):
|
|
|
|
| 122 |
# construct model
|
| 123 |
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
|
| 124 |
|
| 125 |
+
# define sampling function for DDIM
|
| 126 |
+
# removes the noise using ddim
|
| 127 |
+
def denoise_ddim(x, t, t_prev, pred_noise):
|
| 128 |
+
ab = ab_t[t]
|
| 129 |
+
ab_prev = ab_t[t_prev]
|
| 130 |
+
|
| 131 |
+
x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
|
| 132 |
+
dir_xt = (1 - ab_prev).sqrt() * pred_noise
|
| 133 |
+
|
| 134 |
+
return x0_pred + dir_xt
|
| 135 |
+
|
| 136 |
+
# load in model weights and set to eval mode
|
| 137 |
+
nn_model.load_state_dict(torch.load(f"{save_dir}/model_31.pth", map_location=device))
|
| 138 |
+
nn_model.eval()
|
| 139 |
+
print("Loaded in Model without context")
|
| 140 |
+
|
| 141 |
# sample quickly using DDIM
|
| 142 |
@torch.no_grad()
|
| 143 |
def sample_ddim(n_sample, n=20):
|