import math from typing import Any, List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import init from transformers import PreTrainedModel, AutoModelForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from einops import rearrange, repeat from xformers.ops import SwiGLU from .configuration_diff_llama import DiffusionLlamaConfig # =========================================================================== # IMPORTS & CHECKS # =========================================================================== try: from lightning_utilities.core.imports import RequirementCache FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1") except ImportError: # Fallback if lightning_utilities is missing FlashAttention2Available = False # Import compiled extensions if available try: import rotary_emb except ImportError: rotary_emb = None try: import dropout_layer_norm except ImportError: dropout_layer_norm = None # =========================================================================== # PART 1: ROTARY EMBEDDING (Autograd Function for Training) # =========================================================================== class ApplyRotaryEmb(torch.autograd.Function): @staticmethod @torch.compiler.disable def forward(ctx, x, cos, sin, interleaved=False, inplace=False): """ Full forward pass from fused_rotary_embedding.py """ batch, seqlen, nheads, headdim = x.shape rotary_seqlen, rotary_dim = cos.shape rotary_dim *= 2 assert rotary_dim <= headdim assert seqlen <= rotary_seqlen x_ro = x[..., :rotary_dim] x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) out = torch.empty_like(x) if not inplace else x out_ro = out[..., :rotary_dim] if inplace: o1, o2 = x1, x2 else: o1, o2 = ( out_ro.chunk(2, dim=-1) if not interleaved else (out_ro[..., ::2], out_ro[..., 1::2]) ) if rotary_emb is None: # Fallback or error if extension is missing but this code path is hit raise ImportError("rotary_emb extension not found. Please install it to use fused rotary embeddings.") rotary_emb.apply_rotary( x1, x2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), o1, o2, False, ) if not inplace and rotary_dim < headdim: out[..., rotary_dim:].copy_(x[..., rotary_dim:]) ctx.save_for_backward(cos, sin) ctx.interleaved = interleaved ctx.inplace = inplace return out if not inplace else x @staticmethod def backward(ctx, do): """ Full backward pass from fused_rotary_embedding.py to support training """ cos, sin = ctx.saved_tensors _, seqlen, _, headdim = do.shape rotary_dim = cos.shape[-1] * 2 inplace = ctx.inplace do_ro = do[..., :rotary_dim] do1, do2 = ( do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) ) dx = torch.empty_like(do) if not inplace else do if inplace: dx1, dx2 = do1, do2 else: dx_ro = dx[..., :rotary_dim] dx1, dx2 = ( dx_ro.chunk(2, dim=-1) if not ctx.interleaved else (dx_ro[..., ::2], dx_ro[..., 1::2]) ) rotary_emb.apply_rotary( do1, do2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), dx1, dx2, True, ) if not inplace and rotary_dim < headdim: dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) return dx, None, None, None, None apply_rotary_emb_func = ApplyRotaryEmb.apply def build_rope_cache( seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1 ) -> Tuple[torch.Tensor, torch.Tensor]: theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem)) seq_idx = torch.arange(seq_len, device=device) / condense_ratio idx_theta = torch.outer(seq_idx, theta) cos, sin = torch.cos(idx_theta), torch.sin(idx_theta) if dtype == torch.bfloat16: return cos.bfloat16(), sin.bfloat16() if dtype in (torch.float16, torch.bfloat16, torch.int8): return cos.half(), sin.half() return cos, sin # =========================================================================== # PART 2: NORMALIZATION (Fused RMS Norm) # =========================================================================== def maybe_align(x, alignment_in_bytes=16): return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() def _dropout_add_layer_norm_forward( x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32=False, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes""" hidden_size = gamma.numel() x0mat = x0.view((-1, hidden_size)) residualmat = residual.view((-1, hidden_size)) if residual is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon, 1.0, 0, None, residual_in_fp32, is_rms_norm, ) # dmask is None if dropout_p == 0.0 # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma def _dropout_add_layer_norm_backward( dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes dx == None means that it was a post-norm architecture (x = drop(x0) + residual was not returned in the fwd). x0 must not be None if we have colscale. """ hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) dzmat = dz.view(xmat.shape) dxmat = dx.view(xmat.shape) if dx is not None else None x0mat = x0.view((-1, hidden_size)) if x0 is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None if colscale is not None: assert x0 is not None, "x0 is required to compute the gradient of colscale" dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None, dropout_p, 1.0, 0, has_residual, is_rms_norm, ) # dresidualmat is None if not has_residual if colscale is None: return dx0mat, dresidualmat, dgamma, dbeta else: dcolscale = rest[0] return dx0mat, dresidualmat, dgamma, dbeta, dcolscale class DropoutAddLayerNormFn(torch.autograd.Function): @staticmethod def forward( ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False, ): x0 = maybe_align(x0.contiguous(), 16) residual = maybe_align(residual.contiguous(), 16) if residual is not None else None gamma = maybe_align(gamma.contiguous(), 16) beta = maybe_align(beta.contiguous(), 16) if beta is not None else None rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32, is_rms_norm, ) # Only need to save x0 if we need to compute gradient wrt colscale x0_saved = x0 if colscale is not None else None ctx.save_for_backward( xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale ) ctx.prenorm = prenorm ctx.dropout_p = dropout_p ctx.has_residual = residual is not None ctx.is_rms_norm = is_rms_norm ctx.has_beta = beta is not None if not return_dmask: return ( zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) ) else: dmask = ( dmask.view(x0.shape) if dropout_p > 0.0 else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) ) ctx.mark_non_differentiable(dmask) return ( (zmat.view(x0.shape), dmask) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) ) @staticmethod def backward(ctx, dz, *args): # assert dz.is_contiguous() dz = maybe_align(dz.contiguous(), 16) # this happens! dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors # x0 is None if colscale is None dropout_p = ctx.dropout_p has_residual = ctx.has_residual dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual, ctx.is_rms_norm, ) dx0 = dx0mat.view(x.shape) dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None dcolscale = rest[0] if colscale is not None else None return ( dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, None, None, None, None, None, ) def rms_norm(x, weight, epsilon): return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False, False, True) class FusedRMSNorm(torch.nn.Module): def __init__(self, size: int, dim: int = -1, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.ones(size)) self.dim = dim def reset_parameters(self): init.ones_(self.weight) def forward(self, x): return rms_norm(x, self.weight, self.eps) class RMSNorm(torch.nn.Module): def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: super().__init__() self.weight = torch.nn.Parameter(torch.ones(size)) self.eps = eps self.dim = dim def forward(self, x: torch.Tensor) -> torch.Tensor: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) x_normed = x * torch.rsqrt(norm_x + self.eps) return self.weight * x_normed # =========================================================================== # PART 3: BLOCKS & LAYERS # =========================================================================== class GptNeoxMLP(nn.Module): def __init__(self, config: DiffusionLlamaConfig) -> None: super().__init__() self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc(x) x = torch.nn.functional.gelu(x) return self.proj(x) class LLaMAMLP(nn.Module): def __init__(self, config: DiffusionLlamaConfig) -> None: super().__init__() self.swiglu = SwiGLU(config.n_embd, config.intermediate_size, bias=False, _pack_weights=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.swiglu(x) class SelfAttention(nn.Module): def __init__(self, config: DiffusionLlamaConfig) -> None: super().__init__() shape = (config.n_head + 2 * config.n_query_groups) * config.head_size self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.config = config def forward(self, x: torch.Tensor, rope: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: B, T, C = x.size() qkv = self.attn(x) q_per_kv = self.config.n_head // self.config.n_query_groups total_qkv = q_per_kv + 2 qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2) q = q.reshape(B, T, -1, self.config.head_size) k = k.reshape(B, T, -1, self.config.head_size) v = v.reshape(B, T, -1, self.config.head_size) cos, sin = rope # Apply Rotary q = apply_rotary_emb_func(q, cos, sin, False, True) k = apply_rotary_emb_func(k, cos, sin, False, True) y = self.scaled_dot_product_attention(q, k, v) y = y.reshape(B, T, C) y = self.proj(y) return y def scaled_dot_product_attention(self, q, k, v): scale = 1.0 / math.sqrt(self.config.head_size) # Use Flash Attention 2 if available and on CUDA if FlashAttention2Available and q.device.type == "cuda" and q.dtype in (torch.float16, torch.bfloat16): from flash_attn import flash_attn_func return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=False) # Fallback to SDPA q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Handle GQA/MQA broadcast if q.size() != k.size(): k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1) v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1) y = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=0.0, scale=scale, is_causal=False ) return y.transpose(1, 2) class Block(nn.Module): def __init__(self, config: DiffusionLlamaConfig) -> None: super().__init__() # Determine classes dynamically based on config strings if config.norm_class == "RMSNorm": norm_cls = RMSNorm elif config.norm_class == "FusedRMSNorm": norm_cls = FusedRMSNorm else: norm_cls = getattr(torch.nn, config.norm_class) mlp_cls = LLaMAMLP if config.mlp_class == "LLaMAMLP" else GptNeoxMLP self.norm_1 = norm_cls(config.n_embd, eps=config.norm_eps) self.attn = SelfAttention(config) if not config.shared_attention_norm: self.norm_2 = norm_cls(config.n_embd, eps=config.norm_eps) self.mlp = mlp_cls(config) self.config = config def forward(self, x: torch.Tensor, rope: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: n_1 = self.norm_1(x) h = self.attn(n_1, rope) if self.config.parallel_residual: n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) x = x + h + self.mlp(n_2) else: if self.config.shared_attention_norm: raise NotImplementedError("Shared attention norm not supported with non-parallel residual") x = x + h x = x + self.mlp(self.norm_2(x)) return x # =========================================================================== # PART 4: MAIN MODEL CLASSES # =========================================================================== class TransEncoder(nn.Module): def __init__(self, config: DiffusionLlamaConfig) -> None: super().__init__() assert config.padded_vocab_size is not None self.config = config if config.norm_class == "RMSNorm": norm_cls = RMSNorm elif config.norm_class == "FusedRMSNorm": norm_cls = FusedRMSNorm else: norm_cls = getattr(torch.nn, config.norm_class) self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size + 1, config.n_embd), h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), ln_f=norm_cls(config.n_embd, eps=config.norm_eps), ) ) self.rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None def forward(self, idx: torch.Tensor) -> torch.Tensor: B, T = idx.size() # Build Rope cache if needed if self.rope_cache is None: self.rope_cache = build_rope_cache( seq_len=self.config.block_size, n_elem=int(self.config.rotary_percentage * self.config.head_size), dtype=torch.bfloat16, device=idx.device, condense_ratio=self.config.condense_ratio, ) # Retrieve and slice cache cos, sin = self.rope_cache cos = cos[:T] sin = sin[:T] x = self.transformer.wte(idx) for block in self.transformer.h: x = block(x, (cos, sin)) x = self.transformer.ln_f(x) return self.lm_head(x) class DiffusionLlamaLM(PreTrainedModel): config_class = DiffusionLlamaConfig base_model_prefix = "model" def __init__(self, config: DiffusionLlamaConfig): super().__init__(config) self.model = TransEncoder(config) # Initialize weights (Training feature) self.post_init() def _init_weights(self, module: nn.Module) -> None: """ Initialization logic for training. Adapted from original TransEncoder._init_weights. """ n_layer = self.config.n_layer if isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) elif isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) if module.bias is not None: torch.nn.init.zeros_(module.bias) # Special initialization for SwiGLU / Projections based on names # In HF _init_weights, 'module' is the current leaf. We check specific instances. if isinstance(module, LLaMAMLP): for name, p in module.named_parameters(): if "proj.weight" in name: nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer) if isinstance(module, SwiGLU): for name, p in module.named_parameters(): if "w3.weight" in name: nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer) if isinstance(module, SelfAttention): for name, p in module.named_parameters(): if "proj.weight" in name: nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer) def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict logits = self.model(input_ids) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: return ((loss,) + (logits,)) if loss is not None else (logits,) return CausalLMOutputWithPast(loss=loss, logits=logits)