| | |
| | |
| | |
| | |
| |
|
| | import uuid |
| | from typing import Dict, Optional |
| |
|
| | from torch import Tensor |
| |
|
| |
|
| | class FairseqIncrementalState(object): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.init_incremental_state() |
| |
|
| | def init_incremental_state(self): |
| | self._incremental_state_id = str(uuid.uuid4()) |
| |
|
| | def _get_full_incremental_state_key(self, key: str) -> str: |
| | return "{}.{}".format(self._incremental_state_id, key) |
| |
|
| | def get_incremental_state( |
| | self, |
| | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
| | key: str, |
| | ) -> Optional[Dict[str, Optional[Tensor]]]: |
| | """Helper for getting incremental state for an nn.Module.""" |
| | full_key = self._get_full_incremental_state_key(key) |
| | if incremental_state is None or full_key not in incremental_state: |
| | return None |
| | return incremental_state[full_key] |
| |
|
| | def set_incremental_state( |
| | self, |
| | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
| | key: str, |
| | value: Dict[str, Optional[Tensor]], |
| | ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: |
| | """Helper for setting incremental state for an nn.Module.""" |
| | if incremental_state is not None: |
| | full_key = self._get_full_incremental_state_key(key) |
| | incremental_state[full_key] = value |
| | return incremental_state |
| |
|
| |
|
| | def with_incremental_state(cls): |
| | cls.__bases__ = (FairseqIncrementalState,) + tuple( |
| | b for b in cls.__bases__ if b != FairseqIncrementalState |
| | ) |
| | return cls |
| |
|