Spaces:
Runtime error
Runtime error
ddpm
Browse files
app.py
CHANGED
|
@@ -188,11 +188,45 @@ def sample_ddim_context(n_sample, context, n=20):
|
|
| 188 |
intermediate = np.stack(intermediate)
|
| 189 |
return samples, intermediate
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
def greet(input):
|
| 192 |
steps = int(input)
|
| 193 |
#samples, intermediate = sample_ddim(32, n=steps)
|
| 194 |
-
ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|
| 195 |
-
samples, intermediate = sample_ddim_context(32, ctx, steps)
|
|
|
|
| 196 |
#response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
|
| 197 |
#response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
|
| 198 |
#response = im.fromarray(intermediate[24][0][1]).convert("RGB")
|
|
|
|
| 188 |
intermediate = np.stack(intermediate)
|
| 189 |
return samples, intermediate
|
| 190 |
|
| 191 |
+
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
|
| 192 |
+
def denoise_add_noise(x, t, pred_noise, z=None):
|
| 193 |
+
if z is None:
|
| 194 |
+
z = torch.randn_like(x)
|
| 195 |
+
noise = b_t.sqrt()[t] * z
|
| 196 |
+
mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
|
| 197 |
+
return mean + noise
|
| 198 |
+
|
| 199 |
+
# sample using standard algorithm
|
| 200 |
+
@torch.no_grad()
|
| 201 |
+
def sample_ddpm(n_sample, save_rate=20):
|
| 202 |
+
# x_T ~ N(0, 1), sample initial noise
|
| 203 |
+
samples = torch.randn(n_sample, 3, height, height).to(device)
|
| 204 |
+
|
| 205 |
+
# array to keep track of generated steps for plotting
|
| 206 |
+
intermediate = []
|
| 207 |
+
for i in range(timesteps, 0, -1):
|
| 208 |
+
print(f'sampling timestep {i:3d}', end='\r')
|
| 209 |
+
|
| 210 |
+
# reshape time tensor
|
| 211 |
+
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
|
| 212 |
+
|
| 213 |
+
# sample some random noise to inject back in. For i = 1, don't add back in noise
|
| 214 |
+
z = torch.randn_like(samples) if i > 1 else 0
|
| 215 |
+
|
| 216 |
+
eps = nn_model(samples, t) # predict noise e_(x_t,t)
|
| 217 |
+
samples = denoise_add_noise(samples, i, eps, z)
|
| 218 |
+
if i % save_rate ==0 or i==timesteps or i<8:
|
| 219 |
+
intermediate.append(samples.detach().cpu().numpy())
|
| 220 |
+
|
| 221 |
+
intermediate = np.stack(intermediate)
|
| 222 |
+
return samples, intermediate
|
| 223 |
+
|
| 224 |
def greet(input):
|
| 225 |
steps = int(input)
|
| 226 |
#samples, intermediate = sample_ddim(32, n=steps)
|
| 227 |
+
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|
| 228 |
+
#samples, intermediate = sample_ddim_context(32, ctx, steps)
|
| 229 |
+
samples, intermediate = sample_ddpm(32, )
|
| 230 |
#response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
|
| 231 |
#response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
|
| 232 |
#response = im.fromarray(intermediate[24][0][1]).convert("RGB")
|