Spaces:
Sleeping
Sleeping
| # app/models/decoding/base.py | |
| from abc import ABC, abstractmethod | |
| import torch | |
| from models.runtime.state import GenerationState | |
| class DecodingPolicy(ABC): | |
| """ | |
| Token-level decoding policy. | |
| Responsibilities: | |
| - Select the next token from logits | |
| - Decide whether generation should stop | |
| This policy MUST be stateless or derive state only | |
| from GenerationState. | |
| """ | |
| supports_sampling: bool = False | |
| supports_streaming: bool = False | |
| def select( | |
| self, | |
| logits: torch.Tensor, | |
| state: GenerationState, | |
| ) -> int: | |
| """ | |
| Select the next token id from model logits. | |
| Args: | |
| logits: [batch, vocab] tensor | |
| state: current GenerationState | |
| Returns: | |
| token_id (int) | |
| """ | |
| raise NotImplementedError | |
| def should_stop( | |
| self, | |
| token_id: int, | |
| state: GenerationState, | |
| ) -> bool: | |
| """ | |
| Determine whether generation should terminate. | |
| Args: | |
| token_id: selected token id | |
| state: current GenerationState | |
| Returns: | |
| bool | |
| """ | |
| raise NotImplementedError | |