File size: 921 Bytes
6140064 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | """Noise schedules for masked discrete diffusion.
alpha(t) is the retention probability: alpha(0)=1 (clean), alpha(1)=0 (fully masked).
"""
from __future__ import annotations
from typing import Callable
import jax.numpy as jnp
ScheduleFn = Callable[[jnp.ndarray], jnp.ndarray]
def linear_schedule(t: jnp.ndarray) -> jnp.ndarray:
"""alpha(t) = 1 - t. Default in MDLM / ReMDM."""
return 1.0 - t
def linear_schedule_deriv(t: jnp.ndarray) -> jnp.ndarray:
return jnp.full_like(t, -1.0)
def cosine_schedule(t: jnp.ndarray) -> jnp.ndarray:
"""alpha(t) = cos(pi * t / 2)."""
return jnp.cos(t * jnp.pi / 2.0)
def cosine_schedule_deriv(t: jnp.ndarray) -> jnp.ndarray:
return -(jnp.pi / 2.0) * jnp.sin(t * jnp.pi / 2.0)
SCHEDULE_MAP: dict[str, tuple[ScheduleFn, ScheduleFn]] = {
"linear": (linear_schedule, linear_schedule_deriv),
"cosine": (cosine_schedule, cosine_schedule_deriv),
}
|