| """Noise schedules for masked discrete diffusion. | |
| alpha(t) is the retention probability: alpha(0)=1 (clean), alpha(1)=0 (fully masked). | |
| """ | |
| from __future__ import annotations | |
| from typing import Callable | |
| import jax.numpy as jnp | |
| ScheduleFn = Callable[[jnp.ndarray], jnp.ndarray] | |
| def linear_schedule(t: jnp.ndarray) -> jnp.ndarray: | |
| """alpha(t) = 1 - t. Default in MDLM / ReMDM.""" | |
| return 1.0 - t | |
| def linear_schedule_deriv(t: jnp.ndarray) -> jnp.ndarray: | |
| return jnp.full_like(t, -1.0) | |
| def cosine_schedule(t: jnp.ndarray) -> jnp.ndarray: | |
| """alpha(t) = cos(pi * t / 2).""" | |
| return jnp.cos(t * jnp.pi / 2.0) | |
| def cosine_schedule_deriv(t: jnp.ndarray) -> jnp.ndarray: | |
| return -(jnp.pi / 2.0) * jnp.sin(t * jnp.pi / 2.0) | |
| SCHEDULE_MAP: dict[str, tuple[ScheduleFn, ScheduleFn]] = { | |
| "linear": (linear_schedule, linear_schedule_deriv), | |
| "cosine": (cosine_schedule, cosine_schedule_deriv), | |
| } | |