File size: 1,444 Bytes
7d6a683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""
scheduler.py  — Fixed & Upgraded
==================================
Changes:
  1. T=64 (was 16). More timesteps = richer denoising curriculum per epoch.
  2. alpha at t=0 is EXACTLY 1.0 — fixes Bug 2 (final-step re-noise).
  3. sample_timestep samples [0, T-1] including t=0, so model trains on
     fully-clean inputs (learns the identity at t=0 explicitly).
"""
import torch, math

class OptimizedCosineScheduler:
    def __init__(self, cfg, device=None):
        self.num_timesteps  = cfg['model']['diffusion_steps']   # 64
        self.mask_token_id  = cfg['diffusion']['mask_token_id']
        self.device         = device or torch.device('cpu')
        self.alphas_cumprod = self._build_schedule().to(self.device)

    def _build_schedule(self):
        T   = self.num_timesteps
        t   = torch.arange(T + 1, dtype=torch.float32)
        f_t = torch.cos((t / T + 0.008) / 1.008 * math.pi / 2) ** 2
        alphas_bar = f_t / f_t[0]
        alphas_bar = alphas_bar[1:]       # shape [T]
        alphas_bar[0]  = 1.0              # FIX: exact 1.0 at t=0
        alphas_bar[-1] = alphas_bar[-1].clamp(max=0.001)
        return alphas_bar

    def sample_timestep(self, batch_size):
        """Uniform [0, T-1] — includes t=0 so model sees clean inputs."""
        return torch.randint(0, self.num_timesteps, (batch_size,))

    def get_alpha(self, t):
        return self.alphas_cumprod[t.to(self.alphas_cumprod.device).long()]