File size: 807 Bytes
27f26fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""
forward_process.py  — Verified Correct (no changes needed)
===========================================================
Absorbing (mask) diffusion. PAD never masked. At t=0 alpha=1.0 exactly
so x_t == x_0 (nothing masked). Works correctly with the fixed scheduler.
"""
import torch

class AbsorbingForwardProcess:
    def __init__(self, scheduler, mask_id=0, pad_id=1):
        self.scheduler = scheduler
        self.mask_id   = mask_id
        self.pad_id    = pad_id

    def q_sample(self, x_0, t):
        alpha_t = self.scheduler.get_alpha(t).to(x_0.device).view(-1, 1)
        r   = torch.rand(x_0.shape, device=x_0.device)
        x_t = x_0.clone()
        x_t[r > alpha_t]          = self.mask_id
        x_t[x_0 == self.pad_id]   = self.pad_id   # PAD stays PAD always
        return x_0, x_t