Spaces:
Runtime error
Runtime error
ctx
Browse files
app.py
CHANGED
|
@@ -223,7 +223,17 @@ def sample_ddpm(n_sample, save_rate=20):
|
|
| 223 |
|
| 224 |
def greet(input):
|
| 225 |
steps = int(input)
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|
| 228 |
#samples, intermediate = sample_ddim_context(32, ctx, steps)
|
| 229 |
#samples, intermediate = sample_ddpm(steps)
|
|
|
|
| 223 |
|
| 224 |
def greet(input):
|
| 225 |
steps = int(input)
|
| 226 |
+
|
| 227 |
+
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|
| 228 |
+
|
| 229 |
+
# hero, non-hero, food, spell, side-facing
|
| 230 |
+
shape = (32, 5)
|
| 231 |
+
mtx_2d = np.ones(shape) * one_hot_enc
|
| 232 |
+
ctx = mtx_2d.to(device=device).float()
|
| 233 |
+
|
| 234 |
+
samples, intermediate = sample_ddim_ctx(32, ctx, n=steps)
|
| 235 |
+
|
| 236 |
+
#samples, intermediate = sample_ddim(32, n=steps)
|
| 237 |
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
|
| 238 |
#samples, intermediate = sample_ddim_context(32, ctx, steps)
|
| 239 |
#samples, intermediate = sample_ddpm(steps)
|