ikaganacar commited on
Commit
2414ac8
·
1 Parent(s): 61dc72d

Model Architecture

Browse files
Model_Architecture/kernel.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SOURCE: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
2
+
3
+ from typing import Tuple, Optional
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+ from triton import Config
9
+
10
+
11
+ @triton.jit
12
+ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr):
13
+ """
14
+ Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
15
+
16
+ Args:
17
+ x_ptr (triton.Pointer): Pointer to the input tensor.
18
+ y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
19
+ s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
20
+ BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
21
+
22
+ Returns:
23
+ None
24
+ """
25
+ pid = tl.program_id(axis=0)
26
+ offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
27
+ x = tl.load(x_ptr + offs).to(tl.float32)
28
+ amax = tl.max(tl.abs(x)) # reduction
29
+ amax = tl.maximum(amax, 1e-4) # clamp to 1e-4
30
+ s = amax / 448.
31
+ if scale_fmt == "ue8m0":
32
+ exp = tl.math.ceil(tl.math.log2(s))
33
+ s = tl.math.exp2(exp)
34
+ y = x / s
35
+ y = y.to(y_ptr.dtype.element_ty)
36
+ tl.store(y_ptr + offs, y)
37
+ tl.store(s_ptr + pid, s)
38
+
39
+
40
+ def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]:
41
+ """
42
+ Quantizes the input tensor `x` using block-wise quantization.
43
+
44
+ Args:
45
+ x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
46
+ block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
47
+ scale_fmt (Optional[str], optional): The format of the scale. Default is None.
48
+ Returns:
49
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
50
+ - The quantized tensor with dtype `torch.float8_e4m3fn`.
51
+ - A tensor of scaling factors with dtype `torch.float32`.
52
+ """
53
+ assert x.is_contiguous(), 'Input tensor must be contiguous'
54
+ assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
55
+ y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
56
+ s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
57
+ grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
58
+ act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, scale_fmt=scale_fmt)
59
+ return y, s
60
+
61
+
62
+ @triton.jit
63
+ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
64
+ """
65
+ Dequantizes weights using the provided scaling factors and stores the result.
66
+
67
+ Args:
68
+ x_ptr (tl.pointer): Pointer to the quantized weights.
69
+ s_ptr (tl.pointer): Pointer to the scaling factors.
70
+ y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
71
+ M (int): Number of rows in the weight matrix.
72
+ N (int): Number of columns in the weight matrix.
73
+ BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
74
+
75
+ Returns:
76
+ None
77
+ """
78
+ pid_m = tl.program_id(axis=0)
79
+ pid_n = tl.program_id(axis=1)
80
+ n = tl.cdiv(N, BLOCK_SIZE)
81
+ offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
82
+ offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
83
+ offs = offs_m[:, None] * N + offs_n[None, :]
84
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
85
+ x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
86
+ s = tl.load(s_ptr + pid_m * n + pid_n)
87
+ y = x * s
88
+ tl.store(y_ptr + offs, y, mask=mask)
89
+
90
+
91
+ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
92
+ """
93
+ Dequantizes the given weight tensor using the provided scale tensor.
94
+
95
+ Args:
96
+ x (torch.Tensor): The quantized weight tensor of shape (M, N).
97
+ s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size).
98
+ block_size (int, optional): The block size to use for dequantization. Defaults to 128.
99
+
100
+ Returns:
101
+ torch.Tensor: The dequantized weight tensor of the same shape as `x`.
102
+
103
+ Raises:
104
+ AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
105
+ """
106
+ assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
107
+ assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
108
+ M, N = x.size()
109
+ y = torch.empty_like(x, dtype=torch.get_default_dtype())
110
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
111
+ weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
112
+ return y
113
+
114
+
115
+ fp8_gemm_configs = [
116
+ Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
117
+ for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
118
+ ]
119
+
120
+ @triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
121
+ @triton.jit
122
+ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
123
+ a_s_ptr, b_s_ptr,
124
+ M, N: tl.constexpr, K: tl.constexpr,
125
+ BLOCK_SIZE_M: tl.constexpr,
126
+ BLOCK_SIZE_N: tl.constexpr,
127
+ BLOCK_SIZE_K: tl.constexpr):
128
+ """
129
+ Performs a matrix multiplication operation on FP8 matrices with scaling factors.
130
+
131
+ Args:
132
+ a_ptr (tl.tensor): Pointer to the first input matrix A.
133
+ b_ptr (tl.tensor): Pointer to the second input matrix B.
134
+ c_ptr (tl.tensor): Pointer to the output matrix C.
135
+ a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
136
+ b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
137
+ M (int): Number of rows in matrix A and C.
138
+ N (tl.constexpr): Number of columns in matrix B and C.
139
+ K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
140
+ BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
141
+ BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
142
+ BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
143
+
144
+ Returns:
145
+ None
146
+ """
147
+ pid_m = tl.program_id(axis=0)
148
+ pid_n = tl.program_id(axis=1)
149
+ k = tl.cdiv(K, BLOCK_SIZE_K)
150
+ offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
151
+ offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
152
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
153
+ a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
154
+ b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
155
+ a_s_ptrs = a_s_ptr + offs_m * k
156
+ b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
157
+
158
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
159
+ for i in range(k):
160
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
161
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
162
+ a_s = tl.load(a_s_ptrs)
163
+ b_s = tl.load(b_s_ptrs)
164
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
165
+ a_ptrs += BLOCK_SIZE_K
166
+ b_ptrs += BLOCK_SIZE_K
167
+ a_s_ptrs += 1
168
+ b_s_ptrs += 1
169
+ c = accumulator.to(c_ptr.dtype.element_ty)
170
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
171
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
172
+ c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
173
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
174
+ tl.store(c_ptrs, c, mask=mask)
175
+
176
+
177
+ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
178
+ """
179
+ Perform a matrix multiplication using FP8 precision.
180
+
181
+ Args:
182
+ a (torch.Tensor): The first input matrix, must be contiguous.
183
+ a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
184
+ b (torch.Tensor): The second input matrix, must be contiguous.
185
+ b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
186
+
187
+ Returns:
188
+ torch.Tensor: The result of the matrix multiplication.
189
+ """
190
+ assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
191
+ assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
192
+ K = a.size(-1)
193
+ M = a.numel() // K
194
+ N = b.size(0)
195
+ c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
196
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
197
+ fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
198
+ return c
Model_Architecture/model.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import Dataset, DataLoader
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+ from typing import Tuple, Optional, Literal
9
+
10
+ import torch.nn.functional as F
11
+ import torch.distributed as dist
12
+
13
+ from kernel import act_quant, weight_dequant, fp8_gemm
14
+
15
+ #####################################
16
+ # CONFIGURATION
17
+ #####################################
18
+ @dataclass
19
+ class ModelArgs:
20
+ max_batch_size: int = 8
21
+ max_seq_len: int = 4096 * 4
22
+ dtype: Literal["bf16", "fp8"] = "bf16"
23
+ scale_fmt: Optional[str] = None
24
+ vocab_size: int = 102400
25
+ dim: int = 2048
26
+ inter_dim: int = 10944
27
+ moe_inter_dim: int = 1408
28
+ n_layers: int = 27
29
+ n_dense_layers: int = 1
30
+ n_heads: int = 16
31
+ # moe
32
+ n_routed_experts: int = 64
33
+ n_shared_experts: int = 2
34
+ n_activated_experts: int = 6
35
+ n_expert_groups: int = 1
36
+ n_limited_groups: int = 1
37
+ score_func: Literal["softmax", "sigmoid"] = "softmax"
38
+ route_scale: float = 1.
39
+ # mla
40
+ q_lora_rank: int = 0
41
+ kv_lora_rank: int = 512
42
+ qk_nope_head_dim: int = 128
43
+ qk_rope_head_dim: int = 64
44
+ v_head_dim: int = 128
45
+ # yarn
46
+ original_seq_len: int = 4096
47
+ rope_theta: float = 10000.0
48
+ rope_factor: float = 40
49
+ beta_fast: int = 32
50
+ beta_slow: int = 1
51
+ mscale: float = 1.
52
+
53
+ # others
54
+ world_size = 1
55
+ rank = 0
56
+ block_size = 128
57
+ gemm_impl: Literal["bf16", "fp8"] = "bf16"
58
+
59
+
60
+ #####################################
61
+ # DATA
62
+ #####################################
63
+ class Dataset(Dataset):
64
+ def __init__(self, txt, tokenizer, max_length, stride):
65
+ self.input_ids = []
66
+ self.target_ids = []
67
+
68
+ # Tokenize the entire text
69
+ token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
70
+
71
+ # Use a sliding window to chunk the book into overlapping sequences of max_length
72
+ for i in range(0, len(token_ids) - max_length, stride):
73
+ input_chunk = token_ids[i:i + max_length]
74
+ target_chunk = token_ids[i + 1: i + max_length + 1]
75
+ self.input_ids.append(torch.tensor(input_chunk))
76
+ self.target_ids.append(torch.tensor(target_chunk))
77
+
78
+ def __len__(self):
79
+ return len(self.input_ids)
80
+
81
+ def __getitem__(self, idx):
82
+ return self.input_ids[idx], self.target_ids[idx]
83
+
84
+
85
+ def create_dataloader(txt, batch_size=4, max_length=256,
86
+ stride=128, shuffle=True, drop_last=True, num_workers=0):
87
+ # Initialize the tokenizer
88
+ tokenizer = tiktoken.get_encoding("gpt2")
89
+
90
+ # Create dataset
91
+ dataset = Dataset(txt, tokenizer, max_length, stride)
92
+
93
+ # Create dataloader
94
+ dataloader = DataLoader(
95
+ dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
96
+
97
+ return dataloader
98
+
99
+ #####################################
100
+ # RoPE
101
+ #####################################
102
+ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
103
+ dim = args.qk_rope_head_dim
104
+ seqlen = args.max_seq_len
105
+ beta_fast = args.beta_fast
106
+ beta_slow = args.beta_slow
107
+ base = args.rope_theta
108
+ factor = args.rope_factor
109
+
110
+ def find_correction_dim(num_rotations, dim, base, max_seq_len):
111
+ return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
112
+
113
+ def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
114
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
115
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
116
+ return max(low, 0), min(high, dim-1)
117
+
118
+ def linear_ramp_factor(min, max, dim):
119
+ if min == max:
120
+ max += 0.001
121
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
122
+ ramp_func = torch.clamp(linear_func, 0, 1)
123
+ return ramp_func
124
+
125
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
126
+ if seqlen > args.original_seq_len:
127
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
128
+ smooth = 1 - linear_ramp_factor(low, high, dim // 2)
129
+ freqs = freqs / factor * (1 - smooth) + freqs * smooth
130
+
131
+ t = torch.arange(seqlen)
132
+ freqs = torch.outer(t, freqs)
133
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
134
+ return freqs_cis
135
+
136
+
137
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
138
+ dtype = x.dtype
139
+ x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
140
+ freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
141
+ y = torch.view_as_real(x * freqs_cis).flatten(3)
142
+ return y.to(dtype)
143
+
144
+
145
+ #####################################
146
+ # LINEAR LAYERS
147
+ #####################################
148
+
149
+ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor:
150
+
151
+ if weight.element_size() > 1:
152
+ return F.linear(x, weight, bias)
153
+ elif gemm_impl == "bf16":
154
+ weight = weight_dequant(weight, weight.scale)
155
+ return F.linear(x, weight, bias)
156
+ else:
157
+ x, scale = act_quant(x, block_size, scale_fmt)
158
+ y = fp8_gemm(x, scale, weight, weight.scale)
159
+ if bias is not None:
160
+ y += bias
161
+ return y
162
+
163
+
164
+ class Linear(nn.Module):
165
+ dtype = torch.bfloat16
166
+ scale_fmt: Optional[str] = None
167
+
168
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
169
+ super().__init__()
170
+ self.in_features = in_features
171
+ self.out_features = out_features
172
+ self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
173
+ if self.weight.element_size() == 1:
174
+ scale_out_features = (out_features + block_size - 1) // block_size
175
+ scale_in_features = (in_features + block_size - 1) // block_size
176
+ self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
177
+ else:
178
+ self.register_parameter("scale", None)
179
+ if bias:
180
+ self.bias = nn.Parameter(torch.empty(out_features))
181
+ else:
182
+ self.register_parameter("bias", None)
183
+
184
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
185
+
186
+ return linear(x, self.weight, self.bias, self.scale_fmt)
187
+
188
+
189
+ class ColumnParallelLinear(Linear):
190
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
191
+ assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
192
+ self.part_out_features = out_features // world_size
193
+ super().__init__(in_features, self.part_out_features, bias, dtype)
194
+
195
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
196
+ y = linear(x, self.weight, self.bias)
197
+ return y
198
+
199
+
200
+ class RowParallelLinear(Linear):
201
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
202
+ assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
203
+ self.part_in_features = in_features // world_size
204
+ super().__init__(self.part_in_features, out_features, bias, dtype)
205
+
206
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
207
+ y = linear(x, self.weight)
208
+ if world_size > 1:
209
+ dist.all_reduce(y)
210
+ if self.bias is not None:
211
+ y += self.bias
212
+ return y
213
+
214
+ #####################################
215
+ # NORMALIZATION
216
+ #####################################
217
+
218
+ class RMSNorm(nn.Module):
219
+ def __init__(self, dim: int, eps: float = 1e-6):
220
+ super().__init__()
221
+ self.dim = dim
222
+ self.eps = eps
223
+ self.weight = nn.Parameter(torch.ones(dim))
224
+
225
+ def forward(self, x: torch.Tensor):
226
+ return F.rms_norm(x, (self.dim,), self.weight, self.eps)
227
+
228
+
229
+ #####################################
230
+ # ATTENTION
231
+ #####################################
232
+
233
+ class MultiHeadLatentAttention(nn.Module):
234
+ def __init__(self, args: ModelArgs):
235
+ super().__init__()
236
+ self.dim = args.dim
237
+ self.n_heads = args.n_heads
238
+ self.n_local_heads = args.n_heads // world_size
239
+ self.q_lora_rank = args.q_lora_rank
240
+ self.kv_lora_rank = args.kv_lora_rank
241
+ self.qk_nope_head_dim = args.qk_nope_head_dim
242
+ self.qk_rope_head_dim = args.qk_rope_head_dim
243
+ self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
244
+ self.v_head_dim = args.v_head_dim
245
+
246
+ if self.q_lora_rank == 0:
247
+ self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
248
+ else:
249
+ self.wq_a = Linear(self.dim, self.q_lora_rank)
250
+ self.q_norm = RMSNorm(self.q_lora_rank)
251
+ self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
252
+
253
+ self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
254
+ self.kv_norm = RMSNorm(self.kv_lora_rank)
255
+ self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
256
+ self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
257
+ self.softmax_scale = self.qk_head_dim ** -0.5
258
+
259
+ if args.max_seq_len > args.original_seq_len:
260
+ mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
261
+ self.softmax_scale = self.softmax_scale * mscale * mscale
262
+
263
+
264
+ self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
265
+ self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
266
+
267
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
268
+
269
+ bsz, seqlen, _ = x.size()
270
+ end_pos = start_pos + seqlen
271
+ if self.q_lora_rank == 0:
272
+ q = self.wq(x)
273
+ else:
274
+ q = self.wq_b(self.q_norm(self.wq_a(x)))
275
+ q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
276
+ q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
277
+ q_pe = apply_rotary_emb(q_pe, freqs_cis)
278
+ kv = self.wkv_a(x)
279
+ kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
280
+ k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
281
+
282
+
283
+ wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
284
+ wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
285
+ q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
286
+ self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
287
+ self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
288
+ scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
289
+ torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
290
+
291
+ if mask is not None:
292
+ scores += mask.unsqueeze(1)
293
+ scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
294
+
295
+
296
+ x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
297
+ x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
298
+ x = self.wo(x.flatten(2))
299
+ return x
300
+
301
+
302
+ #####################################
303
+ # MOE FEEDFORWARD
304
+ #####################################
305
+
306
+ class MoEFeedForward(nn.Module):
307
+ """
308
+ Mixture of Experts Feed-Forward Network using custom Linear modules.
309
+ Based on the architecture from gpt_with_kv_moe.py but adapted to use
310
+ the custom Linear, ColumnParallelLinear, and RowParallelLinear classes.
311
+ """
312
+ def __init__(self, args: ModelArgs):
313
+ super().__init__()
314
+ self.num_experts_per_tok = args.n_activated_experts
315
+ self.num_experts = args.n_routed_experts
316
+ self.emb_dim = args.dim
317
+ self.hidden_dim = args.moe_inter_dim
318
+
319
+ # Gate network uses custom Linear
320
+ self.gate = Linear(args.dim, args.n_routed_experts, bias=False)
321
+
322
+ # Expert networks using custom Linear modules
323
+ # fc1 and fc2 are the two input projections (for SwiGLU-style activation)
324
+ # fc3 is the output projection
325
+ self.fc1 = nn.ModuleList(
326
+ [
327
+ Linear(args.dim, args.moe_inter_dim, bias=False)
328
+ for _ in range(self.num_experts)
329
+ ]
330
+ )
331
+ self.fc2 = nn.ModuleList(
332
+ [
333
+ Linear(args.dim, args.moe_inter_dim, bias=False)
334
+ for _ in range(self.num_experts)
335
+ ]
336
+ )
337
+ self.fc3 = nn.ModuleList(
338
+ [
339
+ Linear(args.moe_inter_dim, args.dim, bias=False)
340
+ for _ in range(self.num_experts)
341
+ ]
342
+ )
343
+
344
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
345
+ # x: (batch, seq_len, emb_dim)
346
+ scores = self.gate(x) # (b, seq_len, num_experts)
347
+ topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)
348
+ topk_probs = torch.softmax(topk_scores, dim=-1)
349
+
350
+ batch, seq_len, _ = x.shape
351
+ x_flat = x.reshape(batch * seq_len, -1)
352
+ out_flat = torch.zeros(batch * seq_len, self.emb_dim, device=x.device, dtype=x.dtype)
353
+
354
+ topk_indices_flat = topk_indices.reshape(-1, self.num_experts_per_tok)
355
+ topk_probs_flat = topk_probs.reshape(-1, self.num_experts_per_tok)
356
+
357
+ unique_experts = torch.unique(topk_indices_flat)
358
+
359
+ for expert_id_tensor in unique_experts:
360
+ expert_id = int(expert_id_tensor.item())
361
+
362
+ mask = topk_indices_flat == expert_id
363
+ if not mask.any():
364
+ continue
365
+
366
+ token_mask = mask.any(dim=-1)
367
+ selected_idx = token_mask.nonzero(as_tuple=False).squeeze(-1)
368
+ if selected_idx.numel() == 0:
369
+ continue
370
+
371
+ expert_input = x_flat.index_select(0, selected_idx)
372
+ # SwiGLU-style activation: silu(fc1(x)) * fc2(x)
373
+ hidden = torch.nn.functional.silu(self.fc1[expert_id](expert_input)) * self.fc2[
374
+ expert_id
375
+ ](expert_input)
376
+ expert_out = self.fc3[expert_id](hidden)
377
+
378
+ mask_selected = mask[selected_idx]
379
+ slot_indices = mask_selected.int().argmax(dim=-1, keepdim=True)
380
+ selected_probs = torch.gather(
381
+ topk_probs_flat.index_select(0, selected_idx), dim=-1, index=slot_indices
382
+ ).squeeze(-1)
383
+
384
+ out_flat.index_add_(0, selected_idx, expert_out * selected_probs.unsqueeze(-1))
385
+
386
+ return out_flat.reshape(batch, seq_len, self.emb_dim)
387
+
388
+
389
+ #####################################
390
+ # DENSE FEEDFORWARD (MLP)
391
+ #####################################
392
+
393
+ class MLP(nn.Module):
394
+ """
395
+ Dense feed-forward network using custom Linear modules.
396
+ Used for dense layers (non-MoE layers).
397
+ """
398
+ def __init__(self, dim: int, inter_dim: int):
399
+ super().__init__()
400
+ self.fc1 = Linear(dim, inter_dim, bias=False)
401
+ self.fc2 = Linear(dim, inter_dim, bias=False)
402
+ self.fc3 = Linear(inter_dim, dim, bias=False)
403
+
404
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
405
+ # SwiGLU-style activation: silu(fc1(x)) * fc2(x)
406
+ return self.fc3(F.silu(self.fc1(x)) * self.fc2(x))
407
+
408
+
409
+ #####################################
410
+ # TRANSFORMER BLOCKS
411
+ #####################################
412
+
413
+ class Block(nn.Module):
414
+ def __init__(self, layer_id: int, args: ModelArgs):
415
+ super().__init__()
416
+ self.attn = MultiHeadLatentAttention(args)
417
+ # Use dense MLP for first n_dense_layers, then MoE for remaining layers
418
+ self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoEFeedForward(args)
419
+ self.attn_norm = RMSNorm(args.dim)
420
+ self.ffn_norm = RMSNorm(args.dim)
421
+
422
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
423
+ x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
424
+ x = x + self.ffn(self.ffn_norm(x))
425
+ return x
426
+
427
+
428
+ #####################################
429
+ # TRANSFORMER MODEL
430
+ #####################################
431
+
432
+ class Transformer(nn.Module):
433
+ def __init__(self, args: ModelArgs):
434
+ super().__init__()
435
+ self.args = args
436
+ self.vocab_size = args.vocab_size
437
+ self.n_layers = args.n_layers
438
+
439
+ self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
440
+ self.layers = nn.ModuleList([Block(i, args) for i in range(args.n_layers)])
441
+ self.norm = RMSNorm(args.dim)
442
+ self.output = Linear(args.dim, args.vocab_size, bias=False)
443
+
444
+ self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
445
+
446
+ def forward(self, tokens: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
447
+ bsz, seqlen = tokens.shape
448
+ h = self.tok_embeddings(tokens)
449
+ freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
450
+
451
+ # Create causal mask
452
+ mask = None
453
+ if seqlen > 1:
454
+ mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
455
+ mask = torch.triu(mask, diagonal=1)
456
+ mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
457
+
458
+ for layer in self.layers:
459
+ h = layer(h, start_pos, freqs_cis, mask)
460
+ h = self.norm(h)
461
+ output = self.output(h)
462
+ return output
463
+
464
+
465
+ #####################################
466
+ # GENERATION
467
+ #####################################
468
+
469
+ def generate_text_simple(model, idx, max_new_tokens, context_size):
470
+ # idx is (B, T) array of indices in the current context
471
+ for _ in range(max_new_tokens):
472
+
473
+ # Crop current context if it exceeds the supported context size
474
+ # E.g., if LLM supports only 5 tokens, and the context size is 10
475
+ # then only the last 5 tokens are used as context
476
+ idx_cond = idx[:, -context_size:]
477
+
478
+ # Get the predictions
479
+ with torch.no_grad():
480
+ logits = model(idx_cond)
481
+
482
+ # Focus only on the last time step
483
+ # (batch, n_token, vocab_size) becomes (batch, vocab_size)
484
+ logits = logits[:, -1, :]
485
+
486
+ # Get the idx of the vocab entry with the highest logits value
487
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
488
+
489
+ # Append sampled index to the running sequence
490
+ idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
491
+
492
+ return idx
493
+
494
+
495
+ if __name__ == "__main__":
496
+ # Example configuration - similar to DeepSeek-V3 but smaller for testing
497
+ args = ModelArgs(
498
+ max_batch_size=4,
499
+ max_seq_len=1024,
500
+ vocab_size=50257, # GPT-2 vocab size for compatibility
501
+ dim=768,
502
+ inter_dim=3072,
503
+ moe_inter_dim=768,
504
+ n_layers=12,
505
+ n_dense_layers=1, # First layer is dense, rest are MoE
506
+ n_heads=12,
507
+ n_routed_experts=8,
508
+ n_shared_experts=2,
509
+ n_activated_experts=2,
510
+ kv_lora_rank=256,
511
+ qk_nope_head_dim=64,
512
+ qk_rope_head_dim=32,
513
+ v_head_dim=64,
514
+ )
515
+
516
+ torch.manual_seed(123)
517
+ model = Transformer(args)
518
+ model.eval()
519
+
520
+ start_context = "Hello, I am"
521
+ tokenizer = tiktoken.get_encoding("gpt2")
522
+ encoded = tokenizer.encode(start_context)
523
+ encoded_tensor = torch.tensor(encoded).unsqueeze(0)
524
+
525
+ print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
526
+ print("\nInput text:", start_context)
527
+ print("Encoded input text:", encoded)
528
+ print("encoded_tensor.shape:", encoded_tensor.shape)
529
+
530
+ out = generate_text_simple(
531
+ model=model,
532
+ idx=encoded_tensor,
533
+ max_new_tokens=10,
534
+ context_size=args.max_seq_len
535
+ )
536
+ decoded_text = tokenizer.decode(out.squeeze(0).tolist())
537
+
538
+ print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
539
+ print("\nOutput:", out)
540
+ print("Output length:", len(out[0]))
541
+ print("Output text:", decoded_text)