Spaces:
Sleeping
Sleeping
| from typing import Optional, Union, List, Tuple, Dict | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import treetensor.torch as ttorch | |
| import ding | |
| from ding.torch_utils.network.normalization import build_normalization | |
| if ding.enable_hpc_rl: | |
| from hpc_rll.torch_utils.network.rnn import LSTM as HPCLSTM | |
| else: | |
| HPCLSTM = None | |
| def is_sequence(data): | |
| """ | |
| Overview: | |
| Determines if the input data is of type list or tuple. | |
| Arguments: | |
| - data: The input data to be checked. | |
| Returns: | |
| - boolean: True if the input is a list or a tuple, False otherwise. | |
| """ | |
| return isinstance(data, list) or isinstance(data, tuple) | |
| def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] = None) -> torch.BoolTensor: | |
| """ | |
| Overview: | |
| Generates a boolean mask for a batch of sequences with differing lengths. | |
| Arguments: | |
| - lengths (:obj:`torch.Tensor`): A tensor with the lengths of each sequence. Shape could be (n, 1) or (n). | |
| - max_len (:obj:`int`, optional): The padding size. If max_len is None, the padding size is the max length of \ | |
| sequences. | |
| Returns: | |
| - masks (:obj:`torch.BoolTensor`): A boolean mask tensor. The mask has the same device as lengths. | |
| """ | |
| if len(lengths.shape) == 1: | |
| lengths = lengths.unsqueeze(dim=1) | |
| bz = lengths.numel() | |
| if max_len is None: | |
| max_len = lengths.max() | |
| else: | |
| max_len = min(max_len, lengths.max()) | |
| return torch.arange(0, max_len).type_as(lengths).repeat(bz, 1).lt(lengths).to(lengths.device) | |
| class LSTMForwardWrapper(object): | |
| """ | |
| Overview: | |
| Class providing methods to use before and after the LSTM `forward` method. | |
| Wraps the LSTM `forward` method. | |
| Interfaces: | |
| ``_before_forward``, ``_after_forward`` | |
| """ | |
| def _before_forward(self, inputs: torch.Tensor, prev_state: Union[None, List[Dict]]) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Preprocesses the inputs and previous states before the LSTM `forward` method. | |
| Arguments: | |
| - inputs (:obj:`torch.Tensor`): Input vector of the LSTM cell. Shape: [seq_len, batch_size, input_size] | |
| - prev_state (:obj:`Union[None, List[Dict]]`): Previous state tensor. Shape: [num_directions*num_layers, \ | |
| batch_size, hidden_size]. If None, prv_state will be initialized to all zeros. | |
| Returns: | |
| - prev_state (:obj:`torch.Tensor`): Preprocessed previous state for the LSTM batch. | |
| """ | |
| assert hasattr(self, 'num_layers') | |
| assert hasattr(self, 'hidden_size') | |
| seq_len, batch_size = inputs.shape[:2] | |
| if prev_state is None: | |
| num_directions = 1 | |
| zeros = torch.zeros( | |
| num_directions * self.num_layers, | |
| batch_size, | |
| self.hidden_size, | |
| dtype=inputs.dtype, | |
| device=inputs.device | |
| ) | |
| prev_state = (zeros, zeros) | |
| elif is_sequence(prev_state): | |
| if len(prev_state) != batch_size: | |
| raise RuntimeError( | |
| "prev_state number is not equal to batch_size: {}/{}".format(len(prev_state), batch_size) | |
| ) | |
| num_directions = 1 | |
| zeros = torch.zeros( | |
| num_directions * self.num_layers, 1, self.hidden_size, dtype=inputs.dtype, device=inputs.device | |
| ) | |
| state = [] | |
| for prev in prev_state: | |
| if prev is None: | |
| state.append([zeros, zeros]) | |
| else: | |
| if isinstance(prev, (Dict, ttorch.Tensor)): | |
| state.append([v for v in prev.values()]) | |
| else: | |
| state.append(prev) | |
| state = list(zip(*state)) | |
| prev_state = [torch.cat(t, dim=1) for t in state] | |
| elif isinstance(prev_state, dict): | |
| prev_state = list(prev_state.values()) | |
| else: | |
| raise TypeError("not support prev_state type: {}".format(type(prev_state))) | |
| return prev_state | |
| def _after_forward(self, | |
| next_state: Tuple[torch.Tensor], | |
| list_next_state: bool = False) -> Union[List[Dict], Dict[str, torch.Tensor]]: | |
| """ | |
| Overview: | |
| Post-processes the next_state after the LSTM `forward` method. | |
| Arguments: | |
| - next_state (:obj:`Tuple[torch.Tensor]`): Tuple containing the next state (h, c). | |
| - list_next_state (:obj:`bool`, optional): Determines the format of the returned next_state. \ | |
| If True, returns next_state in list format. Default is False. | |
| Returns: | |
| - next_state(:obj:`Union[List[Dict], Dict[str, torch.Tensor]]`): The post-processed next_state. | |
| """ | |
| if list_next_state: | |
| h, c = next_state | |
| batch_size = h.shape[1] | |
| next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] | |
| next_state = list(zip(*next_state)) | |
| next_state = [{k: v for k, v in zip(['h', 'c'], item)} for item in next_state] | |
| else: | |
| next_state = {k: v for k, v in zip(['h', 'c'], next_state)} | |
| return next_state | |
| class LSTM(nn.Module, LSTMForwardWrapper): | |
| """ | |
| Overview: | |
| Implementation of an LSTM cell with Layer Normalization (LN). | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| .. note:: | |
| For a primer on LSTM, refer to https://zhuanlan.zhihu.com/p/32085405. | |
| """ | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_size: int, | |
| num_layers: int, | |
| norm_type: Optional[str] = None, | |
| dropout: float = 0. | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize LSTM cell parameters. | |
| Arguments: | |
| - input_size (:obj:`int`): Size of the input vector. | |
| - hidden_size (:obj:`int`): Size of the hidden state vector. | |
| - num_layers (:obj:`int`): Number of LSTM layers. | |
| - norm_type (:obj:`Optional[str]`): Normalization type, default is None. | |
| - dropout (:obj:`float`): Dropout rate, default is 0. | |
| """ | |
| super(LSTM, self).__init__() | |
| self.input_size = input_size | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| norm_func = build_normalization(norm_type) | |
| self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)]) | |
| self.wx = nn.ParameterList() | |
| self.wh = nn.ParameterList() | |
| dims = [input_size] + [hidden_size] * num_layers | |
| for l in range(num_layers): | |
| self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4))) | |
| self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4))) | |
| self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4)) | |
| self.use_dropout = dropout > 0. | |
| if self.use_dropout: | |
| self.dropout = nn.Dropout(dropout) | |
| self._init() | |
| def _init(self): | |
| """ | |
| Overview: | |
| Initialize the parameters of the LSTM cell. | |
| """ | |
| gain = math.sqrt(1. / self.hidden_size) | |
| for l in range(self.num_layers): | |
| torch.nn.init.uniform_(self.wx[l], -gain, gain) | |
| torch.nn.init.uniform_(self.wh[l], -gain, gain) | |
| if self.bias is not None: | |
| torch.nn.init.uniform_(self.bias[l], -gain, gain) | |
| def forward(self, | |
| inputs: torch.Tensor, | |
| prev_state: torch.Tensor, | |
| list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]: | |
| """ | |
| Overview: | |
| Compute output and next state given previous state and input. | |
| Arguments: | |
| - inputs (:obj:`torch.Tensor`): Input vector of cell, size [seq_len, batch_size, input_size]. | |
| - prev_state (:obj:`torch.Tensor`): Previous state, \ | |
| size [num_directions*num_layers, batch_size, hidden_size]. | |
| - list_next_state (:obj:`bool`): Whether to return next_state in list format, default is True. | |
| Returns: | |
| - x (:obj:`torch.Tensor`): Output from LSTM. | |
| - next_state (:obj:`Union[torch.Tensor, list]`): Hidden state from LSTM. | |
| """ | |
| seq_len, batch_size = inputs.shape[:2] | |
| prev_state = self._before_forward(inputs, prev_state) | |
| H, C = prev_state | |
| x = inputs | |
| next_state = [] | |
| for l in range(self.num_layers): | |
| h, c = H[l], C[l] | |
| new_x = [] | |
| for s in range(seq_len): | |
| gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l]) | |
| ) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l])) | |
| if self.bias is not None: | |
| gate += self.bias[l] | |
| gate = list(torch.chunk(gate, 4, dim=1)) | |
| i, f, o, u = gate | |
| i = torch.sigmoid(i) | |
| f = torch.sigmoid(f) | |
| o = torch.sigmoid(o) | |
| u = torch.tanh(u) | |
| c = f * c + i * u | |
| h = o * torch.tanh(c) | |
| new_x.append(h) | |
| next_state.append((h, c)) | |
| x = torch.stack(new_x, dim=0) | |
| if self.use_dropout and l != self.num_layers - 1: | |
| x = self.dropout(x) | |
| next_state = [torch.stack(t, dim=0) for t in zip(*next_state)] | |
| next_state = self._after_forward(next_state, list_next_state) | |
| return x, next_state | |
| class PytorchLSTM(nn.LSTM, LSTMForwardWrapper): | |
| """ | |
| Overview: | |
| Wrapper class for PyTorch's nn.LSTM, formats the input and output. For more details on nn.LSTM, | |
| refer to https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM | |
| Interfaces: | |
| ``forward`` | |
| """ | |
| def forward(self, | |
| inputs: torch.Tensor, | |
| prev_state: torch.Tensor, | |
| list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]: | |
| """ | |
| Overview: | |
| Executes nn.LSTM.forward with preprocessed input. | |
| Arguments: | |
| - inputs (:obj:`torch.Tensor`): Input vector of cell, size [seq_len, batch_size, input_size]. | |
| - prev_state (:obj:`torch.Tensor`): Previous state, size [num_directions*num_layers, batch_size, \ | |
| hidden_size]. | |
| - list_next_state (:obj:`bool`): Whether to return next_state in list format, default is True. | |
| Returns: | |
| - output (:obj:`torch.Tensor`): Output from LSTM. | |
| - next_state (:obj:`Union[torch.Tensor, list]`): Hidden state from LSTM. | |
| """ | |
| prev_state = self._before_forward(inputs, prev_state) | |
| output, next_state = nn.LSTM.forward(self, inputs, prev_state) | |
| next_state = self._after_forward(next_state, list_next_state) | |
| return output, next_state | |
| class GRU(nn.GRUCell, LSTMForwardWrapper): | |
| """ | |
| Overview: | |
| This class extends the `torch.nn.GRUCell` and `LSTMForwardWrapper` classes, and formats inputs and outputs | |
| accordingly. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| Properties: | |
| hidden_size, num_layers | |
| .. note:: | |
| For further details, refer to the official PyTorch documentation: | |
| <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU> | |
| """ | |
| def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None: | |
| """ | |
| Overview: | |
| Initialize the GRU class with input size, hidden size, and number of layers. | |
| Arguments: | |
| - input_size (:obj:`int`): The size of the input vector. | |
| - hidden_size (:obj:`int`): The size of the hidden state vector. | |
| - num_layers (:obj:`int`): The number of GRU layers. | |
| """ | |
| super(GRU, self).__init__(input_size, hidden_size) | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| def forward(self, | |
| inputs: torch.Tensor, | |
| prev_state: Optional[torch.Tensor] = None, | |
| list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, List]]: | |
| """ | |
| Overview: | |
| Wrap the `nn.GRU.forward` method. | |
| Arguments: | |
| - inputs (:obj:`torch.Tensor`): Input vector of cell, tensor of size [seq_len, batch_size, input_size]. | |
| - prev_state (:obj:`Optional[torch.Tensor]`): None or tensor of \ | |
| size [num_directions*num_layers, batch_size, hidden_size]. | |
| - list_next_state (:obj:`bool`): Whether to return next_state in list format (default is True). | |
| Returns: | |
| - output (:obj:`torch.Tensor`): Output from GRU. | |
| - next_state (:obj:`torch.Tensor` or :obj:`list`): Hidden state from GRU. | |
| """ | |
| # for compatibility | |
| prev_state, _ = self._before_forward(inputs, prev_state) | |
| inputs, prev_state = inputs.squeeze(0), prev_state.squeeze(0) | |
| next_state = nn.GRUCell.forward(self, inputs, prev_state) | |
| next_state = next_state.unsqueeze(0) | |
| x = next_state | |
| # for compatibility | |
| next_state = self._after_forward([next_state, next_state.clone()], list_next_state) | |
| return x, next_state | |
| def get_lstm( | |
| lstm_type: str, | |
| input_size: int, | |
| hidden_size: int, | |
| num_layers: int = 1, | |
| norm_type: str = 'LN', | |
| dropout: float = 0., | |
| seq_len: Optional[int] = None, | |
| batch_size: Optional[int] = None | |
| ) -> Union[LSTM, PytorchLSTM]: | |
| """ | |
| Overview: | |
| Build and return the corresponding LSTM cell based on the provided parameters. | |
| Arguments: | |
| - lstm_type (:obj:`str`): Version of RNN cell. Supported options are ['normal', 'pytorch', 'hpc', 'gru']. | |
| - input_size (:obj:`int`): Size of the input vector. | |
| - hidden_size (:obj:`int`): Size of the hidden state vector. | |
| - num_layers (:obj:`int`): Number of LSTM layers (default is 1). | |
| - norm_type (:obj:`str`): Type of normalization (default is 'LN'). | |
| - dropout (:obj:`float`): Dropout rate (default is 0.0). | |
| - seq_len (:obj:`Optional[int]`): Sequence length (default is None). | |
| - batch_size (:obj:`Optional[int]`): Batch size (default is None). | |
| Returns: | |
| - lstm (:obj:`Union[LSTM, PytorchLSTM]`): The corresponding LSTM cell. | |
| """ | |
| assert lstm_type in ['normal', 'pytorch', 'hpc', 'gru'] | |
| if lstm_type == 'normal': | |
| return LSTM(input_size, hidden_size, num_layers, norm_type, dropout=dropout) | |
| elif lstm_type == 'pytorch': | |
| return PytorchLSTM(input_size, hidden_size, num_layers, dropout=dropout) | |
| elif lstm_type == 'hpc': | |
| return HPCLSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout).cuda() | |
| elif lstm_type == 'gru': | |
| assert num_layers == 1 | |
| return GRU(input_size, hidden_size, num_layers) | |