File size: 2,167 Bytes
c81d8be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Noise schedule functions for MDLM diffusion.

Ported from the Craftax JAX implementation (src/diffusion/schedules.py).
All functions operate on PyTorch tensors and are pure (no global state).

Convention: alpha(t) is the fraction of tokens that remain *unmasked*.
  - alpha(0) = 1.0  (fully clean)
  - alpha(1) = 0.0  (fully masked)
"""

from __future__ import annotations

import math
from typing import Callable

import torch
from torch import Tensor


def linear_schedule(t: Tensor) -> Tensor:
    """Linear noise schedule: alpha(t) = 1 - t.

    Args:
        t: Diffusion time in [0, 1]. Any shape.

    Returns:
        Retention probability alpha_t, same shape as *t*.
    """
    return 1.0 - t


def cosine_schedule(t: Tensor) -> Tensor:
    """Cosine noise schedule: alpha(t) = cos(pi/2 * t)^2.

    Args:
        t: Diffusion time in [0, 1]. Any shape.

    Returns:
        Retention probability alpha_t, same shape as *t*.
    """
    return torch.cos(t * (math.pi / 2.0)) ** 2


_SCHEDULE_MAP: dict[str, Callable[[Tensor], Tensor]] = {
    "linear": linear_schedule,
    "cosine": cosine_schedule,
}


def get_schedule(name: str) -> Callable[[Tensor], Tensor]:
    """Look up a noise schedule by name.

    Args:
        name: One of ``"linear"`` or ``"cosine"``.

    Returns:
        The schedule function ``alpha(t)``.

    Raises:
        KeyError: If *name* is not registered.
    """
    if name not in _SCHEDULE_MAP:
        raise KeyError(
            f"Unknown schedule '{name}'. "
            f"Available: {list(_SCHEDULE_MAP.keys())}"
        )
    return _SCHEDULE_MAP[name]


def alpha_prime(
    t: Tensor,
    schedule_fn: Callable[[Tensor], Tensor],
    eps: float = 1e-5,
) -> Tensor:
    """Numerical derivative d(alpha)/dt via central difference.

    Args:
        t: Diffusion time in [0, 1]. Any shape.
        schedule_fn: Noise schedule returning alpha(t).
        eps: Half-width for finite-difference stencil.

    Returns:
        Approximate derivative, same shape as *t*.
    """
    t_clamped = t.clamp(eps, 1.0 - eps)
    return (schedule_fn(t_clamped + eps) - schedule_fn(t_clamped - eps)) / (
        2.0 * eps
    )