Cynthiawhaletech's picture
Initial release: W1-4B dLLM Base
267f903
"""Shared sampler types for the open-source release."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable
from torch import Tensor
@dataclass
class SamplerState:
step: int
total_steps: int
t: float
x_t: Tensor
logits: Tensor
remain_before: int
remain_after: int
selected: int
prefix_len: int
mask_token_id: int
metadata: dict = field(default_factory=dict)
StepCallback = Callable[[SamplerState], None]