Spaces:
Sleeping
Sleeping
File size: 1,225 Bytes
fe9397b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | # app/models/decoding/base.py
from abc import ABC, abstractmethod
import torch
from models.runtime.state import GenerationState
class DecodingPolicy(ABC):
"""
Token-level decoding policy.
Responsibilities:
- Select the next token from logits
- Decide whether generation should stop
This policy MUST be stateless or derive state only
from GenerationState.
"""
supports_sampling: bool = False
supports_streaming: bool = False
@abstractmethod
def select(
self,
logits: torch.Tensor,
state: GenerationState,
) -> int:
"""
Select the next token id from model logits.
Args:
logits: [batch, vocab] tensor
state: current GenerationState
Returns:
token_id (int)
"""
raise NotImplementedError
@abstractmethod
def should_stop(
self,
token_id: int,
state: GenerationState,
) -> bool:
"""
Determine whether generation should terminate.
Args:
token_id: selected token id
state: current GenerationState
Returns:
bool
"""
raise NotImplementedError
|