2ira's picture
offline_compression_graph_code
72c0672 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn import functional as F
from lingua.transformer import InitStdFactor, RMSNorm
from lingua.probe import log_stats
from apps.fastRNN.component.rnn_common import conv1d, scan
@dataclass
class BaseMinLSTMArgs:
dim: int = 512
n_layers: int = 8
n_heads: int = 1
multiple_of: int = 256
ffn_dim_multiplier: Optional[float] = None
conv_size: Optional[int] = None
norm_eps: float = 1e-5
init_base_std: Optional[float] = None
init_std_factor: str = "disabled"
class LSTM(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int, # h_t dim (state expansion)
n_heads: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
conv_size: Optional[int] = None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
assert (
hidden_dim % n_heads == 0
), f"Hidden dim must be divisible by n_heads: {hidden_dim} % {n_heads} != 0"
self.dim = dim
self.hidden_dim = hidden_dim
self.n_heads = n_heads
self.head_dim = hidden_dim // n_heads
self.conv_size = conv_size
if conv_size is not None:
assert ((self.hidden_dim) % 8 == 0) and (
conv_size in [2, 3, 4]
), f"Causal conv1d only supports conv_size in [2, 3, 4] and hidden_dim % 8 == 0, got {self.hidden_dim} and {conv_size}"
self.conv_dim = 2 * self.hidden_dim
self.conv_weight = nn.Parameter(torch.empty((self.conv_dim, conv_size)))
self.w = nn.Linear(
dim,
hidden_dim,
bias=False,
)
self.wfi = nn.Linear(
dim,
2 * hidden_dim,
bias=False,
)
self.wh_tilde = nn.Linear(
dim,
hidden_dim,
bias=False,
)
self.wo = nn.Linear(
hidden_dim,
dim,
bias=False,
)
def forward(
self, x: torch.Tensor, tok_idx: torch.Tensor, cu_seqlens: torch.Tensor, impl: str = "parallel"
) -> torch.Tensor:
bsz, seq_len, _ = x.shape
w0 = self.w(x.view_as(x))
fi = self.wfi(x.view_as(x)).transpose(1, 2)
h_tilde = self.wh_tilde(x.view_as(x)).transpose(1, 2)
if self.conv_size is not None:
conv1d_w = log_stats(self.conv_weight, "conv1d.w")
fi = conv1d(
x=fi,
conv_weight=conv1d_w,
tok_idx=tok_idx,
cu_seqlens=cu_seqlens,
impl=impl,
cache=self.cache.conv_cache if hasattr(self, "cache") else None,
)
fi = fi.reshape(bsz * self.n_heads, 2 * self.head_dim, seq_len)
h_tilde = h_tilde.reshape(bsz * self.n_heads, self.head_dim, seq_len)
f, i = fi.chunk(2, dim=1)
f, i = F.sigmoid(f), F.sigmoid(i)
denom = 1 / (f + i + 1e-4)
h = scan(
a=(f * denom),
b=(h_tilde * i * denom),
cu_seqlens=cu_seqlens,
impl=impl,
cache=self.cache.state_cache if hasattr(self, "cache") else None,
)
h = h.view(bsz, self.hidden_dim, seq_len).transpose(1, 2)
h = log_stats(h, "hidden_state")
h = h * F.silu(w0)
out = self.wo(h)
return out
def reset_parameters(self, init_std, factor):
in_init_std = init_std or (self.dim ** (-0.5))
out_init_std = init_std or (self.hidden_dim ** (-0.5))
out_init_std = out_init_std / factor
for w in [self.w, self.wfi, self.wh_tilde]:
nn.init.trunc_normal_(
w.weight, std=in_init_std, a=-3 * in_init_std, b=3 * in_init_std
)
nn.init.trunc_normal_(
self.wo.weight, std=out_init_std, a=-3 * in_init_std, b=3 * in_init_std
)
if self.conv_size is not None:
conv_std = init_std or (self.conv_size ** (-0.5))
nn.init.trunc_normal_(
self.conv_weight,
mean=0.0,
std=conv_std,
a=-3 * conv_std,
b=3 * conv_std,
)
class LSTMBlock(nn.Module):
def __init__(self, args: BaseMinLSTMArgs):
super().__init__()
self.lstm_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.lstm = LSTM(
dim=args.dim,
hidden_dim=3 * args.dim,
n_heads=args.n_heads,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
conv_size=args.conv_size,
)
def forward(
self, x: torch.Tensor, tok_idx: torch.Tensor, cu_seqlens: torch.Tensor, impl: str = "parallel"
) -> torch.Tensor:
x = x + self.lstm(self.lstm_norm(x), tok_idx=tok_idx, cu_seqlens=cu_seqlens, impl=impl)
return x
def init_weights(self, init_std: Optional[float], factor: InitStdFactor):
self.lstm.reset_parameters(init_std, factor)
self.lstm_norm.reset_parameters()
class BaseMinLSTM(nn.Module):
def __init__(self, args: BaseMinLSTMArgs):
super().__init__()
self.dim = args.dim
self.init_base_std = args.init_base_std
self.init_std_factor = InitStdFactor(args.init_std_factor)
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(LSTMBlock(args))
def forward(
self, x: torch.Tensor, tok_idx: torch.Tensor, cu_seqlens: torch.Tensor, impl: str = "parallel"
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, tok_idx=tok_idx, cu_seqlens=cu_seqlens, impl=impl)
return x
def reset_parameters(self):
pass
def init_weights(self):
self.reset_parameters()
for depth, layer in enumerate(self.layers):
factor = {
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
InitStdFactor.DIM_RATIO: self.dim / 4096,
InitStdFactor.DISABLED: 1.0,
}[self.init_std_factor]
layer.init_weights(self.init_base_std, factor)