# app/models/decoding/base.py from abc import ABC, abstractmethod import torch from models.runtime.state import GenerationState class DecodingPolicy(ABC): """ Token-level decoding policy. Responsibilities: - Select the next token from logits - Decide whether generation should stop This policy MUST be stateless or derive state only from GenerationState. """ supports_sampling: bool = False supports_streaming: bool = False @abstractmethod def select( self, logits: torch.Tensor, state: GenerationState, ) -> int: """ Select the next token id from model logits. Args: logits: [batch, vocab] tensor state: current GenerationState Returns: token_id (int) """ raise NotImplementedError @abstractmethod def should_stop( self, token_id: int, state: GenerationState, ) -> bool: """ Determine whether generation should terminate. Args: token_id: selected token id state: current GenerationState Returns: bool """ raise NotImplementedError