File size: 921 Bytes
6140064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""Noise schedules for masked discrete diffusion.

alpha(t) is the retention probability: alpha(0)=1 (clean), alpha(1)=0 (fully masked).
"""

from __future__ import annotations
from typing import Callable

import jax.numpy as jnp

ScheduleFn = Callable[[jnp.ndarray], jnp.ndarray]


def linear_schedule(t: jnp.ndarray) -> jnp.ndarray:
    """alpha(t) = 1 - t.  Default in MDLM / ReMDM."""
    return 1.0 - t


def linear_schedule_deriv(t: jnp.ndarray) -> jnp.ndarray:
    return jnp.full_like(t, -1.0)


def cosine_schedule(t: jnp.ndarray) -> jnp.ndarray:
    """alpha(t) = cos(pi * t / 2)."""
    return jnp.cos(t * jnp.pi / 2.0)


def cosine_schedule_deriv(t: jnp.ndarray) -> jnp.ndarray:
    return -(jnp.pi / 2.0) * jnp.sin(t * jnp.pi / 2.0)


SCHEDULE_MAP: dict[str, tuple[ScheduleFn, ScheduleFn]] = {
    "linear": (linear_schedule, linear_schedule_deriv),
    "cosine": (cosine_schedule, cosine_schedule_deriv),
}