import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: norm_x = torch.mean(x * x, dim=-1, keepdim=True) x_normed = x * torch.rsqrt(norm_x + self.eps) return self.scale * x_normed class MLP(nn.Module): def __init__(self, n_embd: int): super().__init__() hidden_dim = 4 * n_embd n_hidden = int(2 * hidden_dim / 3) self.c_fc1 = nn.Linear(n_embd, n_hidden, bias=False) self.c_fc2 = nn.Linear(n_embd, n_hidden, bias=False) self.c_proj = nn.Linear(n_hidden, n_embd, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.c_proj(F.silu(self.c_fc1(x)) * self.c_fc2(x)) class GroupedRNN(nn.Module): def __init__(self, n_embd: int, n_layer: int, n_groups: int, rnn_type: str = 'gru', bidirectional: bool = False): super().__init__() assert n_embd % n_groups == 0, "n_embd must be divisible by n_groups" self.n_groups = n_groups self.group_size = n_embd // n_groups rnn_type = rnn_type.lower() if rnn_type not in ['gru', 'lstm']: raise ValueError("rnn_type must be 'gru' or 'lstm'") rnn_class = nn.GRU if rnn_type == 'gru' else nn.LSTM self.rnns = nn.ModuleList() for _ in range(self.n_groups): if bidirectional: rnn_hidden_size = self.group_size // 2 else: rnn_hidden_size = self.group_size self.rnns.append( rnn_class( input_size=self.group_size, hidden_size=rnn_hidden_size, num_layers=n_layer, bias=False, batch_first=True, bidirectional=bidirectional ) ) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (batch, seq, n_embd) B, T, D = x.shape # (B, T, D) -> (B, T, G, group_size) x_groups = x.view(B, T, self.n_groups, self.group_size) rnn_outputs = [] for g in range(self.n_groups): x_g = x_groups[:, :, g, :] rnn_out_g, _ = self.rnns[g](x_g) rnn_outputs.append(rnn_out_g) rnn_output = torch.cat(rnn_outputs, dim=-1) return rnn_output class RNNBlock(nn.Module): def __init__(self, n_embd: int, n_layer: int, n_groups: int, rnn_type: str = 'gru', bidirectional: bool = False): super().__init__() self.rnn_norm = RMSNorm(n_embd) self.rnn = GroupedRNN(n_embd, n_layer, n_groups, rnn_type, bidirectional) self.ffn_norm = RMSNorm(n_embd) self.mlp = MLP(n_embd) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.rnn(self.rnn_norm(x)) x = x + self.mlp(self.ffn_norm(x)) return x if __name__ == '__main__': n_embd = 512 batch_size = 4 seq_len = 100 n_groups = 4 device = 'cuda' if torch.cuda.is_available() else 'cpu' input_tensor = torch.randn(batch_size, seq_len, n_embd).to(device) n_layer_rnn = 2 model_gru = RNNBlock( n_embd=n_embd, n_layer=n_layer_rnn, n_groups=n_groups, rnn_type='gru' ) model_gru.to(device) model_gru.eval() with torch.no_grad(): output_gru = model_gru(input_tensor) print(f"Input shape: {input_tensor.shape}") print(f"Output shape (GRU): {output_gru.shape}") assert output_gru.shape == (batch_size, seq_len, n_embd) model_bigru = RNNBlock( n_embd=n_embd, n_layer=n_layer_rnn, n_groups=n_groups, rnn_type='gru', bidirectional=True ) model_bigru.to(device) model_bigru.eval() with torch.no_grad(): output_bigru = model_bigru(input_tensor) print(f"Input shape: {input_tensor.shape}") print(f"Output shape (BiGRU): {output_bigru.shape}") assert output_bigru.shape == (batch_size, seq_len, n_embd)