joebruce1313 commited on
Commit
9ac5bef
·
verified ·
1 Parent(s): ac2ed15

Upload claudeson.py

Browse files
Files changed (1) hide show
  1. claudeson.py +644 -0
claudeson.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import faiss
5
+ import math
6
+ import numpy as np
7
+ from typing import Optional, Tuple, Literal
8
+ from dataclasses import dataclass
9
+
10
+ # Global configuration
11
+ world_size = 1
12
+ rank = 0
13
+ block_size = 128
14
+ gemm_impl: Literal["bf16", "fp8"] = "bf16"
15
+ attn_impl: Literal["naive", "absorb"] = "absorb"
16
+
17
+ @dataclass
18
+ class ModelArgs:
19
+ dim: int = 4096
20
+ n_layers: int = 32
21
+ n_heads: int = 32
22
+ n_kv_heads: int = 8
23
+ vocab_size: int = 32000
24
+ multiple_of: int = 256
25
+ ffn_dim_multiplier: Optional[float] = None
26
+ max_seq_len: int = 4096
27
+ original_seq_len: int = 4096
28
+ rope_theta: float = 10000.0
29
+ rope_factor: float = 1.0
30
+ beta_fast: float = 32.0
31
+ beta_slow: float = 1.0
32
+ mscale: float = 0.707
33
+ q_lora_rank: int = 0
34
+ kv_lora_rank: int = 0
35
+ qk_nope_head_dim: int = 128
36
+ qk_rope_head_dim: int = 64
37
+ v_head_dim: int = 128
38
+ n_routed_experts: int = 8
39
+ n_activated_experts: int = 2
40
+ n_expert_groups: int = 1
41
+ n_limited_groups: int = 1
42
+ score_func: str = "softmax"
43
+ route_scale: float = 1.0
44
+ n_dense_layers: int = 0
45
+ moe_inter_dim: int = None
46
+ n_shared_experts: int = 1
47
+ max_batch_size: int = 32
48
+ dtype: str = "bf16"
49
+
50
+ def __post_init__(self):
51
+ if self.ffn_dim_multiplier is None:
52
+ self.inter_dim = int(2 * self.dim / 3)
53
+ self.inter_dim = self.multiple_of * ((self.inter_dim + self.multiple_of - 1) // self.multiple_of)
54
+ else:
55
+ self.inter_dim = int(2 * self.dim * self.ffn_dim_multiplier)
56
+
57
+ if self.moe_inter_dim is None:
58
+ self.moe_inter_dim = int(2 * self.dim / 3)
59
+ self.moe_inter_dim = self.multiple_of * ((self.moe_inter_dim + self.multiple_of - 1) // self.multiple_of)
60
+
61
+ # Embedding layer
62
+ class ParallelEmbedding(nn.Module):
63
+ def __init__(self, vocab_size: int, dim: int):
64
+ super().__init__()
65
+ self.vocab_size = vocab_size
66
+ self.dim = dim
67
+ assert vocab_size % world_size == 0
68
+ self.part_vocab_size = (vocab_size // world_size)
69
+ self.vocab_start_idx = rank * self.part_vocab_size
70
+ self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
71
+ self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ if world_size > 1:
75
+ mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
76
+ x = x - self.vocab_start_idx
77
+ x[mask] = 0
78
+ y = F.embedding(x, self.weight)
79
+ if world_size > 1:
80
+ y[mask] = 0
81
+ torch.distributed.all_reduce(y)
82
+ return y
83
+
84
+ # Linear layer
85
+ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
86
+ if weight.element_size() > 1:
87
+ return F.linear(x, weight, bias)
88
+ elif gemm_impl == "bf16":
89
+ weight = weight_dequant(weight, weight.scale)
90
+ return F.linear(x, weight, bias)
91
+ else:
92
+ x, scale = act_quant(x, block_size)
93
+ y = fp8_gemm(x, scale, weight, weight.scale)
94
+ if bias is not None:
95
+ y += bias
96
+ return y
97
+
98
+ class Linear(nn.Module):
99
+ dtype = torch.bfloat16
100
+
101
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
102
+ super().__init__()
103
+ self.in_features = in_features
104
+ self.out_features = out_features
105
+ self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
106
+ if self.weight.element_size() == 1:
107
+ scale_out_features = (out_features + block_size - 1) // block_size
108
+ scale_in_features = (in_features + block_size - 1) // block_size
109
+ self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
110
+ else:
111
+ self.register_parameter("scale", None)
112
+ if bias:
113
+ self.bias = nn.Parameter(torch.empty(out_features))
114
+ else:
115
+ self.register_parameter("bias", None)
116
+
117
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
118
+ return linear(x, self.weight, self.bias)
119
+
120
+ class ColumnParallelLinear(Linear):
121
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
122
+ assert out_features % world_size == 0
123
+ self.part_out_features = out_features // world_size
124
+ super().__init__(in_features, self.part_out_features, bias, dtype)
125
+
126
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
127
+ return linear(x, self.weight, self.bias)
128
+
129
+ class RowParallelLinear(Linear):
130
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
131
+ assert in_features % world_size == 0
132
+ self.part_in_features = in_features // world_size
133
+ super().__init__(self.part_in_features, out_features, bias, dtype)
134
+
135
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
+ y = linear(x, self.weight)
137
+ if world_size > 1:
138
+ torch.distributed.all_reduce(y)
139
+ if self.bias is not None:
140
+ y += self.bias
141
+ return y
142
+
143
+ # Normalization layer
144
+ class RMSNorm(nn.Module):
145
+ def __init__(self, dim: int, eps: float = 1e-6):
146
+ super().__init__()
147
+ self.eps = eps
148
+ self.weight = nn.Parameter(torch.ones(dim))
149
+
150
+ def forward(self, x: torch.Tensor):
151
+ x = x.float()
152
+ y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
153
+ return y.type_as(self.weight) * self.weight
154
+
155
+ # Positional encoding
156
+ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
157
+ dim = args.qk_rope_head_dim
158
+ seqlen = args.max_seq_len
159
+ beta_fast = args.beta_fast
160
+ beta_slow = args.beta_slow
161
+ base = args.rope_theta
162
+ factor = args.rope_factor
163
+
164
+ def find_correction_dim(num_rotations, dim, base, max_seq_len):
165
+ return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
166
+
167
+ def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
168
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
169
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
170
+ return max(low, 0), min(high, dim-1)
171
+
172
+ def linear_ramp_factor(min, max, dim):
173
+ if min == max:
174
+ max += 0.001
175
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
176
+ ramp_func = torch.clamp(linear_func, 0, 1)
177
+ return ramp_func
178
+
179
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
180
+ if seqlen > args.original_seq_len:
181
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
182
+ smooth = 1 - linear_ramp_factor(low, high, dim // 2)
183
+ freqs = freqs / factor * (1 - smooth) + freqs * smooth
184
+
185
+ t = torch.arange(seqlen)
186
+ freqs = torch.outer(t, freqs)
187
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
188
+ return freqs_cis
189
+
190
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
191
+ dtype = x.dtype
192
+ x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
193
+ freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
194
+ y = torch.view_as_real(x * freqs_cis).flatten(3)
195
+ return y.to(dtype)
196
+
197
+ # Multi-Head Latent Attention (MLA)
198
+ class MLA(nn.Module):
199
+ def __init__(self, args: ModelArgs):
200
+ super().__init__()
201
+ self.dim = args.dim
202
+ self.n_heads = args.n_heads
203
+ self.n_local_heads = args.n_heads // world_size
204
+ self.q_lora_rank = args.q_lora_rank
205
+ self.kv_lora_rank = args.kv_lora_rank
206
+ self.qk_nope_head_dim = args.qk_nope_head_dim
207
+ self.qk_rope_head_dim = args.qk_rope_head_dim
208
+ self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
209
+ self.v_head_dim = args.v_head_dim
210
+
211
+ if self.q_lora_rank == 0:
212
+ self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
213
+ else:
214
+ self.wq_a = Linear(self.dim, self.q_lora_rank)
215
+ self.q_norm = RMSNorm(self.q_lora_rank)
216
+ self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
217
+ self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
218
+ self.kv_norm = RMSNorm(self.kv_lora_rank)
219
+ self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
220
+ self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
221
+ self.softmax_scale = self.qk_head_dim ** -0.5
222
+ if args.max_seq_len > args.original_seq_len:
223
+ mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
224
+ self.softmax_scale = self.softmax_scale * mscale * mscale
225
+
226
+ if attn_impl == "naive":
227
+ self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
228
+ self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
229
+ else:
230
+ self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
231
+ self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
232
+
233
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
234
+ bsz, seqlen, _ = x.size()
235
+ end_pos = start_pos + seqlen
236
+ if self.q_lora_rank == 0:
237
+ q = self.wq(x)
238
+ else:
239
+ q = self.wq_b(self.q_norm(self.wq_a(x)))
240
+ q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
241
+ q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
242
+ q_pe = apply_rotary_emb(q_pe, freqs_cis)
243
+ kv = self.wkv_a(x)
244
+ kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
245
+ k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
246
+ if attn_impl == "naive":
247
+ q = torch.cat([q_nope, q_pe], dim=-1)
248
+ kv = self.wkv_b(self.kv_norm(kv))
249
+ kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
250
+ k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
251
+ k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
252
+ self.k_cache[:bsz, start_pos:end_pos] = k
253
+ self.v_cache[:bsz, start_pos:end_pos] = v
254
+ scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
255
+ else:
256
+ wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
257
+ wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
258
+ q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
259
+ self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
260
+ self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
261
+ scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
262
+ torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
263
+ if mask is not None:
264
+ scores += mask.unsqueeze(1)
265
+ scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
266
+ if attn_impl == "naive":
267
+ x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
268
+ else:
269
+ x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
270
+ x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
271
+ x = self.wo(x.flatten(2))
272
+ return x
273
+
274
+ # MLP layer
275
+ class MLP(nn.Module):
276
+ def __init__(self, dim: int, inter_dim: int):
277
+ super().__init__()
278
+ self.w1 = ColumnParallelLinear(dim, inter_dim)
279
+ self.w2 = RowParallelLinear(inter_dim, dim)
280
+ self.w3 = ColumnParallelLinear(dim, inter_dim)
281
+
282
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
283
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
284
+
285
+ # Mixture of Experts (MoE) components
286
+ class Gate(nn.Module):
287
+ def __init__(self, args: ModelArgs):
288
+ super().__init__()
289
+ self.dim = args.dim
290
+ self.topk = args.n_activated_experts
291
+ self.n_groups = args.n_expert_groups
292
+ self.topk_groups = args.n_limited_groups
293
+ self.score_func = args.score_func
294
+ self.route_scale = args.route_scale
295
+ self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
296
+ self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
297
+
298
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
299
+ scores = linear(x, self.weight)
300
+ if self.score_func == "softmax":
301
+ scores = scores.softmax(dim=-1, dtype=torch.float32)
302
+ else:
303
+ scores = scores.sigmoid()
304
+ original_scores = scores
305
+ if self.bias is not None:
306
+ scores = scores + self.bias
307
+ if self.n_groups > 1:
308
+ scores = scores.view(x.size(0), self.n_groups, -1)
309
+ if self.bias is None:
310
+ group_scores = scores.amax(dim=-1)
311
+ else:
312
+ group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
313
+ indices = group_scores.topk(self.topk_groups, dim=-1)[1]
314
+ mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True)
315
+ scores = (scores * mask.unsqueeze(-1)).flatten(1)
316
+ indices = torch.topk(scores, self.topk, dim=-1)[1]
317
+ weights = original_scores.gather(1, indices)
318
+ if self.score_func == "sigmoid":
319
+ weights /= weights.sum(dim=-1, keepdim=True)
320
+ weights *= self.route_scale
321
+ return weights.type_as(x), indices
322
+
323
+ class Expert(nn.Module):
324
+ def __init__(self, dim: int, inter_dim: int):
325
+ super().__init__()
326
+ self.w1 = Linear(dim, inter_dim)
327
+ self.w2 = Linear(inter_dim, dim)
328
+ self.w3 = Linear(dim, inter_dim)
329
+
330
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
331
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
332
+
333
+ class MoE(nn.Module):
334
+ def __init__(self, args: ModelArgs):
335
+ super().__init__()
336
+ self.dim = args.dim
337
+ assert args.n_routed_experts % world_size == 0
338
+ self.n_routed_experts = args.n_routed_experts
339
+ self.n_local_experts = args.n_routed_experts // world_size
340
+ self.n_activated_experts = args.n_activated_experts
341
+ self.experts_start_idx = rank * self.n_local_experts
342
+ self.experts_end_idx = self.experts_start_idx + self.n_local_experts
343
+ self.gate = Gate(args)
344
+ self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
345
+ for i in range(self.n_routed_experts)])
346
+ self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
347
+
348
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
349
+ shape = x.size()
350
+ x = x.view(-1, self.dim)
351
+ weights, indices = self.gate(x)
352
+ y = torch.zeros_like(x)
353
+ counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
354
+ for i in range(self.experts_start_idx, self.experts_end_idx):
355
+ if counts[i] == 0:
356
+ continue
357
+ expert = self.experts[i]
358
+ idx, top = torch.where(indices == i)
359
+ y[idx] += expert(x[idx]) * weights[idx, top, None]
360
+ z = self.shared_experts(x)
361
+ if world_size > 1:
362
+ torch.distributed.all_reduce(y)
363
+ return (y + z).view(shape)
364
+
365
+ # Transformer block
366
+ class Block(nn.Module):
367
+ def __init__(self, layer_id: int, args: ModelArgs):
368
+ super().__init__()
369
+ self.attn = MLA(args)
370
+ self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
371
+ self.attn_norm = RMSNorm(args.dim)
372
+ self.ffn_norm = RMSNorm(args.dim)
373
+
374
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
375
+ x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
376
+ x = x + self.ffn(self.ffn_norm(x))
377
+ return x
378
+
379
+ # Transformer model
380
+ class Transformer(nn.Module):
381
+ def __init__(self, args: ModelArgs):
382
+ global world_size, rank
383
+ world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
384
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
385
+ Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
386
+ super().__init__()
387
+ self.max_seq_len = args.max_seq_len
388
+ self.embed = ParallelEmbedding(args.vocab_size, args.dim)
389
+ self.layers = torch.nn.ModuleList()
390
+ for layer_id in range(args.n_layers):
391
+ self.layers.append(Block(layer_id, args))
392
+ self.norm = RMSNorm(args.dim)
393
+ self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
394
+ self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
395
+
396
+ @torch.inference_mode()
397
+ def forward(self, tokens: torch.Tensor, start_pos: int = 0):
398
+ seqlen = tokens.size(1)
399
+ h = self.embed(tokens)
400
+ freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
401
+ mask = None
402
+ if seqlen > 1:
403
+ mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
404
+ for layer in self.layers:
405
+ h = layer(h, start_pos, freqs_cis, mask)
406
+ h = self.norm(h)[:, -1]
407
+ logits = self.head(h)
408
+ if world_size > 1:
409
+ all_logits = [torch.empty_like(logits) for _ in range(world_size)]
410
+ torch.distributed.all_gather(all_logits, logits)
411
+ logits = torch.cat(all_logits, dim=-1)
412
+ return logits
413
+
414
+ # FAISS Retriever
415
+ class FAISSRetriever:
416
+ def __init__(self, knowledge_base: faiss.Index, dim: int = 768, num_results: int = 5):
417
+ self.index = knowledge_base
418
+ self.dim = dim
419
+ self.num_results = num_results
420
+
421
+ def search(self, query_embedding: torch.Tensor, k: int = None) -> torch.Tensor:
422
+ if k is None:
423
+ k = self.num_results
424
+ query_np = query_embedding.detach().cpu().numpy()
425
+ distances, indices = self.index.search(query_np, k)
426
+ return torch.tensor(indices, device=query_embedding.device)
427
+
428
+ # Complete Multi-Modal LLM
429
+ class CombinedMultiModalTransformer(nn.Module):
430
+ def __init__(self, args: ModelArgs, knowledge_base: faiss.Index):
431
+ super(CombinedMultiModalTransformer, self).__init__()
432
+ self.args = args
433
+ self.transformer = Transformer(args)
434
+
435
+ # Multi-modal components
436
+ self.audio_encoder = nn.Sequential(
437
+ nn.Conv1d(256, 256, kernel_size=11, stride=2, padding='same'),
438
+ nn.ReLU(),
439
+ nn.Conv1d(256, 256, kernel_size=11, stride=2, padding='same'),
440
+ nn.ReLU(),
441
+ nn.Conv1d(256, args.dim, kernel_size=11, stride=2, padding='same'),
442
+ nn.ReLU()
443
+ )
444
+
445
+ self.image_encoder = nn.Sequential(
446
+ # Simplified ResNet50 implementation
447
+ nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
448
+ nn.ReLU(),
449
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
450
+ nn.AdaptiveAvgPool2d((1, 1)),
451
+ nn.Flatten(),
452
+ nn.Linear(2048, args.dim)
453
+ )
454
+
455
+ # Music generation components
456
+ self.pitch_embedding = nn.Embedding(128, args.dim)
457
+ self.duration_embedding = nn.Embedding(32, args.dim)
458
+ self.velocity_embedding = nn.Embedding(128, args.dim)
459
+
460
+ # Anomaly detection components
461
+ self.anomaly_detector = nn.Sequential(
462
+ nn.Linear(args.dim, args.dim),
463
+ nn.ReLU(),
464
+ nn.Linear(args.dim, 1),
465
+ nn.Sigmoid()
466
+ )
467
+
468
+ # RAG components
469
+ self.knowledge_base = FAISSRetriever(knowledge_base)
470
+ self.query_encoder = nn.Sequential(
471
+ nn.Linear(args.dim, args.dim),
472
+ nn.ReLU(),
473
+ nn.Linear(args.dim, args.dim)
474
+ )
475
+
476
+ def forward(self, inputs, task, start_pos=0):
477
+ if task == 'text_generation':
478
+ # RAG component
479
+ query_embedding = self.query_encoder(self.transformer.embed(inputs))
480
+ retrieved_indices = self.knowledge_base.search(query_embedding, k=5)
481
+
482
+ # Concatenate retrieved docs with input
483
+ # In practice, you would convert indices to actual embeddings
484
+ retrieved_embeddings = torch.zeros_like(inputs[:, :5, :]) # Placeholder
485
+ inputs = torch.cat([retrieved_embeddings, inputs], dim=1)
486
+
487
+ # Pass through transformer
488
+ logits = self.transformer(inputs, start_pos)
489
+ return logits
490
+
491
+ elif task == 'speech_recognition':
492
+ x = self.audio_encoder(inputs)
493
+ # Convert audio encoder output to transformer format
494
+ batch_size, seq_len = x.shape[0], x.shape[1]
495
+ tokens = torch.zeros(batch_size, seq_len, dtype=torch.long, device=x.device)
496
+ logits = self.transformer(tokens, start_pos)
497
+ return logits
498
+
499
+ elif task == 'image_captioning':
500
+ image_features = self.image_encoder(inputs)
501
+ # Convert image features to transformer format
502
+ batch_size = image_features.shape[0]
503
+ tokens = torch.zeros(batch_size, 1, dtype=torch.long, device=image_features.device)
504
+ logits = self.transformer(tokens, start_pos)
505
+ return logits
506
+
507
+ elif task == 'music_generation':
508
+ pitch, duration, velocity = inputs
509
+ x = self.pitch_embedding(pitch) + self.duration_embedding(duration) + self.velocity_embedding(velocity)
510
+ # Convert music features to transformer format
511
+ batch_size, seq_len = x.shape[0], x.shape[1]
512
+ tokens = torch.zeros(batch_size, seq_len, dtype=torch.long, device=x.device)
513
+ logits = self.transformer(tokens, start_pos)
514
+ return logits
515
+
516
+ elif task == 'anomaly_detection':
517
+ x = self.transformer.embed(inputs)
518
+ anomaly_scores = self.anomaly_detector(x)
519
+ return anomaly_scores
520
+
521
+ else:
522
+ raise ValueError(f"Unknown task: {task}")
523
+
524
+ # Helper functions
525
+ def act_quant(x: torch.Tensor, block_size: int = 128):
526
+ # Simplified activation quantization function
527
+ return x, torch.ones(1, device=x.device)
528
+
529
+ def weight_dequant(weight: torch.Tensor, scale: torch.Tensor, block_size: int = 128):
530
+ # Simplified weight dequantization function
531
+ return weight * scale
532
+
533
+ def fp8_gemm(x: torch.Tensor, x_scale: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor):
534
+ # Simplified FP8 GEMM function
535
+ return torch.matmul(x, weight.t()) * x_scale * weight_scale
536
+
537
+ # Training function
538
+ def train_model(model, dataloader, optimizer, criterion, device, num_epochs=10):
539
+ model.train()
540
+ model.to(device)
541
+
542
+ for epoch in range(num_epochs):
543
+ total_loss = 0.0
544
+ for batch_idx, (inputs, targets, tasks) in enumerate(dataloader):
545
+ inputs, targets = inputs.to(device), targets.to(device)
546
+
547
+ optimizer.zero_grad()
548
+ outputs = model(inputs, tasks)
549
+
550
+ if isinstance(outputs, dict):
551
+ # Handle multi-task outputs
552
+ loss = 0.0
553
+ for task, output in outputs.items():
554
+ task_targets = targets[task]
555
+ loss += criterion(output, task_targets)
556
+ else:
557
+ loss = criterion(outputs, targets)
558
+
559
+ loss.backward()
560
+ optimizer.step()
561
+
562
+ total_loss += loss.item()
563
+
564
+ if batch_idx % 100 == 0:
565
+ print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}')
566
+
567
+ avg_loss = total_loss / len(dataloader)
568
+ print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}')
569
+
570
+ # Inference function
571
+ def generate_text(model, prompt, max_length=100, temperature=1.0, device='cpu'):
572
+ model.eval()
573
+ model.to(device)
574
+
575
+ # Convert prompt to tokens
576
+ tokens = torch.tensor([prompt], dtype=torch.long, device=device)
577
+
578
+ with torch.no_grad():
579
+ for _ in range(max_length):
580
+ logits = model(tokens, 'text_generation')
581
+
582
+ # Apply temperature scaling
583
+ logits = logits[:, -1, :] / temperature
584
+
585
+ # Get probabilities for next token
586
+ probs = F.softmax(logits, dim=-1)
587
+
588
+ # Sample next token
589
+ next_token = torch.multinomial(probs, num_samples=1)
590
+
591
+ # Append new token to sequence
592
+ tokens = torch.cat([tokens, next_token], dim=1)
593
+
594
+ return tokens[0].tolist()
595
+
596
+ # Example usage
597
+ if __name__ == "__main__":
598
+ # Initialize model parameters
599
+ args = ModelArgs()
600
+
601
+ # Create a dummy knowledge base for testing
602
+ dim = args.dim
603
+ knowledge_base = faiss.IndexFlatL2(dim)
604
+ # Add some dummy vectors
605
+ vectors = np.random.rand(100, dim).astype('float32')
606
+ knowledge_base.add(vectors)
607
+
608
+ # Initialize model
609
+ model = CombinedMultiModalTransformer(args, knowledge_base)
610
+
611
+ # Print model structure
612
+ print(model)
613
+
614
+ # Test text generation
615
+ prompt = [1, 2, 3, 4, 5] # Example token sequence
616
+ generated_tokens = generate_text(model, prompt, max_length=20)
617
+ print(f"Generated tokens: {generated_tokens}")
618
+
619
+ # Test other tasks
620
+ # Note: In practice, you would provide appropriate input data
621
+ try:
622
+ # Speech recognition
623
+ audio_input = torch.randn(1, 256, 160) # Example audio input
624
+ speech_output = model(audio_input, 'speech_recognition')
625
+ print(f"Speech recognition output shape: {speech_output.shape}")
626
+
627
+ # Image captioning
628
+ image_input = torch.randn(1, 3, 224, 224) # Example image input
629
+ caption_output = model(image_input, 'image_captioning')
630
+ print(f"Image captioning output shape: {caption_output.shape}")
631
+
632
+ # Music generation
633
+ pitch = torch.randint(0, 128, (1, 100))
634
+ duration = torch.randint(0, 32, (1, 100))
635
+ velocity = torch.randint(0, 128, (1, 100))
636
+ music_output = model((pitch, duration, velocity), 'music_generation')
637
+ print(f"Music generation output shape: {music_output.shape}")
638
+
639
+ # Anomaly detection
640
+ anomaly_input = torch.randn(1, 100, args.dim)
641
+ anomaly_output = model(anomaly_input, 'anomaly_detection')
642
+ print(f"Anomaly detection output shape: {anomaly_output.shape}")
643
+ except Exception as e:
644
+ print(f"Error during testing: {e}")