"""Shared sampler types for the open-source release.""" from __future__ import annotations from dataclasses import dataclass, field from typing import Callable from torch import Tensor @dataclass class SamplerState: step: int total_steps: int t: float x_t: Tensor logits: Tensor remain_before: int remain_after: int selected: int prefix_len: int mask_token_id: int metadata: dict = field(default_factory=dict) StepCallback = Callable[[SamplerState], None]