Spaces:
Sleeping
Sleeping
File size: 807 Bytes
1cc095d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | """
forward_process.py — Verified Correct (no changes needed)
===========================================================
Absorbing (mask) diffusion. PAD never masked. At t=0 alpha=1.0 exactly
so x_t == x_0 (nothing masked). Works correctly with the fixed scheduler.
"""
import torch
class AbsorbingForwardProcess:
def __init__(self, scheduler, mask_id=0, pad_id=1):
self.scheduler = scheduler
self.mask_id = mask_id
self.pad_id = pad_id
def q_sample(self, x_0, t):
alpha_t = self.scheduler.get_alpha(t).to(x_0.device).view(-1, 1)
r = torch.rand(x_0.shape, device=x_0.device)
x_t = x_0.clone()
x_t[r > alpha_t] = self.mask_id
x_t[x_0 == self.pad_id] = self.pad_id # PAD stays PAD always
return x_0, x_t |