LLM_D4 / LLM_2.py
firdavsus's picture
Upload folder using huggingface_hub
943bd92 verified
import math
import inspect
from dataclasses import dataclass
from contextlib import nullcontext
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple
import inspect
from transformers.modeling_outputs import CausalLMOutput
from manager import MANAGER
torch.manual_seed(101)
def precompute_freqs_cis(config):
# We now return cos and sin directly instead of a complex polar tensor
freqs = 1.0 / (config.theta ** (torch.arange(0, config.d_rotate, 2)[: (config.d_rotate // 2)].float() / config.d_rotate))
t = torch.arange(config.block_size, device=freqs.device)
freqs = torch.outer(t, freqs).float() # [seq_len, d_rotate/2]
# Cos and Sin are what Inductor can easily optimize
cos = torch.cos(freqs)
sin = torch.sin(freqs)
# Repeat along the last dimension to match the d_rotate size
# [seq_len, d_rotate/2] -> [seq_len, d_rotate]
cos = torch.repeat_interleave(cos, 2, dim=-1)
sin = torch.repeat_interleave(sin, 2, dim=-1)
return cos, sin
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
# x: [..., d_rotate]
# Split into [x1, x2, x3, x4...] -> x1, x2 are pairs
# We use the interleaving pattern: [-x2, x1, -x4, x3...]
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):
# Reshape freqs for broadcasting: [seq_len, d_rotate] -> [1, seq_len, 1, d_rotate]
# This matches (batch, seq, head, dim)
cos = freqs_cos[:xq.shape[1]].view(1, xq.shape[1], 1, xq.shape[-1])
sin = freqs_sin[:xq.shape[1]].view(1, xq.shape[1], 1, xq.shape[-1])
# The RoPE formula: x_out = x * cos + rotate_half(x) * sin
xq_out = (xq * cos) + (rotate_half(xq) * sin)
xk_out = (xk * cos) + (rotate_half(xk) * sin)
return xq_out.type_as(xq), xk_out.type_as(xk)
class MultiHeadLatentAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.d_model = config.n_embd
self.num_head = config.n_head
self.d_head = self.d_model // self.num_head
self.d_c = config.d_c
self.d_c1 = config.d_c1
self.d_rotate = config.d_rotate
# ==========================================
# FUSION 1: All Projections from 'x'
# Replaces DQ_proj, DKV_proj, and RK_proj
# ==========================================
self.W_down = nn.Linear(
self.d_model,
self.d_c1 + self.d_c + self.d_rotate,
bias=config.bias
)
self.W_down.is_attention = True
# ==========================================
# FUSION 2: All Q Up-Projections from 'C_Q'
# Replaces UQ_proj and RQ_proj
# ==========================================
self.W_up_q = nn.Linear(
self.d_c1,
self.d_model + (self.num_head * self.d_rotate),
bias=config.bias
)
self.W_up_q.is_attention = True
# ==========================================
# FUSION 3: All KV Up-Projections from 'C_KV'
# Replaces UK_proj and UV_proj (STILL STRICTLY SEPARATE WEIGHTS)
# ==========================================
self.W_up_kv = nn.Linear(
self.d_c,
self.d_model + self.d_model, # d_model for K, d_model for V
bias=config.bias
)
self.W_up_kv.is_attention = True
self.q_norm = nn.RMSNorm(self.d_c1)
self.kv_norm = nn.RMSNorm(self.d_c)
# Output projection and Regularization
self.output_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
self.output_proj.output_proj_marker = True
self.output_proj.is_attention = True
self.dropout = nn.Dropout(config.dropout)
self.attn_dropout_p = config.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
cos, sin = precompute_freqs_cis(config)
self.register_buffer("freqs_cos", cos, persistent=False)
self.register_buffer("freqs_sin", sin, persistent=False)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# ---------------------------------------------------------
# 1. KERNEL 1: Down-project everything at once
# ---------------------------------------------------------
down_out = self.W_down(x)
# Split into the 3 exact latents your math requires
C_Q, C_KV, K_rotate = down_out.split(
[self.d_c1, self.d_c, self.d_rotate], dim=-1
)
C_Q = self.q_norm(C_Q)
C_KV = self.kv_norm(C_KV)
# ---------------------------------------------------------
# 2. KERNEL 2: Up-project Query content and RoPE
# ---------------------------------------------------------
q_up_out = self.W_up_q(C_Q)
Q_state, Q_rotate = q_up_out.split(
[self.d_model, self.num_head * self.d_rotate], dim=-1
)
Q_state = Q_state.view(batch_size, seq_len, self.num_head, self.d_head)
Q_rotate = Q_rotate.view(batch_size, seq_len, self.num_head, self.d_rotate)
# ---------------------------------------------------------
# 3. KERNEL 3: Up-project Key and Value content independently
# ---------------------------------------------------------
kv_up_out = self.W_up_kv(C_KV)
K_state, V_state = kv_up_out.split(
[self.d_model, self.d_model], dim=-1
)
K_state = K_state.view(batch_size, seq_len, self.num_head, self.d_head)
V_state = V_state.view(batch_size, seq_len, self.num_head, self.d_head)
# Prepare shared RoPE Key
K_rotate = K_rotate.view(batch_size, seq_len, 1, self.d_rotate).expand(-1, -1, self.num_head, -1)
# ---------------------------------------------------------
# 4. Apply RoPE, Concatenate, and Attention
# ---------------------------------------------------------
Q_rotate, K_rotate = apply_rotary_emb(
Q_rotate,
K_rotate,
self.freqs_cos,
self.freqs_sin
)
Q = torch.cat([Q_state, Q_rotate], dim=-1).transpose(1, 2)
K = torch.cat([K_state, K_rotate], dim=-1).transpose(1, 2)
V = V_state.transpose(1, 2)
if self.flash:
att_output = F.scaled_dot_product_attention(
Q, K, V,
dropout_p=self.attn_dropout_p if self.training else 0.0,
is_causal=True
)
else:
scaler = 1.0 / math.sqrt(self.d_head + self.d_rotate)
att_matrix = (Q @ K.transpose(-2, -1)) * scaler
mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).view(1, 1, seq_len, seq_len)
att_matrix = att_matrix.masked_fill(mask == 0, float('-inf'))
att_score = self.dropout(F.softmax(att_matrix, dim=-1))
att_output = att_score @ V
att_output = att_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.output_proj(att_output)
class Router(nn.Module):
def __init__(self, config):
super().__init__()
# router settings
self.top_k = config.top_k
self.n_exp = config.n_exp
assert self.top_k >= 1 and self.top_k <= config.n_exp
self.use_noisy_top_k = config.use_noisy_top_k
self.train_capacity = config.train_capacity
self.eval_capacity = config.eval_capacity
self.min_capacity = config.min_capacity
self.router_use_full_prec = config.router_use_full_prec
# auxiliary / load balancing loss settings
self.use_aux_loss = config.use_aux_loss
self.use_router_z_loss = config.use_router_z_loss
# linear projection for (noisy) softmax gating
# no bias is used, see page 4 eq (4) in (https://arxiv.org/abs/1701.06538)
self.w_g = nn.Linear(config.n_embd, config.n_exp, bias=False)
self.w_g.router_marker = True
self.w_noise = nn.Linear(config.n_embd, config.n_exp, bias=False) if self.use_noisy_top_k else None
def forward(self, x):
# optionally run the router in full precision to avoid instability during training
# see discussion on pg. 9 here: https://arxiv.org/abs/2101.03961
# setting enabled to False in autocast automatically puts everything in float32
device_type = 'cuda' if torch.cuda.is_available() else 'cpu' # for later use in torch.autocast
ctx = nullcontext() if not self.router_use_full_prec else torch.amp.autocast(device_type=device_type, enabled=False)
with ctx:
B, T, _ = x.size()
num_tokens = B * T
# eq (4) in (https://arxiv.org/abs/1701.06538)
logits = self.w_g(x) # [B, T, n_exp]
if self.use_noisy_top_k:
# optionally add noise into the router
noise = F.softplus(self.w_noise(x))
noise *= torch.randn_like(noise)
logits += noise
# router z loss, computed on logits (before softmax)
# this loss prevents router logits from becoming too large
if self.use_router_z_loss:
z_loss = self.compute_router_z_loss(logits)
MANAGER.add_router_z_loss(z_loss)
# find top k experts for each token
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1) # [B, T, k]
# normalize expert probabilities
# Question: should we normalize over all experts or just top-k?
# we choose to normalize over top-k, other option is commented out below
# Shazeer et al (https://arxiv.org/abs/1701.06538) does only topk
# see page 4 eq (3)-(5), the code for this is commented out below
router_probs = torch.full_like(logits, float('-inf')) # [B, T, n_exp]
router_probs.scatter_(-1, top_k_indices, top_k_logits)
router_probs = F.softmax(router_probs, dim=-1)
# # normalize all router logits (not just top-k) via softmax
router_probs = F.softmax(logits, dim=-1)
# compute auxiliary load balancing loss
# this loss encourages equal probability assigned to each expert
# and equal load balancing of tokens assigned to each expert
if self.use_aux_loss:
aux_loss = self.compute_aux_loss(router_probs, top_k_indices)
MANAGER.add_aux_loss(aux_loss)
# compute expert capacity
exp_capacity = self.get_capacity(num_tokens)
# make a multi-hot mask of chosen experts, size [B, T, n_exp]
# entries are 0 if expert not chosen and 1 if expert chosen
exp_mask = F.one_hot(top_k_indices, num_classes=self.n_exp) # [B, T, k, n_exp]
exp_mask = exp_mask.view(num_tokens, self.top_k, self.n_exp) # [B * T, k, n_exp]
exp_mask = exp_mask.permute(1, 0, 2) # [k, B * T, n_exp]
# compute cumulative sum of each token over experts, this stores
# the index of each token within the batch of each expert
# NOTE: cumsum should count all top-1 first, top-2 second, etc.
# so that we prioritize top experts when dropping tokens (this is
# done by putting k dimension first for the reshape operation)
exp_rank = exp_mask.reshape(self.top_k * num_tokens, self.n_exp) # [k * B * T, n_exp]
exp_rank = torch.cumsum(exp_rank, dim=0) - 1 # cumulative sum of expert selections [k * B * T, n_exp]
exp_rank = exp_rank.reshape(self.top_k, num_tokens, self.n_exp) # [k, B * T, n_exp]
# mask out (set to zero) entries that go beyond expert capacity
# compute amount of used capacity by taking a sum over mask
exp_mask *= torch.lt(exp_rank, exp_capacity) # [k, B * T, n_exp]
used_capacity = torch.sum(exp_mask, dim=(0, 1)) # [n_exp]
# mask rank to only include tokens that are selected
# perform a sum so each row only contains index of token
# for the expert that is selected in that row
# result is a matrix that contains the position of each token
# in the batch of its corresponding expert
exp_rank = torch.sum(exp_mask * exp_rank, dim=-1) # [k, B * T]
# mask probabilities to only include selected experts
router_probs = router_probs.view(num_tokens, self.n_exp)[None, :] # [1, B * T, n_exp]
exp_weights = exp_mask * router_probs # [k, B * T, n_exp]
# convert rank into one-hot vectors over the available capacity
# stores the position of each token within the capacity of the selected expert
exp_rank_sc = F.one_hot(exp_rank, num_classes=exp_capacity) # [k, B * T, exp_capacity]
# create a vector that stores, for each token, the weight of selected
# experts at token's position in the capacity of that expert
# size of tensor is [B * T, n_exp, exp_capacity]
cb_weight = torch.sum(exp_weights.unsqueeze(3) * exp_rank_sc.unsqueeze(2), dim=0)
sec_mask = cb_weight.bool() # binary mask of selected experts for each token
return used_capacity, cb_weight, sec_mask
def compute_aux_loss(self, expert_probs: torch.Tensor, indices: torch.Tensor):
"""
Computes Switch Transformer auxiliary loss (https://arxiv.org/abs/2101.03961)
See equations (4)-(6) on page 7
"""
# equation (5): compute ratio of tokens allocated to each expert
# total number of tokens is defined as total tokens in batch * k
# (k = 1) for the Switch Transformer
with torch.no_grad():
one_hot_indices = F.one_hot(indices, num_classes=self.n_exp) # [B, T, k, n_exp]
one_hot_indices = torch.sum(one_hot_indices.float(), dim=2) # [B, T, n_exp] (sum over k dimension)
tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1))
# equation (6): compute ratio of router probability allocated to each expert
prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1))
# equation (4): take a scaled dot product between prob/token allocation vectors
# multiply the result by the number of experts
return self.n_exp * torch.sum(prob_per_expert * tokens_per_expert)
def compute_router_z_loss(self, logits: torch.Tensor):
"""
Computes ST-MoE router z loss (https://arxiv.org/abs/2202.08906)
See equation (5) on page 7
"""
# exponentiate logits, sum logits of each expert, take log, and square
# code below is the same as:
# > z_loss = torch.exp(logits)
# > z_loss = torch.sum(z_loss, dim=-1)
# > z_loss = torch.log(z_loss) ** 2.0
z_loss = torch.logsumexp(logits, dim=-1) ** 2.0 # [B, T, n_exp]
# sum over all tokens and divide by total number of tokens
return torch.mean(z_loss)
def get_capacity(self, tokens_per_batch):
# expert capacity is given by (tokens_per_batch / num_experts) * capacity_factor
# see eq (3) in Switch Transformer (https://arxiv.org/abs/2101.03961)
capacity_factor = self.train_capacity if self.training else self.eval_capacity
capacity = math.floor(self.top_k * capacity_factor * tokens_per_batch / self.n_exp)
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return int(capacity)
# FEEDFORWARD
class MLP(nn.Module):
def __init__(self, config, ffn_dim=None):
super().__init__()
if ffn_dim==None:
ffn_dim = config.ffn_dim
self.fc1 = nn.Linear(config.n_embd, 2 * ffn_dim, bias=config.bias)
self.fc1.is_swiglu = True
self.swish = nn.SiLU()
self.fc2 = nn.Linear(ffn_dim, config.n_embd, bias=config.bias)
self.fc2.output_proj_marker = True
self.dropout1 = nn.Dropout(config.dropout)
self.dropout2 = nn.Dropout(config.dropout)
# nn.init.xavier_uniform_(self.fc1.weight, gain=math.sqrt(2.0))
# nn.init.xavier_uniform_(self.fc2.weight, gain=1.0)
def forward(self, x):
x = self.fc1(x)
# Inline SwiGLU: Split the doubled dimension and apply gate
x, gate = x.chunk(2, dim=-1)
x = x * self.swish(gate)
x = self.dropout1(x)
x = self.fc2(x)
return self.dropout2(x)
class MLPExperts(nn.Module):
def __init__(self, config):
super().__init__()
self.n_exp = config.n_exp
self.n_embd = config.n_embd
self.bias = config.bias
self.c_fc = nn.Parameter(torch.empty(self.n_exp, self.n_embd, 2 * config.expert_dim))
self.c_proj = nn.Parameter(torch.empty(self.n_exp, config.expert_dim, self.n_embd))
self.swish = nn.SiLU()
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = torch.bmm(x, self.c_fc)
x, gate = x.chunk(2, dim=-1)
x = x * self.swish(gate)
x = torch.bmm(x, self.c_proj)
return self.dropout(x)
class MOELayer(nn.Module):
def __init__(self, config):
super().__init__()
self.router = Router(config) # (noisy) top k router
self.experts = MLPExperts(config) # group of MLPs (experts)
self.shared_expert = MLP(config, ffn_dim=config.shared_dim)
def forward(self, x: torch.Tensor):
B, T, n_embd = x.size()
num_tokens = (B * T)
shared_out = self.shared_expert(x)
used_capacity, exp_weight, exp_mask = self.router(x)
x = x.view(num_tokens, n_embd)
# [n_exp, exp_capacity, B * T] * [B * T, n_embd] -> [n_exp, exp_capacity, n_embd]
exp_batches = exp_mask.permute(1, 2, 0).type_as(x) @ x
exp_out = self.experts(exp_batches) # [n_exp, exp_capacity, n_embd]
# aggregate expert outputs based on router weights
# eq (2) on page 4 of ST-MoE (https://arxiv.org/abs/2202.08906)
# similar equations are used for other MoE papers
exp_weight = exp_weight.view(num_tokens, -1) # [B * T, n_exp * exp_capacity]
exp_out = exp_out.view(-1, n_embd) # [n_exp * exp_capacity, n_embd]
output = exp_weight @ exp_out # [B * T, n_embd]
moe_out = output.view(B, T, n_embd)
return moe_out + shared_out
class Block(nn.Module):
def __init__(self, config, use_moe=False):
super().__init__()
self.ln_1 = nn.RMSNorm(config.n_embd)
self.attn = MultiHeadLatentAttention(config)
self.ln_2 = nn.RMSNorm(config.n_embd)
if use_moe:
self.mlp = MOELayer(config)
else:
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
@dataclass
class GPTConfig:
block_size: int = 2048
vocab_size: int = 50304
n_layer: int = 24
n_head: int = 10
n_embd: int = 640
dropout: float = 0.0
ffn_dim: int = 640*4
bias: bool = False
# MLA - High Efficiency
d_c: int = 192
d_c1: int = 192
d_rotate: int = 64
theta: float = 10000.0
# MoE - Maximally Smart
n_exp: int = 12
top_k: int = 3
expert_dim: int = 640
shared_dim: int = 640
stride: int = 2
# Stability (Standard Production Settings)
use_aux_loss: bool = True
use_router_z_loss: bool = True
use_noisy_top_k: bool = True
aux_loss_weight: float = 0.01
router_z_loss_weight: float = 0.001
train_capacity: float = 1.25
eval_capacity: float = 2.0
min_capacity: int = 4
use_switch_tfm_init: bool = True
switch_tfm_init_scale: float = 1.0
router_use_full_prec: bool = True
# Training Hyperparameters
batch_size: int = 8
grad_acc: int = 128
num_train_epochs: int = 1
learning_rate: float = 3e-4
weight_decay: float = 0.1
betas: tuple = (0.9, 0.95)
warm_up: int = 5000
eos_token_id = 0
bos_token_id = 0
pad_token_id = 0
class HybridOptimizer(torch.optim.Optimizer):
def __init__(self, optimizers):
self.optimizers = optimizers
self.param_groups = []
for opt in self.optimizers:
self.param_groups.extend(opt.param_groups)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for opt in self.optimizers:
opt.step()
return loss
def zero_grad(self, set_to_none=True):
for opt in self.optimizers:
opt.zero_grad(set_to_none=set_to_none)
def state_dict(self):
return [opt.state_dict() for opt in self.optimizers]
def load_state_dict(self, state_dict):
for opt, sd in zip(self.optimizers, state_dict):
opt.load_state_dict(sd)
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.can_return_loss = True
self.accepts_loss_kwargs = False
if config.n_exp == 1:
blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
else:
blocks = []
for i in range(config.n_layer):
use_moe = False if (i < config.stride or i > config.n_layer - config.stride-1) else True
blocks.append(Block(config, use_moe=use_moe))
blocks = nn.ModuleList(blocks)
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
h = blocks,
ln_f = nn.RMSNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
self.apply(self._init_weights)
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
def get_num_params(self, non_embedding=True):
n_params = sum(p.numel() for p in self.parameters())
return n_params
@torch.no_grad()
def _init_weights(self, module):
# Setup base configuration
scale = self.config.switch_tfm_init_scale if hasattr(self.config, 'switch_tfm_init_scale') else 1.0
n_layer = self.config.n_layer
if isinstance(module, nn.Linear):
# Calculate standard fan-in (input dimension)
w_fan_in = module.weight.shape[-1]
base_std = (scale / w_fan_in) ** 0.5
# Determine specific scaling per layer type
if hasattr(module, 'router_marker'):
# Small std for routers ensures balanced initial expert distribution
final_std = 0.01
elif hasattr(module, 'output_proj_marker'):
# Residual scaling: keeps variance from exploding in deep networks
final_std = base_std / math.sqrt(2 * n_layer)
elif hasattr(module, 'is_attention'):
# Attn weights often benefit from a slight dampener
final_std = base_std * 0.7
else:
# Standard hidden/up-projections
final_std = base_std
# Apply truncated normal initialization
torch.nn.init.trunc_normal_(
module.weight, mean=0.0, std=final_std, a=-2*final_std, b=2*final_std
)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
# Handling custom Parameter-based MLPExperts
elif isinstance(module, MLPExperts):
# UP-PROJECTION (c_fc)
c_fc_fan_in = module.c_fc.shape[-2]
final_fc_std = (scale / c_fc_fan_in) ** 0.5
torch.nn.init.trunc_normal_(module.c_fc, std=final_fc_std, a=-2*final_fc_std, b=2*final_fc_std)
# DOWN-PROJECTION (c_proj)
c_proj_fan_in = module.c_proj.shape[-2]
# Residual scaling for MoE outputs
final_proj_std = ((scale / c_proj_fan_in) ** 0.5) / math.sqrt(2 * n_layer)
torch.nn.init.trunc_normal_(module.c_proj, std=final_proj_std, a=-2*final_proj_std, b=2*final_proj_std)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
# elif isinstance(module, nn.RMSNorm):
# # Initializing to 0.01 as requested
# # Note: 1.0 is standard, 0.01 will significantly dampen initial signal
# torch.nn.init.constant_(module.weight, 1.0)
def forward(self, input_ids, labels=None, attention_mask=None, **kwargs):
_, t = input_ids.size()
assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"
x = self.transformer.wte(input_ids)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
if labels is not None:
logits = self.lm_head(x)
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
# print("\n\nlabel: ", shift_labels, "\ninput: ", input_ids)
loss_fct = nn.CrossEntropyLoss(
ignore_index=-100,
label_smoothing=0.1,
reduction='mean'
)
main_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
loss = main_loss
if self.config.n_exp > 1:
if self.config.use_aux_loss:
loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
MANAGER.reset_aux_loss()
if self.config.use_router_z_loss:
loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
MANAGER.reset_router_z_loss()
else:
logits = self.lm_head(x[:, [-1], :])
loss = None
return CausalLMOutput(loss=loss, logits=logits)
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
# TODO: add expert config
# start with all of the candidate parameters
param_dict = {pn: p for pn, p in self.named_parameters()}
# filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
# add an extra check for "bias" string to account for bias terms in MoE layers
decay_params = [p for n, p in param_dict.items() if (p.dim() >= 2 and not n.endswith('bias'))]
nodecay_params = [p for n, p in param_dict.items() if (p.dim() < 2 or n.endswith('bias'))]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
print(f"using fused AdamW: {use_fused}")
return optimizer
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# Correctly unpack the dataclass output
outputs = self(idx_cond)
logits = outputs.logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx