DevaFlow / diffusion /forward_process.py
bhsinghgrid's picture
Add files using upload-large-folder tool
7d6a683 verified
raw
history blame contribute delete
807 Bytes
"""
forward_process.py — Verified Correct (no changes needed)
===========================================================
Absorbing (mask) diffusion. PAD never masked. At t=0 alpha=1.0 exactly
so x_t == x_0 (nothing masked). Works correctly with the fixed scheduler.
"""
import torch
class AbsorbingForwardProcess:
def __init__(self, scheduler, mask_id=0, pad_id=1):
self.scheduler = scheduler
self.mask_id = mask_id
self.pad_id = pad_id
def q_sample(self, x_0, t):
alpha_t = self.scheduler.get_alpha(t).to(x_0.device).view(-1, 1)
r = torch.rand(x_0.shape, device=x_0.device)
x_t = x_0.clone()
x_t[r > alpha_t] = self.mask_id
x_t[x_0 == self.pad_id] = self.pad_id # PAD stays PAD always
return x_0, x_t