m97j's picture
add decoding interface
fe9397b
# app/models/decoding/base.py
from abc import ABC, abstractmethod
import torch
from models.runtime.state import GenerationState
class DecodingPolicy(ABC):
"""
Token-level decoding policy.
Responsibilities:
- Select the next token from logits
- Decide whether generation should stop
This policy MUST be stateless or derive state only
from GenerationState.
"""
supports_sampling: bool = False
supports_streaming: bool = False
@abstractmethod
def select(
self,
logits: torch.Tensor,
state: GenerationState,
) -> int:
"""
Select the next token id from model logits.
Args:
logits: [batch, vocab] tensor
state: current GenerationState
Returns:
token_id (int)
"""
raise NotImplementedError
@abstractmethod
def should_stop(
self,
token_id: int,
state: GenerationState,
) -> bool:
"""
Determine whether generation should terminate.
Args:
token_id: selected token id
state: current GenerationState
Returns:
bool
"""
raise NotImplementedError