| import math |
| from dataclasses import dataclass |
| from typing import Tuple, Optional, Literal |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
|
|
| from kernel import act_quant, weight_dequant, fp8_gemm |
|
|
|
|
| world_size = 1 |
| rank = 0 |
| block_size = 128 |
| gemm_impl: Literal["bf16", "fp8"] = "bf16" |
| attn_impl: Literal["naive", "absorb"] = "absorb" |
|
|
| @dataclass |
| class ModelArgs: |
| max_batch_size: int = 8 |
| max_seq_len: int = 4096 * 4 |
| dtype: Literal["bf16", "fp8"] = "bf16" |
| vocab_size: int = 102400 |
| dim: int = 2048 |
| inter_dim: int = 10944 |
| moe_inter_dim: int = 1408 |
| n_layers: int = 27 |
| n_dense_layers: int = 1 |
| n_heads: int = 16 |
| |
| n_routed_experts: int = 64 |
| n_shared_experts: int = 2 |
| n_activated_experts: int = 6 |
| n_expert_groups: int = 1 |
| n_limited_groups: int = 1 |
| score_func: Literal["softmax", "sigmoid"] = "softmax" |
| route_scale: float = 1. |
| |
| q_lora_rank: int = 0 |
| kv_lora_rank: int = 512 |
| qk_nope_head_dim: int = 128 |
| qk_rope_head_dim: int = 64 |
| v_head_dim: int = 128 |
| |
| original_seq_len: int = 4096 |
| rope_theta: float = 10000.0 |
| rope_factor: float = 40 |
| beta_fast: int = 32 |
| beta_slow: int = 1 |
| mscale: float = 1. |
|
|
|
|
| class ParallelEmbedding(nn.Module): |
| def __init__(self, vocab_size: int, dim: int): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.dim = dim |
| assert vocab_size % world_size == 0 |
| self.part_vocab_size = (vocab_size // world_size) |
| self.vocab_start_idx = rank * self.part_vocab_size |
| self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size |
| self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if world_size > 1: |
| mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) |
| x = x - self.vocab_start_idx |
| x[mask] = 0 |
| y = F.embedding(x, self.weight) |
| if world_size > 1: |
| y[mask] = 0 |
| dist.all_reduce(y) |
| return y |
|
|
|
|
| def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: |
| if weight.element_size() > 1: |
| return F.linear(x, weight, bias) |
| elif gemm_impl == "bf16": |
| weight = weight_dequant(weight, weight.scale) |
| return F.linear(x, weight, bias) |
| else: |
| x, scale = act_quant(x, block_size) |
| y = fp8_gemm(x, scale, weight, weight.scale) |
| if bias is not None: |
| y += bias |
| return y |
|
|
|
|
| class Linear(nn.Module): |
| dtype = torch.bfloat16 |
|
|
| def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) |
| if self.weight.element_size() == 1: |
| scale_out_features = (out_features + block_size - 1) // block_size |
| scale_in_features = (in_features + block_size - 1) // block_size |
| self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) |
| else: |
| self.register_parameter("scale", None) |
| if bias: |
| self.bias = nn.Parameter(torch.empty(self.part_out_features)) |
| else: |
| self.register_parameter("bias", None) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return linear(x, self.weight, self.bias) |
|
|
|
|
| class ColumnParallelLinear(Linear): |
| def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): |
| assert out_features % world_size == 0 |
| self.part_out_features = out_features // world_size |
| super().__init__(in_features, self.part_out_features, bias, dtype) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| y = linear(x, self.weight, self.bias) |
| return y |
|
|
|
|
| class RowParallelLinear(Linear): |
| def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): |
| assert in_features % world_size == 0 |
| self.part_in_features = in_features // world_size |
| super().__init__(self.part_in_features, out_features, bias, dtype) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| y = linear(x, self.weight) |
| if world_size > 1: |
| dist.all_reduce(y) |
| if self.bias is not None: |
| y += self.bias |
| return y |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor): |
| x = x.float() |
| y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| return y.type_as(self.weight) * self.weight |
|
|
|
|
| def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: |
| dim = args.qk_rope_head_dim |
| seqlen = args.max_seq_len |
| beta_fast = args.beta_fast |
| beta_slow = args.beta_slow |
| base = args.rope_theta |
| factor = args.rope_factor |
|
|
| def find_correction_dim(num_rotations, dim, base, max_seq_len): |
| return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) |
|
|
| def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): |
| low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) |
| high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) |
| return max(low, 0), min(high, dim-1) |
|
|
| def linear_ramp_factor(min, max, dim): |
| if min == max: |
| max += 0.001 |
| linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) |
| ramp_func = torch.clamp(linear_func, 0, 1) |
| return ramp_func |
|
|
| freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
| if seqlen > args.original_seq_len: |
| low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) |
| smooth = 1 - linear_ramp_factor(low, high, dim // 2) |
| freqs = freqs / factor * (1 - smooth) + freqs * smooth |
|
|
| t = torch.arange(seqlen) |
| freqs = torch.outer(t, freqs) |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: |
| dtype = x.dtype |
| x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) |
| freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) |
| y = torch.view_as_real(x * freqs_cis).flatten(3) |
| return y.to(dtype) |
|
|
|
|
| class MLA(nn.Module): |
| def __init__(self, args: ModelArgs): |
| super().__init__() |
| self.dim = args.dim |
| self.n_heads = args.n_heads |
| self.n_local_heads = args.n_heads // world_size |
| self.q_lora_rank = args.q_lora_rank |
| self.kv_lora_rank = args.kv_lora_rank |
| self.qk_nope_head_dim = args.qk_nope_head_dim |
| self.qk_rope_head_dim = args.qk_rope_head_dim |
| self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim |
| self.v_head_dim = args.v_head_dim |
|
|
| if self.q_lora_rank == 0: |
| self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim) |
| else: |
| self.wq_a = Linear(self.dim, self.q_lora_rank) |
| self.q_norm = RMSNorm(self.q_lora_rank) |
| self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim) |
| self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) |
| self.kv_norm = RMSNorm(self.kv_lora_rank) |
| self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) |
| self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) |
| self.softmax_scale = self.qk_head_dim ** -0.5 |
| if args.max_seq_len > args.original_seq_len: |
| mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 |
| self.softmax_scale = self.softmax_scale * mscale * mscale |
|
|
| if attn_impl == "naive": |
| self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False) |
| self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False) |
| else: |
| self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) |
| self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) |
|
|
| def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): |
| bsz, seqlen, _ = x.size() |
| end_pos = start_pos + seqlen |
| if self.q_lora_rank == 0: |
| q = self.wq(x) |
| else: |
| q = self.wq_b(self.q_norm(self.wq_a(x))) |
| q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) |
| q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) |
| q_pe = apply_rotary_emb(q_pe, freqs_cis) |
| kv = self.wkv_a(x) |
| kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) |
| k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) |
| if attn_impl == "naive": |
| q = torch.cat([q_nope, q_pe], dim=-1) |
| kv = self.wkv_b(self.kv_norm(kv)) |
| kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) |
| k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) |
| k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) |
| self.k_cache[:bsz, start_pos:end_pos] = k |
| self.v_cache[:bsz, start_pos:end_pos] = v |
| scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale |
| else: |
| wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) |
| wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) |
| q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) |
| self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) |
| self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) |
| scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + |
| torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale |
| if mask is not None: |
| scores += mask.unsqueeze(1) |
| scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) |
| if attn_impl == "naive": |
| x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) |
| else: |
| x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) |
| x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) |
| x = self.wo(x.flatten(2)) |
| return x |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, dim: int, inter_dim: int): |
| super().__init__() |
| self.w1 = ColumnParallelLinear(dim, inter_dim) |
| self.w2 = RowParallelLinear(inter_dim, dim) |
| self.w3 = ColumnParallelLinear(dim, inter_dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| class Gate(nn.Module): |
| def __init__(self, args: ModelArgs): |
| super().__init__() |
| self.dim = args.dim |
| self.topk = args.n_activated_experts |
| self.n_groups = args.n_expert_groups |
| self.topk_groups = args.n_limited_groups |
| self.score_func = args.score_func |
| self.route_scale = args.route_scale |
| self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) |
| self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| scores = linear(x, self.weight) |
| if self.score_func == "softmax": |
| scores = scores.softmax(dim=-1, dtype=torch.float32) |
| else: |
| scores = scores.sigmoid() |
| original_scores = scores |
| if self.bias is not None: |
| scores = scores + self.bias |
| if self.n_groups > 1: |
| scores = scores.view(x.size(0), self.n_groups, -1) |
| if self.bias is None: |
| group_scores = scores.amax(dim=-1) |
| else: |
| group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) |
| indices = group_scores.topk(self.topk_groups, dim=-1)[1] |
| mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True) |
| scores = (scores * mask.unsqueeze(-1)).flatten(1) |
| indices = torch.topk(scores, self.topk, dim=-1)[1] |
| weights = original_scores.gather(1, indices) |
| if self.score_func == "sigmoid": |
| weights /= weights.sum(dim=-1, keepdim=True) |
| weights *= self.route_scale |
| return weights.type_as(x), indices |
|
|
|
|
| class Expert(nn.Module): |
| def __init__(self, dim: int, inter_dim: int): |
| super().__init__() |
| self.w1 = Linear(dim, inter_dim) |
| self.w2 = Linear(inter_dim, dim) |
| self.w3 = Linear(dim, inter_dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| class MoE(nn.Module): |
| def __init__(self, args: ModelArgs): |
| super().__init__() |
| self.dim = args.dim |
| assert args.n_routed_experts % world_size == 0 |
| self.n_routed_experts = args.n_routed_experts |
| self.n_local_experts = args.n_routed_experts // world_size |
| self.n_activated_experts = args.n_activated_experts |
| self.experts_start_idx = rank * self.n_local_experts |
| self.experts_end_idx = self.experts_start_idx + self.n_local_experts |
| self.gate = Gate(args) |
| self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None |
| for i in range(self.n_routed_experts)]) |
| self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| shape = x.size() |
| x = x.view(-1, self.dim) |
| weights, indices = self.gate(x) |
| y = torch.zeros_like(x) |
| counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() |
| for i in range(self.experts_start_idx, self.experts_end_idx): |
| if counts[i] == 0: |
| continue |
| expert = self.experts[i] |
| idx, top = torch.where(indices == i) |
| y[idx] += expert(x[idx]) * weights[idx, top, None] |
| z = self.shared_experts(x) |
| if world_size > 1: |
| dist.all_reduce(y) |
| return (y + z).view(shape) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, layer_id: int, args: ModelArgs): |
| super().__init__() |
| self.attn = MLA(args) |
| self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) |
| self.attn_norm = RMSNorm(args.dim) |
| self.ffn_norm = RMSNorm(args.dim) |
|
|
| def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: |
| x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask) |
| x = x + self.ffn(self.ffn_norm(x)) |
| return x |
|
|
|
|
| class Transformer(nn.Module): |
| def __init__(self, args: ModelArgs): |
| global world_size, rank |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 |
| rank = dist.get_rank() if dist.is_initialized() else 0 |
| Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 |
| super().__init__() |
| self.max_seq_len = args.max_seq_len |
| self.embed = ParallelEmbedding(args.vocab_size, args.dim) |
| self.layers = torch.nn.ModuleList() |
| for layer_id in range(args.n_layers): |
| self.layers.append(Block(layer_id, args)) |
| self.norm = RMSNorm(args.dim) |
| self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype()) |
| self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) |
|
|
| @torch.inference_mode() |
| def forward(self, tokens: torch.Tensor, start_pos: int = 0): |
| seqlen = tokens.size(1) |
| h = self.embed(tokens) |
| freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] |
| mask = None |
| if seqlen > 1: |
| mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) |
| for layer in self.layers: |
| h = layer(h, start_pos, freqs_cis, mask) |
| h = self.norm(h)[:, -1] |
| logits = self.head(h) |
| if world_size > 1: |
| all_logits = [torch.empty_like(logits) for _ in range(world_size)] |
| dist.all_gather(all_logits, logits) |
| logits = torch.cat(all_logits, dim=-1) |
| return logits |
|
|
|
|
| if __name__ == "__main__": |
| torch.set_default_dtype(torch.bfloat16) |
| torch.set_default_device("cuda") |
| torch.manual_seed(0) |
| args = ModelArgs() |
| x = torch.randint(0, args.vocab_size, (2, 128)) |
| model = Transformer(args) |
| print(model(x).size()) |
|
|