test
Browse files
mar.py
CHANGED
|
@@ -309,6 +309,8 @@ class MARBert(nn.Module):
|
|
| 309 |
mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
|
| 310 |
|
| 311 |
# sample token latents for this step
|
|
|
|
|
|
|
| 312 |
z = z[mask_to_pred.nonzero(as_tuple=True)]
|
| 313 |
print(z.shape)
|
| 314 |
# cfg schedule follow Muse
|
|
|
|
| 309 |
mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
|
| 310 |
|
| 311 |
# sample token latents for this step
|
| 312 |
+
print(z.shape)
|
| 313 |
+
print("-----------")
|
| 314 |
z = z[mask_to_pred.nonzero(as_tuple=True)]
|
| 315 |
print(z.shape)
|
| 316 |
# cfg schedule follow Muse
|