| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Streaming module API that should be implemented by all Streaming components, |
| | """ |
| |
|
| | from contextlib import contextmanager |
| | import typing as tp |
| | from torch import nn |
| | import torch |
| |
|
| |
|
| | State = tp.Dict[str, torch.Tensor] |
| |
|
| |
|
| | class StreamingModule(nn.Module): |
| | """Common API for streaming components. |
| | |
| | Each streaming component has a streaming state, which is just a dict[str, Tensor]. |
| | By convention, the first dim of each tensor must be the batch size. |
| | Don't use dots in the key names, as this would clash with submodules |
| | (like in state_dict). |
| | |
| | If `self._is_streaming` is True, the component should use and remember |
| | the proper state inside `self._streaming_state`. |
| | |
| | To set a streaming component in streaming state, use |
| | |
| | with module.streaming(): |
| | ... |
| | |
| | This will automatically reset the streaming state when exiting the context manager. |
| | This also automatically propagates to all streaming children module. |
| | |
| | Some module might also implement the `StreamingModule.flush` method, although |
| | this one is trickier, as all parents module must be StreamingModule and implement |
| | it as well for it to work properly. See `StreamingSequential` after. |
| | """ |
| | def __init__(self) -> None: |
| | super().__init__() |
| | self._streaming_state: State = {} |
| | self._is_streaming = False |
| |
|
| | def _apply_named_streaming(self, fn: tp.Any): |
| | for name, module in self.named_modules(): |
| | if isinstance(module, StreamingModule): |
| | fn(name, module) |
| |
|
| | def _set_streaming(self, streaming: bool): |
| | def _set_streaming(name, module): |
| | module._is_streaming = streaming |
| | self._apply_named_streaming(_set_streaming) |
| |
|
| | @contextmanager |
| | def streaming(self): |
| | """Context manager to enter streaming mode. Reset streaming state on exit.""" |
| | self._set_streaming(True) |
| | try: |
| | yield |
| | finally: |
| | self._set_streaming(False) |
| | self.reset_streaming() |
| |
|
| | def reset_streaming(self): |
| | """Reset the streaming state.""" |
| | def _reset(name: str, module: StreamingModule): |
| | module._streaming_state.clear() |
| |
|
| | self._apply_named_streaming(_reset) |
| |
|
| | def get_streaming_state(self) -> State: |
| | """Return the streaming state, including that of sub-modules.""" |
| | state: State = {} |
| |
|
| | def _add(name: str, module: StreamingModule): |
| | if name: |
| | name += "." |
| | for key, value in module._streaming_state.items(): |
| | state[name + key] = value |
| |
|
| | self._apply_named_streaming(_add) |
| | return state |
| |
|
| | def set_streaming_state(self, state: State): |
| | """Set the streaming state, including that of sub-modules.""" |
| | state = dict(state) |
| |
|
| | def _set(name: str, module: StreamingModule): |
| | if name: |
| | name += "." |
| | module._streaming_state.clear() |
| | for key, value in list(state.items()): |
| | |
| | if key.startswith(name): |
| | local_key = key[len(name):] |
| | if '.' not in local_key: |
| | module._streaming_state[local_key] = value |
| | del state[key] |
| |
|
| | self._apply_named_streaming(_set) |
| | assert len(state) == 0, list(state.keys()) |
| |
|
| | def flush(self, x: tp.Optional[torch.Tensor] = None): |
| | """Flush any remaining outputs that were waiting for completion. |
| | Typically, for convolutions, this will add the final padding |
| | and process the last buffer. |
| | |
| | This should take an optional argument `x`, which will be provided |
| | if a module before this one in the streaming pipeline has already |
| | spitted out a flushed out buffer. |
| | """ |
| | if x is None: |
| | return None |
| | else: |
| | return self(x) |
| |
|
| |
|
| | class StreamingSequential(StreamingModule, nn.Sequential): |
| | """A streaming compatible alternative of `nn.Sequential`. |
| | """ |
| | def flush(self, x: tp.Optional[torch.Tensor] = None): |
| | for module in self: |
| | if isinstance(module, StreamingModule): |
| | x = module.flush(x) |
| | elif x is not None: |
| | x = module(x) |
| | return x |
| |
|