Transformer_500M / layers.py
yagizdevre's picture
config and model are added
cbda9b7
raw
history blame
3.63 kB
import torch
import torch.nn as nn
from .modules import STU
from .modules import MLP
from .modules import Attention
try:
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
triton_mlp = True
except ImportError as e:
print(
f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead."
)
triton_mlp = False
try:
from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
triton_norm = True
except ImportError as e:
print(
f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
)
from torch.nn import RMSNorm
triton_norm = False
class STULayer(nn.Module):
def __init__(self, config, phi, n):
super(STULayer, self).__init__()
if isinstance(config.torch_dtype, str):
torch_dtype = getattr(torch, config.torch_dtype)
else:
torch_dtype = config.torch_dtype
self.stu_norm = (
TritonNorm(config.n_embd)
if triton_norm
else RMSNorm(config.n_embd, dtype=torch_dtype)
)
self.stu = STU(config, phi, n)
self.stu = self.stu.to(dtype=torch_dtype)
self.mlp_norm = (
TritonNorm(config.n_embd)
if triton_norm
else RMSNorm(config.n_embd, dtype=torch_dtype)
)
self.mlp = (
TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
)
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
self.stu_norm = self.stu_norm.to(dtype=torch_dtype)
self.mlp = self.mlp.to(dtype=torch_dtype)
self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Debug dtype
# Normalize and apply STU
x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) # Match dtype for STU
x_stu = self.stu(x_normed).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
x = x + x_stu
# Normalize and apply MLP
x_normed_mlp = self.mlp_norm(x).to(dtype=self.mlp.gate_proj.weight.dtype) # Match dtype for MLP
x_mlp = self.mlp(x_normed_mlp).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
x = x + x_mlp
return x
class AttentionLayer(nn.Module):
def __init__(self, config) -> None:
super(AttentionLayer, self).__init__()
if isinstance(config.torch_dtype, str):
torch_dtype = getattr(torch, config.torch_dtype)
else:
torch_dtype = config.torch_dtype
self.attn_norm = (
TritonNorm(config.n_embd)
if triton_norm
else RMSNorm(config.n_embd, dtype=torch_dtype)
)
self.attn = Attention(config)
self.attn = self.attn.to(dtype=torch_dtype)
self.mlp_norm = (
TritonNorm(config.n_embd)
if triton_norm
else RMSNorm(config.n_embd, dtype=torch_dtype)
)
self.mlp = (
TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
)
self.mlp = self.mlp.to(dtype=torch_dtype)
# TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
self.attn_norm = self.attn_norm.to(dtype=torch_dtype)
self.mlp = self.mlp.to(dtype=torch_dtype)
self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.attn_norm(x))
x = x + self.mlp(self.mlp_norm(x))
return x