test
Browse files
mar.py
CHANGED
|
@@ -310,6 +310,7 @@ class MARBert(nn.Module):
|
|
| 310 |
|
| 311 |
# sample token latents for this step
|
| 312 |
z = z[mask_to_pred.nonzero(as_tuple=True)]
|
|
|
|
| 313 |
# cfg schedule follow Muse
|
| 314 |
if cfg_schedule == "linear":
|
| 315 |
cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
|
|
|
|
| 310 |
|
| 311 |
# sample token latents for this step
|
| 312 |
z = z[mask_to_pred.nonzero(as_tuple=True)]
|
| 313 |
+
print(z.shape)
|
| 314 |
# cfg schedule follow Muse
|
| 315 |
if cfg_schedule == "linear":
|
| 316 |
cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
|