Diff_LLaMA_336M_sudoku_sft_v2_640 / modeling_diff_llama.py
zzy1123's picture
Update modeling_diff_llama.py
9566595 verified
import math
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import init
from transformers import PreTrainedModel, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from einops import rearrange, repeat
from xformers.ops import SwiGLU
from .configuration_diff_llama import DiffusionLlamaConfig
# ===========================================================================
# IMPORTS & CHECKS
# ===========================================================================
try:
from lightning_utilities.core.imports import RequirementCache
FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1")
except ImportError:
# Fallback if lightning_utilities is missing
FlashAttention2Available = False
# Import compiled extensions if available
try:
import rotary_emb
except ImportError:
rotary_emb = None
try:
import dropout_layer_norm
except ImportError:
dropout_layer_norm = None
# ===========================================================================
# PART 1: ROTARY EMBEDDING (Autograd Function for Training)
# ===========================================================================
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
@torch.compiler.disable
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
Full forward pass from fused_rotary_embedding.py
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
x_ro = x[..., :rotary_dim]
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]
if inplace:
o1, o2 = x1, x2
else:
o1, o2 = (
out_ro.chunk(2, dim=-1)
if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2])
)
if rotary_emb is None:
# Fallback or error if extension is missing but this code path is hit
raise ImportError("rotary_emb extension not found. Please install it to use fused rotary embeddings.")
rotary_emb.apply_rotary(
x1, x2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
o1, o2,
False,
)
if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x
@staticmethod
def backward(ctx, do):
"""
Full backward pass from fused_rotary_embedding.py to support training
"""
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1] * 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
do1, do2 = (
do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
)
dx = torch.empty_like(do) if not inplace else do
if inplace:
dx1, dx2 = do1, do2
else:
dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (
dx_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2])
)
rotary_emb.apply_rotary(
do1, do2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dx1, dx2,
True,
)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
apply_rotary_emb_func = ApplyRotaryEmb.apply
def build_rope_cache(
seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem))
seq_idx = torch.arange(seq_len, device=device) / condense_ratio
idx_theta = torch.outer(seq_idx, theta)
cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)
if dtype == torch.bfloat16:
return cos.bfloat16(), sin.bfloat16()
if dtype in (torch.float16, torch.bfloat16, torch.int8):
return cos.half(), sin.half()
return cos, sin
# ===========================================================================
# PART 2: NORMALIZATION (Fused RMS Norm)
# ===========================================================================
def maybe_align(x, alignment_in_bytes=16):
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
def _dropout_add_layer_norm_forward(
x0,
residual,
gamma,
beta,
rowscale,
colscale,
dropout_p,
epsilon,
residual_in_fp32=False,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat,
residualmat,
gamma,
beta,
rowscale,
colscale,
None,
None,
dropout_p,
epsilon,
1.0,
0,
None,
residual_in_fp32,
is_rms_norm,
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_backward(
dz,
dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
rowscale,
colscale,
dropout_p,
has_residual,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dzmat = dz.view(xmat.shape)
dxmat = dx.view(xmat.shape) if dx is not None else None
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
if colscale is not None:
assert x0 is not None, "x0 is required to compute the gradient of colscale"
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat,
dxmat,
xmat,
x0mat,
dmask,
mu,
rsigma,
gamma,
rowscale,
colscale,
None,
None,
dropout_p,
1.0,
0,
has_residual,
is_rms_norm,
)
# dresidualmat is None if not has_residual
if colscale is None:
return dx0mat, dresidualmat, dgamma, dbeta
else:
dcolscale = rest[0]
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x0,
residual,
gamma,
beta,
rowscale,
colscale,
dropout_p,
epsilon,
residual_in_fp32=False,
prenorm=False,
is_rms_norm=False,
return_dmask=False,
):
x0 = maybe_align(x0.contiguous(), 16)
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma = maybe_align(gamma.contiguous(), 16)
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0,
residual,
gamma,
beta,
rowscale,
colscale,
dropout_p,
epsilon,
residual_in_fp32,
is_rms_norm,
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
ctx.save_for_backward(
xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None
if not return_dmask:
return (
zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
)
else:
dmask = (
dmask.view(x0.shape)
if dropout_p > 0.0
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
ctx.mark_non_differentiable(dmask)
return (
(zmat.view(x0.shape), dmask)
if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
)
@staticmethod
def backward(ctx, dz, *args):
# assert dz.is_contiguous()
dz = maybe_align(dz.contiguous(), 16) # this happens!
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
# x0 is None if colscale is None
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
dz,
dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
rowscale,
colscale,
dropout_p,
has_residual,
ctx.is_rms_norm,
)
dx0 = dx0mat.view(x.shape)
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None
return (
dx0,
dresidual,
dgamma,
dbeta if ctx.has_beta else None,
None,
dcolscale,
None,
None,
None,
None,
None,
None,
)
def rms_norm(x, weight, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False, False, True)
class FusedRMSNorm(torch.nn.Module):
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(size))
self.dim = dim
def reset_parameters(self):
init.ones_(self.weight)
def forward(self, x):
return rms_norm(x, self.weight, self.eps)
class RMSNorm(torch.nn.Module):
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(size))
self.eps = eps
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
return self.weight * x_normed
# ===========================================================================
# PART 3: BLOCKS & LAYERS
# ===========================================================================
class GptNeoxMLP(nn.Module):
def __init__(self, config: DiffusionLlamaConfig) -> None:
super().__init__()
self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc(x)
x = torch.nn.functional.gelu(x)
return self.proj(x)
class LLaMAMLP(nn.Module):
def __init__(self, config: DiffusionLlamaConfig) -> None:
super().__init__()
self.swiglu = SwiGLU(config.n_embd, config.intermediate_size, bias=False, _pack_weights=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.swiglu(x)
class SelfAttention(nn.Module):
def __init__(self, config: DiffusionLlamaConfig) -> None:
super().__init__()
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.config = config
def forward(self, x: torch.Tensor, rope: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
B, T, C = x.size()
qkv = self.attn(x)
q_per_kv = self.config.n_head // self.config.n_query_groups
total_qkv = q_per_kv + 2
qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)
q = q.reshape(B, T, -1, self.config.head_size)
k = k.reshape(B, T, -1, self.config.head_size)
v = v.reshape(B, T, -1, self.config.head_size)
cos, sin = rope
# Apply Rotary
q = apply_rotary_emb_func(q, cos, sin, False, True)
k = apply_rotary_emb_func(k, cos, sin, False, True)
y = self.scaled_dot_product_attention(q, k, v)
y = y.reshape(B, T, C)
y = self.proj(y)
return y
def scaled_dot_product_attention(self, q, k, v):
scale = 1.0 / math.sqrt(self.config.head_size)
# Use Flash Attention 2 if available and on CUDA
if FlashAttention2Available and q.device.type == "cuda" and q.dtype in (torch.float16, torch.bfloat16):
from flash_attn import flash_attn_func
return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=False)
# Fallback to SDPA
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Handle GQA/MQA broadcast
if q.size() != k.size():
k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1)
v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1)
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, scale=scale, is_causal=False
)
return y.transpose(1, 2)
class Block(nn.Module):
def __init__(self, config: DiffusionLlamaConfig) -> None:
super().__init__()
# Determine classes dynamically based on config strings
if config.norm_class == "RMSNorm":
norm_cls = RMSNorm
elif config.norm_class == "FusedRMSNorm":
norm_cls = FusedRMSNorm
else:
norm_cls = getattr(torch.nn, config.norm_class)
mlp_cls = LLaMAMLP if config.mlp_class == "LLaMAMLP" else GptNeoxMLP
self.norm_1 = norm_cls(config.n_embd, eps=config.norm_eps)
self.attn = SelfAttention(config)
if not config.shared_attention_norm:
self.norm_2 = norm_cls(config.n_embd, eps=config.norm_eps)
self.mlp = mlp_cls(config)
self.config = config
def forward(self, x: torch.Tensor, rope: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
n_1 = self.norm_1(x)
h = self.attn(n_1, rope)
if self.config.parallel_residual:
n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
x = x + h + self.mlp(n_2)
else:
if self.config.shared_attention_norm:
raise NotImplementedError("Shared attention norm not supported with non-parallel residual")
x = x + h
x = x + self.mlp(self.norm_2(x))
return x
# ===========================================================================
# PART 4: MAIN MODEL CLASSES
# ===========================================================================
class TransEncoder(nn.Module):
def __init__(self, config: DiffusionLlamaConfig) -> None:
super().__init__()
assert config.padded_vocab_size is not None
self.config = config
if config.norm_class == "RMSNorm":
norm_cls = RMSNorm
elif config.norm_class == "FusedRMSNorm":
norm_cls = FusedRMSNorm
else:
norm_cls = getattr(torch.nn, config.norm_class)
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size + 1, config.n_embd),
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
ln_f=norm_cls(config.n_embd, eps=config.norm_eps),
)
)
self.rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def forward(self, idx: torch.Tensor) -> torch.Tensor:
B, T = idx.size()
# Build Rope cache if needed
if self.rope_cache is None:
self.rope_cache = build_rope_cache(
seq_len=self.config.block_size,
n_elem=int(self.config.rotary_percentage * self.config.head_size),
dtype=torch.bfloat16,
device=idx.device,
condense_ratio=self.config.condense_ratio,
)
# Retrieve and slice cache
cos, sin = self.rope_cache
cos = cos[:T]
sin = sin[:T]
x = self.transformer.wte(idx)
for block in self.transformer.h:
x = block(x, (cos, sin))
x = self.transformer.ln_f(x)
return self.lm_head(x)
class DiffusionLlamaLM(PreTrainedModel):
config_class = DiffusionLlamaConfig
base_model_prefix = "model"
def __init__(self, config: DiffusionLlamaConfig):
super().__init__(config)
self.model = TransEncoder(config)
# Initialize weights (Training feature)
self.post_init()
def _init_weights(self, module: nn.Module) -> None:
"""
Initialization logic for training.
Adapted from original TransEncoder._init_weights.
"""
n_layer = self.config.n_layer
if isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
elif isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
# Special initialization for SwiGLU / Projections based on names
# In HF _init_weights, 'module' is the current leaf. We check specific instances.
if isinstance(module, LLaMAMLP):
for name, p in module.named_parameters():
if "proj.weight" in name:
nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer)
if isinstance(module, SwiGLU):
for name, p in module.named_parameters():
if "w3.weight" in name:
nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer)
if isinstance(module, SelfAttention):
for name, p in module.named_parameters():
if "proj.weight" in name:
nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer)
def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logits = self.model(input_ids)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
return ((loss,) + (logits,)) if loss is not None else (logits,)
return CausalLMOutputWithPast(loss=loss, logits=logits)