Spaces:
Runtime error
Runtime error
sample_ddpm_context
Browse files
app.py
CHANGED
|
@@ -198,7 +198,7 @@ def denoise_add_noise(x, t, pred_noise, z=None):
|
|
| 198 |
|
| 199 |
# sample using standard algorithm
|
| 200 |
@torch.no_grad()
|
| 201 |
-
def sample_ddpm(n_sample, save_rate=20):
|
| 202 |
# x_T ~ N(0, 1), sample initial noise
|
| 203 |
samples = torch.randn(n_sample, 3, height, height).to(device)
|
| 204 |
|
|
@@ -221,6 +221,30 @@ def sample_ddpm(n_sample, save_rate=20):
|
|
| 221 |
intermediate = np.stack(intermediate)
|
| 222 |
return samples, intermediate
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
def greet(input):
|
| 225 |
steps = int(input)
|
| 226 |
|
|
@@ -233,7 +257,7 @@ def greet(input):
|
|
| 233 |
ctx = torch.from_numpy(mtx_2d).to(device=device).float()
|
| 234 |
|
| 235 |
#samples, intermediate = sample_ddim_context(32, ctx, n=steps)
|
| 236 |
-
samples, intermediate =
|
| 237 |
|
| 238 |
#samples, intermediate = sample_ddim(32, n=steps)
|
| 239 |
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|
|
|
|
| 198 |
|
| 199 |
# sample using standard algorithm
|
| 200 |
@torch.no_grad()
|
| 201 |
+
def sample_ddpm(n_sample, context, save_rate=20):
|
| 202 |
# x_T ~ N(0, 1), sample initial noise
|
| 203 |
samples = torch.randn(n_sample, 3, height, height).to(device)
|
| 204 |
|
|
|
|
| 221 |
intermediate = np.stack(intermediate)
|
| 222 |
return samples, intermediate
|
| 223 |
|
| 224 |
+
@torch.no_grad()
|
| 225 |
+
def sample_ddpm_context(n_sample, save_rate=20):
|
| 226 |
+
# x_T ~ N(0, 1), sample initial noise
|
| 227 |
+
samples = torch.randn(n_sample, 3, height, height).to(device)
|
| 228 |
+
|
| 229 |
+
# array to keep track of generated steps for plotting
|
| 230 |
+
intermediate = []
|
| 231 |
+
for i in range(timesteps, 0, -1):
|
| 232 |
+
print(f'sampling timestep {i:3d}', end='\r')
|
| 233 |
+
|
| 234 |
+
# reshape time tensor
|
| 235 |
+
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
|
| 236 |
+
|
| 237 |
+
# sample some random noise to inject back in. For i = 1, don't add back in noise
|
| 238 |
+
z = torch.randn_like(samples) if i > 1 else 0
|
| 239 |
+
|
| 240 |
+
eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t)
|
| 241 |
+
samples = denoise_add_noise(samples, i, eps, z)
|
| 242 |
+
if i % save_rate ==0 or i==timesteps or i<8:
|
| 243 |
+
intermediate.append(samples.detach().cpu().numpy())
|
| 244 |
+
|
| 245 |
+
intermediate = np.stack(intermediate)
|
| 246 |
+
return samples, intermediate
|
| 247 |
+
|
| 248 |
def greet(input):
|
| 249 |
steps = int(input)
|
| 250 |
|
|
|
|
| 257 |
ctx = torch.from_numpy(mtx_2d).to(device=device).float()
|
| 258 |
|
| 259 |
#samples, intermediate = sample_ddim_context(32, ctx, n=steps)
|
| 260 |
+
samples, intermediate = sample_ddpm_context(32, ctx, steps)
|
| 261 |
|
| 262 |
#samples, intermediate = sample_ddim(32, n=steps)
|
| 263 |
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|