| """ |
| forward_process.py — Verified Correct (no changes needed) |
| =========================================================== |
| Absorbing (mask) diffusion. PAD never masked. At t=0 alpha=1.0 exactly |
| so x_t == x_0 (nothing masked). Works correctly with the fixed scheduler. |
| """ |
| import torch |
|
|
| class AbsorbingForwardProcess: |
| def __init__(self, scheduler, mask_id=0, pad_id=1): |
| self.scheduler = scheduler |
| self.mask_id = mask_id |
| self.pad_id = pad_id |
|
|
| def q_sample(self, x_0, t): |
| alpha_t = self.scheduler.get_alpha(t).to(x_0.device).view(-1, 1) |
| r = torch.rand(x_0.shape, device=x_0.device) |
| x_t = x_0.clone() |
| x_t[r > alpha_t] = self.mask_id |
| x_t[x_0 == self.pad_id] = self.pad_id |
| return x_0, x_t |