remdm-craftax / src /diffusion /schedules.py
MathisW78's picture
Upload COMP0258 demo bundle (code + diffusion/PPO checkpoints + ablation assets)
6140064 verified
"""Noise schedules for masked discrete diffusion.
alpha(t) is the retention probability: alpha(0)=1 (clean), alpha(1)=0 (fully masked).
"""
from __future__ import annotations
from typing import Callable
import jax.numpy as jnp
ScheduleFn = Callable[[jnp.ndarray], jnp.ndarray]
def linear_schedule(t: jnp.ndarray) -> jnp.ndarray:
"""alpha(t) = 1 - t. Default in MDLM / ReMDM."""
return 1.0 - t
def linear_schedule_deriv(t: jnp.ndarray) -> jnp.ndarray:
return jnp.full_like(t, -1.0)
def cosine_schedule(t: jnp.ndarray) -> jnp.ndarray:
"""alpha(t) = cos(pi * t / 2)."""
return jnp.cos(t * jnp.pi / 2.0)
def cosine_schedule_deriv(t: jnp.ndarray) -> jnp.ndarray:
return -(jnp.pi / 2.0) * jnp.sin(t * jnp.pi / 2.0)
SCHEDULE_MAP: dict[str, tuple[ScheduleFn, ScheduleFn]] = {
"linear": (linear_schedule, linear_schedule_deriv),
"cosine": (cosine_schedule, cosine_schedule_deriv),
}