| import math
|
| import inspect
|
| from dataclasses import dataclass
|
| from contextlib import nullcontext
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.nn import functional as F
|
| from typing import Tuple
|
| import inspect
|
|
|
| from transformers.modeling_outputs import CausalLMOutput
|
| from manager import MANAGER
|
|
|
| torch.manual_seed(101)
|
|
|
| def precompute_freqs_cis(config):
|
|
|
| freqs = 1.0 / (config.theta ** (torch.arange(0, config.d_rotate, 2)[: (config.d_rotate // 2)].float() / config.d_rotate))
|
| t = torch.arange(config.block_size, device=freqs.device)
|
| freqs = torch.outer(t, freqs).float()
|
|
|
|
|
| cos = torch.cos(freqs)
|
| sin = torch.sin(freqs)
|
|
|
|
|
|
|
| cos = torch.repeat_interleave(cos, 2, dim=-1)
|
| sin = torch.repeat_interleave(sin, 2, dim=-1)
|
| return cos, sin
|
|
|
| def rotate_half(x):
|
| """Rotates half the hidden dims of the input."""
|
|
|
|
|
|
|
| x1 = x[..., 0::2]
|
| x2 = x[..., 1::2]
|
| return torch.stack((-x2, x1), dim=-1).flatten(-2)
|
|
|
| def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):
|
|
|
|
|
| cos = freqs_cos[:xq.shape[1]].view(1, xq.shape[1], 1, xq.shape[-1])
|
| sin = freqs_sin[:xq.shape[1]].view(1, xq.shape[1], 1, xq.shape[-1])
|
|
|
|
|
| xq_out = (xq * cos) + (rotate_half(xq) * sin)
|
| xk_out = (xk * cos) + (rotate_half(xk) * sin)
|
|
|
| return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
|
| class MultiHeadLatentAttention(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| self.d_model = config.n_embd
|
| self.num_head = config.n_head
|
| self.d_head = self.d_model // self.num_head
|
|
|
| self.d_c = config.d_c
|
| self.d_c1 = config.d_c1
|
| self.d_rotate = config.d_rotate
|
|
|
|
|
|
|
|
|
|
|
| self.W_down = nn.Linear(
|
| self.d_model,
|
| self.d_c1 + self.d_c + self.d_rotate,
|
| bias=config.bias
|
| )
|
| self.W_down.is_attention = True
|
|
|
|
|
|
|
|
|
|
|
| self.W_up_q = nn.Linear(
|
| self.d_c1,
|
| self.d_model + (self.num_head * self.d_rotate),
|
| bias=config.bias
|
| )
|
| self.W_up_q.is_attention = True
|
|
|
|
|
|
|
|
|
|
|
| self.W_up_kv = nn.Linear(
|
| self.d_c,
|
| self.d_model + self.d_model,
|
| bias=config.bias
|
| )
|
| self.W_up_kv.is_attention = True
|
|
|
| self.q_norm = nn.RMSNorm(self.d_c1)
|
| self.kv_norm = nn.RMSNorm(self.d_c)
|
|
|
|
|
| self.output_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
|
| self.output_proj.output_proj_marker = True
|
| self.output_proj.is_attention = True
|
|
|
| self.dropout = nn.Dropout(config.dropout)
|
| self.attn_dropout_p = config.dropout
|
|
|
| self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
| cos, sin = precompute_freqs_cis(config)
|
| self.register_buffer("freqs_cos", cos, persistent=False)
|
| self.register_buffer("freqs_sin", sin, persistent=False)
|
|
|
| def forward(self, x):
|
| batch_size, seq_len, _ = x.size()
|
|
|
|
|
|
|
|
|
| down_out = self.W_down(x)
|
|
|
| C_Q, C_KV, K_rotate = down_out.split(
|
| [self.d_c1, self.d_c, self.d_rotate], dim=-1
|
| )
|
|
|
| C_Q = self.q_norm(C_Q)
|
| C_KV = self.kv_norm(C_KV)
|
|
|
|
|
|
|
|
|
| q_up_out = self.W_up_q(C_Q)
|
| Q_state, Q_rotate = q_up_out.split(
|
| [self.d_model, self.num_head * self.d_rotate], dim=-1
|
| )
|
| Q_state = Q_state.view(batch_size, seq_len, self.num_head, self.d_head)
|
| Q_rotate = Q_rotate.view(batch_size, seq_len, self.num_head, self.d_rotate)
|
|
|
|
|
|
|
|
|
| kv_up_out = self.W_up_kv(C_KV)
|
| K_state, V_state = kv_up_out.split(
|
| [self.d_model, self.d_model], dim=-1
|
| )
|
| K_state = K_state.view(batch_size, seq_len, self.num_head, self.d_head)
|
| V_state = V_state.view(batch_size, seq_len, self.num_head, self.d_head)
|
|
|
|
|
| K_rotate = K_rotate.view(batch_size, seq_len, 1, self.d_rotate).expand(-1, -1, self.num_head, -1)
|
|
|
|
|
|
|
|
|
| Q_rotate, K_rotate = apply_rotary_emb(
|
| Q_rotate,
|
| K_rotate,
|
| self.freqs_cos,
|
| self.freqs_sin
|
| )
|
|
|
| Q = torch.cat([Q_state, Q_rotate], dim=-1).transpose(1, 2)
|
| K = torch.cat([K_state, K_rotate], dim=-1).transpose(1, 2)
|
| V = V_state.transpose(1, 2)
|
|
|
| if self.flash:
|
| att_output = F.scaled_dot_product_attention(
|
| Q, K, V,
|
| dropout_p=self.attn_dropout_p if self.training else 0.0,
|
| is_causal=True
|
| )
|
| else:
|
| scaler = 1.0 / math.sqrt(self.d_head + self.d_rotate)
|
| att_matrix = (Q @ K.transpose(-2, -1)) * scaler
|
| mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).view(1, 1, seq_len, seq_len)
|
| att_matrix = att_matrix.masked_fill(mask == 0, float('-inf'))
|
| att_score = self.dropout(F.softmax(att_matrix, dim=-1))
|
| att_output = att_score @ V
|
|
|
| att_output = att_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
|
|
| return self.output_proj(att_output)
|
|
|
| class Router(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
|
|
|
|
| self.top_k = config.top_k
|
| self.n_exp = config.n_exp
|
| assert self.top_k >= 1 and self.top_k <= config.n_exp
|
| self.use_noisy_top_k = config.use_noisy_top_k
|
| self.train_capacity = config.train_capacity
|
| self.eval_capacity = config.eval_capacity
|
| self.min_capacity = config.min_capacity
|
| self.router_use_full_prec = config.router_use_full_prec
|
|
|
|
|
| self.use_aux_loss = config.use_aux_loss
|
| self.use_router_z_loss = config.use_router_z_loss
|
|
|
|
|
|
|
| self.w_g = nn.Linear(config.n_embd, config.n_exp, bias=False)
|
| self.w_g.router_marker = True
|
| self.w_noise = nn.Linear(config.n_embd, config.n_exp, bias=False) if self.use_noisy_top_k else None
|
|
|
| def forward(self, x):
|
|
|
|
|
|
|
| device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| ctx = nullcontext() if not self.router_use_full_prec else torch.amp.autocast(device_type=device_type, enabled=False)
|
|
|
| with ctx:
|
| B, T, _ = x.size()
|
| num_tokens = B * T
|
|
|
|
|
| logits = self.w_g(x)
|
| if self.use_noisy_top_k:
|
|
|
| noise = F.softplus(self.w_noise(x))
|
| noise *= torch.randn_like(noise)
|
| logits += noise
|
|
|
|
|
|
|
| if self.use_router_z_loss:
|
| z_loss = self.compute_router_z_loss(logits)
|
| MANAGER.add_router_z_loss(z_loss)
|
|
|
|
|
| top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| router_probs = torch.full_like(logits, float('-inf'))
|
| router_probs.scatter_(-1, top_k_indices, top_k_logits)
|
| router_probs = F.softmax(router_probs, dim=-1)
|
|
|
|
|
| router_probs = F.softmax(logits, dim=-1)
|
|
|
|
|
|
|
|
|
| if self.use_aux_loss:
|
| aux_loss = self.compute_aux_loss(router_probs, top_k_indices)
|
| MANAGER.add_aux_loss(aux_loss)
|
|
|
|
|
| exp_capacity = self.get_capacity(num_tokens)
|
|
|
|
|
|
|
| exp_mask = F.one_hot(top_k_indices, num_classes=self.n_exp)
|
| exp_mask = exp_mask.view(num_tokens, self.top_k, self.n_exp)
|
| exp_mask = exp_mask.permute(1, 0, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
| exp_rank = exp_mask.reshape(self.top_k * num_tokens, self.n_exp)
|
| exp_rank = torch.cumsum(exp_rank, dim=0) - 1
|
| exp_rank = exp_rank.reshape(self.top_k, num_tokens, self.n_exp)
|
|
|
|
|
|
|
| exp_mask *= torch.lt(exp_rank, exp_capacity)
|
| used_capacity = torch.sum(exp_mask, dim=(0, 1))
|
|
|
|
|
|
|
|
|
|
|
|
|
| exp_rank = torch.sum(exp_mask * exp_rank, dim=-1)
|
|
|
|
|
| router_probs = router_probs.view(num_tokens, self.n_exp)[None, :]
|
| exp_weights = exp_mask * router_probs
|
|
|
|
|
|
|
| exp_rank_sc = F.one_hot(exp_rank, num_classes=exp_capacity)
|
|
|
|
|
|
|
|
|
| cb_weight = torch.sum(exp_weights.unsqueeze(3) * exp_rank_sc.unsqueeze(2), dim=0)
|
| sec_mask = cb_weight.bool()
|
| return used_capacity, cb_weight, sec_mask
|
|
|
| def compute_aux_loss(self, expert_probs: torch.Tensor, indices: torch.Tensor):
|
| """
|
| Computes Switch Transformer auxiliary loss (https://arxiv.org/abs/2101.03961)
|
| See equations (4)-(6) on page 7
|
| """
|
|
|
|
|
|
|
|
|
| with torch.no_grad():
|
| one_hot_indices = F.one_hot(indices, num_classes=self.n_exp)
|
| one_hot_indices = torch.sum(one_hot_indices.float(), dim=2)
|
| tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1))
|
|
|
|
|
| prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1))
|
|
|
|
|
|
|
| return self.n_exp * torch.sum(prob_per_expert * tokens_per_expert)
|
|
|
| def compute_router_z_loss(self, logits: torch.Tensor):
|
| """
|
| Computes ST-MoE router z loss (https://arxiv.org/abs/2202.08906)
|
| See equation (5) on page 7
|
| """
|
|
|
|
|
|
|
|
|
|
|
|
|
| z_loss = torch.logsumexp(logits, dim=-1) ** 2.0
|
|
|
|
|
| return torch.mean(z_loss)
|
|
|
| def get_capacity(self, tokens_per_batch):
|
|
|
|
|
| capacity_factor = self.train_capacity if self.training else self.eval_capacity
|
| capacity = math.floor(self.top_k * capacity_factor * tokens_per_batch / self.n_exp)
|
| capacity += capacity % 2
|
| capacity = max(capacity, self.min_capacity)
|
| assert capacity > 0
|
| return int(capacity)
|
|
|
|
|
| class MLP(nn.Module):
|
| def __init__(self, config, ffn_dim=None):
|
| super().__init__()
|
|
|
| if ffn_dim==None:
|
| ffn_dim = config.ffn_dim
|
|
|
| self.fc1 = nn.Linear(config.n_embd, 2 * ffn_dim, bias=config.bias)
|
| self.fc1.is_swiglu = True
|
| self.swish = nn.SiLU()
|
| self.fc2 = nn.Linear(ffn_dim, config.n_embd, bias=config.bias)
|
| self.fc2.output_proj_marker = True
|
|
|
| self.dropout1 = nn.Dropout(config.dropout)
|
| self.dropout2 = nn.Dropout(config.dropout)
|
|
|
|
|
|
|
|
|
| def forward(self, x):
|
| x = self.fc1(x)
|
|
|
|
|
| x, gate = x.chunk(2, dim=-1)
|
| x = x * self.swish(gate)
|
|
|
| x = self.dropout1(x)
|
| x = self.fc2(x)
|
| return self.dropout2(x)
|
|
|
|
|
| class MLPExperts(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| self.n_exp = config.n_exp
|
| self.n_embd = config.n_embd
|
| self.bias = config.bias
|
|
|
| self.c_fc = nn.Parameter(torch.empty(self.n_exp, self.n_embd, 2 * config.expert_dim))
|
| self.c_proj = nn.Parameter(torch.empty(self.n_exp, config.expert_dim, self.n_embd))
|
|
|
| self.swish = nn.SiLU()
|
| self.dropout = nn.Dropout(config.dropout)
|
|
|
| def forward(self, x):
|
| x = torch.bmm(x, self.c_fc)
|
|
|
| x, gate = x.chunk(2, dim=-1)
|
| x = x * self.swish(gate)
|
|
|
| x = torch.bmm(x, self.c_proj)
|
|
|
| return self.dropout(x)
|
|
|
| class MOELayer(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| self.router = Router(config)
|
| self.experts = MLPExperts(config)
|
|
|
| self.shared_expert = MLP(config, ffn_dim=config.shared_dim)
|
|
|
| def forward(self, x: torch.Tensor):
|
| B, T, n_embd = x.size()
|
| num_tokens = (B * T)
|
|
|
| shared_out = self.shared_expert(x)
|
|
|
| used_capacity, exp_weight, exp_mask = self.router(x)
|
|
|
| x = x.view(num_tokens, n_embd)
|
|
|
|
|
| exp_batches = exp_mask.permute(1, 2, 0).type_as(x) @ x
|
|
|
| exp_out = self.experts(exp_batches)
|
|
|
|
|
|
|
|
|
| exp_weight = exp_weight.view(num_tokens, -1)
|
| exp_out = exp_out.view(-1, n_embd)
|
| output = exp_weight @ exp_out
|
|
|
| moe_out = output.view(B, T, n_embd)
|
|
|
| return moe_out + shared_out
|
|
|
| class Block(nn.Module):
|
|
|
| def __init__(self, config, use_moe=False):
|
| super().__init__()
|
| self.ln_1 = nn.RMSNorm(config.n_embd)
|
| self.attn = MultiHeadLatentAttention(config)
|
| self.ln_2 = nn.RMSNorm(config.n_embd)
|
| if use_moe:
|
| self.mlp = MOELayer(config)
|
| else:
|
| self.mlp = MLP(config)
|
|
|
| def forward(self, x):
|
| x = x + self.attn(self.ln_1(x))
|
| x = x + self.mlp(self.ln_2(x))
|
| return x
|
|
|
| @dataclass
|
| class GPTConfig:
|
| block_size: int = 2048
|
| vocab_size: int = 50304
|
| n_layer: int = 24
|
| n_head: int = 10
|
| n_embd: int = 640
|
| dropout: float = 0.0
|
| ffn_dim: int = 640*4
|
| bias: bool = False
|
|
|
|
|
| d_c: int = 192
|
| d_c1: int = 192
|
| d_rotate: int = 64
|
| theta: float = 10000.0
|
|
|
|
|
| n_exp: int = 12
|
| top_k: int = 3
|
| expert_dim: int = 640
|
| shared_dim: int = 640
|
| stride: int = 2
|
|
|
|
|
| use_aux_loss: bool = True
|
| use_router_z_loss: bool = True
|
| use_noisy_top_k: bool = True
|
| aux_loss_weight: float = 0.01
|
| router_z_loss_weight: float = 0.001
|
| train_capacity: float = 1.25
|
| eval_capacity: float = 2.0
|
| min_capacity: int = 4
|
| use_switch_tfm_init: bool = True
|
| switch_tfm_init_scale: float = 1.0
|
| router_use_full_prec: bool = True
|
|
|
|
|
| batch_size: int = 8
|
| grad_acc: int = 128
|
| num_train_epochs: int = 1
|
| learning_rate: float = 3e-4
|
| weight_decay: float = 0.1
|
| betas: tuple = (0.9, 0.95)
|
| warm_up: int = 5000
|
|
|
| eos_token_id = 0
|
| bos_token_id = 0
|
| pad_token_id = 0
|
|
|
| class HybridOptimizer(torch.optim.Optimizer):
|
| def __init__(self, optimizers):
|
| self.optimizers = optimizers
|
| self.param_groups = []
|
| for opt in self.optimizers:
|
| self.param_groups.extend(opt.param_groups)
|
|
|
| def step(self, closure=None):
|
| loss = None
|
| if closure is not None:
|
| loss = closure()
|
| for opt in self.optimizers:
|
| opt.step()
|
| return loss
|
|
|
| def zero_grad(self, set_to_none=True):
|
| for opt in self.optimizers:
|
| opt.zero_grad(set_to_none=set_to_none)
|
|
|
| def state_dict(self):
|
| return [opt.state_dict() for opt in self.optimizers]
|
|
|
| def load_state_dict(self, state_dict):
|
| for opt, sd in zip(self.optimizers, state_dict):
|
| opt.load_state_dict(sd)
|
|
|
| class GPT(nn.Module):
|
|
|
| def __init__(self, config):
|
| super().__init__()
|
| assert config.vocab_size is not None
|
| assert config.block_size is not None
|
| self.config = config
|
|
|
| self.can_return_loss = True
|
| self.accepts_loss_kwargs = False
|
|
|
| if config.n_exp == 1:
|
| blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
|
| else:
|
| blocks = []
|
| for i in range(config.n_layer):
|
| use_moe = False if (i < config.stride or i > config.n_layer - config.stride-1) else True
|
| blocks.append(Block(config, use_moe=use_moe))
|
| blocks = nn.ModuleList(blocks)
|
|
|
| self.transformer = nn.ModuleDict(dict(
|
| wte = nn.Embedding(config.vocab_size, config.n_embd),
|
| h = blocks,
|
| ln_f = nn.RMSNorm(config.n_embd),
|
| ))
|
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| self.transformer.wte.weight = self.lm_head.weight
|
| self.apply(self._init_weights)
|
|
|
| print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
|
|
|
| def get_num_params(self, non_embedding=True):
|
| n_params = sum(p.numel() for p in self.parameters())
|
| return n_params
|
|
|
| @torch.no_grad()
|
| def _init_weights(self, module):
|
|
|
| scale = self.config.switch_tfm_init_scale if hasattr(self.config, 'switch_tfm_init_scale') else 1.0
|
| n_layer = self.config.n_layer
|
|
|
| if isinstance(module, nn.Linear):
|
|
|
| w_fan_in = module.weight.shape[-1]
|
| base_std = (scale / w_fan_in) ** 0.5
|
|
|
|
|
| if hasattr(module, 'router_marker'):
|
|
|
| final_std = 0.01
|
| elif hasattr(module, 'output_proj_marker'):
|
|
|
| final_std = base_std / math.sqrt(2 * n_layer)
|
| elif hasattr(module, 'is_attention'):
|
|
|
| final_std = base_std * 0.7
|
| else:
|
|
|
| final_std = base_std
|
|
|
|
|
| torch.nn.init.trunc_normal_(
|
| module.weight, mean=0.0, std=final_std, a=-2*final_std, b=2*final_std
|
| )
|
|
|
| if module.bias is not None:
|
| torch.nn.init.zeros_(module.bias)
|
|
|
|
|
| elif isinstance(module, MLPExperts):
|
|
|
| c_fc_fan_in = module.c_fc.shape[-2]
|
| final_fc_std = (scale / c_fc_fan_in) ** 0.5
|
| torch.nn.init.trunc_normal_(module.c_fc, std=final_fc_std, a=-2*final_fc_std, b=2*final_fc_std)
|
|
|
|
|
| c_proj_fan_in = module.c_proj.shape[-2]
|
|
|
| final_proj_std = ((scale / c_proj_fan_in) ** 0.5) / math.sqrt(2 * n_layer)
|
| torch.nn.init.trunc_normal_(module.c_proj, std=final_proj_std, a=-2*final_proj_std, b=2*final_proj_std)
|
|
|
| elif isinstance(module, nn.Embedding):
|
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def forward(self, input_ids, labels=None, attention_mask=None, **kwargs):
|
| _, t = input_ids.size()
|
| assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"
|
|
|
| x = self.transformer.wte(input_ids)
|
| for block in self.transformer.h:
|
| x = block(x)
|
| x = self.transformer.ln_f(x)
|
|
|
| if labels is not None:
|
| logits = self.lm_head(x)
|
|
|
| shift_logits = logits[:, :-1, :].contiguous()
|
| shift_labels = labels[:, 1:].contiguous()
|
|
|
|
|
|
|
| loss_fct = nn.CrossEntropyLoss(
|
| ignore_index=-100,
|
| label_smoothing=0.1,
|
| reduction='mean'
|
| )
|
|
|
| main_loss = loss_fct(
|
| shift_logits.view(-1, shift_logits.size(-1)),
|
| shift_labels.view(-1)
|
| )
|
|
|
| loss = main_loss
|
|
|
| if self.config.n_exp > 1:
|
| if self.config.use_aux_loss:
|
| loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
|
| MANAGER.reset_aux_loss()
|
|
|
| if self.config.use_router_z_loss:
|
| loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
|
| MANAGER.reset_router_z_loss()
|
| else:
|
| logits = self.lm_head(x[:, [-1], :])
|
| loss = None
|
|
|
| return CausalLMOutput(loss=loss, logits=logits)
|
|
|
| def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
|
|
|
|
| param_dict = {pn: p for pn, p in self.named_parameters()}
|
|
|
| param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
|
|
|
|
|
|
| decay_params = [p for n, p in param_dict.items() if (p.dim() >= 2 and not n.endswith('bias'))]
|
| nodecay_params = [p for n, p in param_dict.items() if (p.dim() < 2 or n.endswith('bias'))]
|
| optim_groups = [
|
| {'params': decay_params, 'weight_decay': weight_decay},
|
| {'params': nodecay_params, 'weight_decay': 0.0}
|
| ]
|
| num_decay_params = sum(p.numel() for p in decay_params)
|
| num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
| print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
| print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
|
|
| fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
| use_fused = fused_available and device_type == 'cuda'
|
| extra_args = dict(fused=True) if use_fused else dict()
|
| optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
| print(f"using fused AdamW: {use_fused}")
|
|
|
| return optimizer
|
|
|
| @torch.no_grad()
|
| def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
| for _ in range(max_new_tokens):
|
| idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
|
|
|
|
| outputs = self(idx_cond)
|
| logits = outputs.logits[:, -1, :] / temperature
|
|
|
| if top_k is not None:
|
| v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| logits[logits < v[:, [-1]]] = -float('Inf')
|
|
|
| probs = F.softmax(logits, dim=-1)
|
|
|
| idx_next = torch.multinomial(probs, num_samples=1)
|
| idx = torch.cat((idx, idx_next), dim=1)
|
|
|
| return idx |