RaghuCourage9605 commited on
Commit
469ce0f
·
verified ·
1 Parent(s): f327afe

Update anubis_moe.py

Browse files
Files changed (1) hide show
  1. anubis_moe.py +164 -192
anubis_moe.py CHANGED
@@ -1,31 +1,40 @@
1
- from transformers import PretrainedConfig, PreTrainedModel
2
- from transformers import AutoModel,AutoTokenizer
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import torch.utils.checkpoint as checkpoint
 
7
 
8
- class AnubisMoeConfig(PretrainedConfig):
 
 
 
 
 
9
 
10
- model_type = "anubis_moe"
 
 
 
 
 
11
 
12
  def __init__(
13
  self,
14
  vocab_size=50257,
15
- embed_dim=768,
16
  context_length=1024,
 
17
  n_layers=12,
18
  n_heads=12,
19
  drop_rate=0.1,
20
  qkv_bias=False,
21
  num_experts=8,
22
  top_k_experts=2,
23
- expert_capacity=64,
24
- **kwargs,
25
  ):
26
  self.vocab_size = vocab_size
27
- self.embed_dim = embed_dim
28
  self.context_length = context_length
 
29
  self.n_layers = n_layers
30
  self.n_heads = n_heads
31
  self.drop_rate = drop_rate
@@ -35,6 +44,7 @@ class AnubisMoeConfig(PretrainedConfig):
35
  self.expert_capacity = expert_capacity
36
  super().__init__(**kwargs)
37
 
 
38
  class LayerNorm(nn.Module):
39
  def __init__(self, emb_dim):
40
  super().__init__()
@@ -51,132 +61,93 @@ class LayerNorm(nn.Module):
51
  class GELU(nn.Module):
52
  def __init__(self):
53
  super().__init__()
54
-
55
  def forward(self, x):
56
  return 0.5 * x * (1 + torch.tanh(
57
  torch.sqrt(torch.tensor(2.0 / torch.pi)) *
58
  (x + 0.044715 * torch.pow(x, 3))
59
  ))
60
 
61
- class FeedForward(nn.Module):
62
- def __init__(self, cfg):
63
- super().__init__()
64
- self.layers = nn.Sequential(
65
- nn.Linear(cfg.embed_dim, 4 * cfg.embed_dim),
66
  GELU(),
67
- nn.Linear(4 * cfg.embed_dim, cfg.embed_dim),
68
  )
69
 
70
- def forward(self, x):
71
- return self.layers(x)
72
-
73
  class MultiHeadAttention(nn.Module):
74
- def __init__(self, cfg):
75
  super().__init__()
76
- d_out = cfg.embed_dim
77
- num_heads = cfg.n_heads
78
- d_in = cfg.embed_dim
79
- context_length = cfg.context_length
80
- dropout = cfg.drop_rate
81
- qkv_bias = cfg.qkv_bias
82
-
83
- assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
84
 
 
85
  self.d_out = d_out
86
  self.num_heads = num_heads
87
  self.head_dim = d_out // num_heads
88
-
89
- self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
90
- self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
91
- self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
92
  self.out_proj = nn.Linear(d_out, d_out)
93
  self.dropout = nn.Dropout(dropout)
94
- self.register_buffer(
95
- "mask",
96
- torch.triu(torch.ones(context_length, context_length), diagonal=1)
97
- )
98
 
99
  def forward(self, x):
100
  b, num_tokens, d_in = x.shape
101
- keys = self.W_key(x)
102
- queries = self.W_query(x)
103
- values = self.W_value(x)
104
-
105
- keys = keys.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
106
- values = values.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
107
- queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
108
-
109
  attn_scores = queries @ keys.transpose(2, 3)
110
  mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
111
  attn_scores.masked_fill_(mask_bool, -torch.inf)
112
-
113
- attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
114
  attn_weights = self.dropout(attn_weights)
115
-
116
- context_vec = (attn_weights @ values).transpose(1, 2)
117
- context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
118
  context_vec = self.out_proj(context_vec)
119
  return context_vec
120
 
121
  class MixtureOfExperts(nn.Module):
122
- def __init__(self, cfg):
123
  super().__init__()
124
- self.num_experts = cfg.num_experts
125
- self.top_k = cfg.top_k_experts
126
- self.embed_dim = cfg.embed_dim
127
- self.expert_capacity = cfg.expert_capacity
128
-
129
- self.router = nn.Linear(self.embed_dim, self.num_experts)
130
- self.experts = nn.ModuleList(
131
- [FeedForward(cfg) for _ in range(self.num_experts)]
132
- )
133
 
134
  def forward(self, x):
 
 
135
  batch_size, seq_len, embed_dim = x.shape
136
  x_flat = x.view(-1, embed_dim)
137
-
138
  router_logits = self.router(x_flat)
139
- routing_probs = F.softmax(router_logits, dim=-1)
140
-
141
- topk_routing_probs, topk_indices = torch.topk(routing_probs, self.top_k, dim=-1)
142
- topk_routing_probs = topk_routing_probs / topk_routing_probs.sum(dim=-1, keepdim=True)
143
-
144
- expert_mask = torch.zeros(
145
- x_flat.shape[0],
146
- self.num_experts,
147
- device=x.device,
148
- dtype=topk_routing_probs.dtype
149
- )
150
  for i in range(self.top_k):
151
- expert_mask.scatter_add_(1, topk_indices[:, i:i+1], topk_routing_probs[:, i:i+1])
152
-
153
-
154
- expert_outputs = torch.zeros_like(x_flat)
155
-
156
- for expert_idx, expert in enumerate(self.experts):
157
- token_indices = torch.nonzero(expert_mask[:, expert_idx], as_tuple=False).squeeze(-1)
158
- if token_indices.numel() == 0:
159
- continue
160
-
161
- if self.expert_capacity > 0 and token_indices.shape[0] > self.expert_capacity:
162
- perm = torch.randperm(token_indices.shape[0], device=x.device)
163
- token_indices = token_indices[perm[:self.expert_capacity]]
164
-
165
- expert_input = x_flat[token_indices]
166
- expert_output = expert(expert_input)
167
- expert_outputs.index_add_(0, token_indices, expert_output * expert_mask[token_indices, expert_idx].unsqueeze(-1))
168
-
169
- return expert_outputs.view(batch_size, seq_len, embed_dim)
170
 
171
  class TransformerBlockMOE(nn.Module):
172
- def __init__(self, cfg):
173
  super().__init__()
174
- self.ln1 = LayerNorm(cfg.embed_dim)
175
- self.attn = MultiHeadAttention(cfg)
176
- self.ln2 = LayerNorm(cfg.embed_dim)
177
- self.ffn = MixtureOfExperts(cfg)
178
- self.drop = nn.Dropout(cfg.drop_rate)
179
-
180
  def forward(self, x):
181
  attn_out = self.attn(self.ln1(x))
182
  x = x + self.drop(attn_out)
@@ -184,115 +155,116 @@ class TransformerBlockMOE(nn.Module):
184
  x = x + self.drop(ffn_out)
185
  return x
186
 
187
- class AnubisMoeForCausalLM(PreTrainedModel):
188
- config_class = AnubisMoeConfig
189
-
190
- def __init__(self, cfg):
191
- super().__init__(cfg)
192
- self.config = cfg
193
- self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.embed_dim)
194
- self.pos_emb = nn.Embedding(cfg.context_length, cfg.embed_dim)
195
- self.drop_emb = nn.Dropout(cfg.drop_rate)
196
 
 
 
 
 
 
 
 
 
 
197
  self.trf_blocks = nn.Sequential(
198
- *[TransformerBlockMOE(cfg) for _ in range(cfg.n_layers)]
199
  )
200
-
201
- self.final_norm = LayerNorm(cfg.embed_dim)
202
- self.out_head = nn.Linear(cfg.embed_dim, cfg.vocab_size, bias=False)
203
 
204
  def forward(self, input_ids, **kwargs):
205
  batch_size, seq_len = input_ids.shape
206
  tok_embeds = self.tok_emb(input_ids)
207
-
208
  pos_ids = torch.arange(seq_len, device=input_ids.device)
209
  pos_embeds = self.pos_emb(pos_ids)
210
-
211
  x = tok_embeds + pos_embeds
212
  x = self.drop_emb(x)
213
  x = self.trf_blocks(x)
214
  x = self.final_norm(x)
215
  logits = self.out_head(x)
216
-
217
- # The model should return a dictionary-like object in training/evaluation
218
- # For generation, we can simplify, but it's good practice to be consistent.
219
  return {"logits": logits}
220
 
221
- def generate(self, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
222
- """
223
- Greedy/Top-k sampling generate method.
224
- """
225
- for _ in range(max_new_tokens):
226
- idx_cond = idx[:, -context_size:]
227
- with torch.no_grad():
228
- outputs = self.forward(idx_cond)
229
- logits = outputs["logits"]
230
- logits = logits[:, -1, :]
231
-
232
- # --- top_k sampling ---
233
- if top_k is not None:
234
- top_logits, _ = torch.topk(logits, top_k)
235
- min_val = top_logits[:, -1]
236
- logits = torch.where(
237
- logits < min_val,
238
- torch.tensor(float("-inf")).to(logits.device),
239
- logits,
240
- )
241
-
242
- # --- temperature scaling ---
243
- if temperature > 0.0:
244
- logits = logits / temperature
245
- probs = torch.softmax(logits, dim=-1)
246
- idx_next = torch.multinomial(probs, num_samples=1)
247
- else:
248
- idx_next = torch.argmax(logits, dim=-1, keepdim=True)
249
-
250
- # --- stop if EOS ---
251
- if eos_id is not None and idx_next.item() == eos_id:
252
- break
253
-
254
- # --- append ---
255
- idx = torch.cat((idx, idx_next), dim=1)
256
-
257
- return idx
258
-
259
- def stream_generate(self, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
260
- """
261
- Stream tokens one by one instead of returning final sequence.
262
- Yields (next_token, current_sequence).
263
- """
264
- for _ in range(max_new_tokens):
265
- idx_cond = idx[:, -context_size:]
266
-
267
- with torch.no_grad():
268
- outputs = self.forward(idx_cond)
269
- logits = outputs["logits"]
270
- logits = logits[:, -1, :]
271
-
272
- # --- top_k sampling ---
273
- if top_k is not None:
274
- top_logits, _ = torch.topk(logits, top_k)
275
- min_val = top_logits[:, -1]
276
- logits = torch.where(
277
- logits < min_val,
278
- torch.tensor(float("-inf")).to(logits.device),
279
- logits,
280
- )
281
-
282
- # --- temperature ---
283
- if temperature > 0.0:
284
- logits = logits / temperature
285
- probs = torch.softmax(logits, dim=-1)
286
- idx_next = torch.multinomial(probs, num_samples=1)
287
- else:
288
- idx_next = torch.argmax(logits, dim=-1, keepdim=True)
289
-
290
- # --- stop if EOS ---
291
- if eos_id is not None and idx_next.item() == eos_id:
292
- break
293
-
294
- # --- append and yield ---
295
- idx = torch.cat((idx, idx_next), dim=1)
296
- yield idx_next.item(), idx
297
-
298
- AutoModel.register(AnubisMoeConfig, AnubisMoeForCausalLM)
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import torch.utils.checkpoint as checkpoint
5
+ import torch.nn as nn
6
 
7
+ from transformers import ( AutoConfig,
8
+ AutoTokenizer,
9
+ PretrainedConfig,
10
+ PreTrainedModel,
11
+ AutoModelForCausalLM
12
+ )
13
 
14
+ class AnubisMoeConfig(PretrainedConfig):
15
+ """
16
+ This is the configuration class to store the configuration of an AnubisMOE model.
17
+ It is used to instantiate the model according to the specified arguments, defining the model architecture.
18
+ """
19
+ model_type = "anubis_moe" # This is a custom model type name
20
 
21
  def __init__(
22
  self,
23
  vocab_size=50257,
 
24
  context_length=1024,
25
+ embed_dim=768,
26
  n_layers=12,
27
  n_heads=12,
28
  drop_rate=0.1,
29
  qkv_bias=False,
30
  num_experts=8,
31
  top_k_experts=2,
32
+ expert_capacity=0,
33
+ **kwargs
34
  ):
35
  self.vocab_size = vocab_size
 
36
  self.context_length = context_length
37
+ self.embed_dim = embed_dim
38
  self.n_layers = n_layers
39
  self.n_heads = n_heads
40
  self.drop_rate = drop_rate
 
44
  self.expert_capacity = expert_capacity
45
  super().__init__(**kwargs)
46
 
47
+
48
  class LayerNorm(nn.Module):
49
  def __init__(self, emb_dim):
50
  super().__init__()
 
61
  class GELU(nn.Module):
62
  def __init__(self):
63
  super().__init__()
 
64
  def forward(self, x):
65
  return 0.5 * x * (1 + torch.tanh(
66
  torch.sqrt(torch.tensor(2.0 / torch.pi)) *
67
  (x + 0.044715 * torch.pow(x, 3))
68
  ))
69
 
70
+ # Your corrected FeedForward class
71
+ class FeedForward(nn.Sequential):
72
+ def __init__(self, config):
73
+ super().__init__(
74
+ nn.Linear(config.embed_dim, 4 * config.embed_dim),
75
  GELU(),
76
+ nn.Linear(4 * config.embed_dim, config.embed_dim),
77
  )
78
 
 
 
 
79
  class MultiHeadAttention(nn.Module):
80
+ def __init__(self, config):
81
  super().__init__()
82
+ d_out = config.embed_dim
83
+ num_heads = config.n_heads
84
+ context_length = config.context_length
85
+ dropout = config.drop_rate
86
+ qkv_bias = config.qkv_bias
 
 
 
87
 
88
+ assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
89
  self.d_out = d_out
90
  self.num_heads = num_heads
91
  self.head_dim = d_out // num_heads
92
+ self.W_query = nn.Linear(config.embed_dim, d_out, bias=qkv_bias)
93
+ self.W_key = nn.Linear(config.embed_dim, d_out, bias=qkv_bias)
94
+ self.W_value = nn.Linear(config.embed_dim, d_out, bias=qkv_bias)
 
95
  self.out_proj = nn.Linear(d_out, d_out)
96
  self.dropout = nn.Dropout(dropout)
97
+ self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
 
 
 
98
 
99
  def forward(self, x):
100
  b, num_tokens, d_in = x.shape
101
+ keys = self.W_key(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
102
+ queries = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
103
+ values = self.W_value(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
104
  attn_scores = queries @ keys.transpose(2, 3)
105
  mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
106
  attn_scores.masked_fill_(mask_bool, -torch.inf)
107
+ attn_weights = F.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
 
108
  attn_weights = self.dropout(attn_weights)
109
+ context_vec = (attn_weights @ values).transpose(1, 2).contiguous().view(b, num_tokens, self.d_out)
 
 
110
  context_vec = self.out_proj(context_vec)
111
  return context_vec
112
 
113
  class MixtureOfExperts(nn.Module):
114
+ def __init__(self, config):
115
  super().__init__()
116
+ self.num_experts = config.num_experts
117
+ self.top_k = config.top_k_experts
118
+ self.router = nn.Linear(config.embed_dim, self.num_experts)
119
+ self.experts = nn.ModuleList([FeedForward(config) for _ in range(self.num_experts)])
 
 
 
 
 
120
 
121
  def forward(self, x):
122
+ # This implementation is simplified for brevity but captures the essence.
123
+ # Your original implementation is perfectly fine here.
124
  batch_size, seq_len, embed_dim = x.shape
125
  x_flat = x.view(-1, embed_dim)
 
126
  router_logits = self.router(x_flat)
127
+ routing_weights = F.softmax(router_logits, dim=1)
128
+ topk_weights, topk_indices = torch.topk(routing_weights, self.top_k, dim=-1)
129
+
130
+ final_output = torch.zeros_like(x_flat)
 
 
 
 
 
 
 
131
  for i in range(self.top_k):
132
+ expert_indices = topk_indices[:, i]
133
+ for exp_idx in range(self.num_experts):
134
+ token_indices = (expert_indices == exp_idx).nonzero(as_tuple=True)[0]
135
+ if token_indices.numel() > 0:
136
+ tokens_for_expert = x_flat[token_indices]
137
+ expert_output = self.experts[exp_idx](tokens_for_expert)
138
+ final_output.index_add_(0, token_indices, expert_output * topk_weights[token_indices, i].unsqueeze(1))
139
+
140
+ return final_output.view(batch_size, seq_len, embed_dim)
 
 
 
 
 
 
 
 
 
 
141
 
142
  class TransformerBlockMOE(nn.Module):
143
+ def __init__(self, config):
144
  super().__init__()
145
+ self.ln1 = LayerNorm(config.embed_dim)
146
+ self.attn = MultiHeadAttention(config)
147
+ self.ln2 = LayerNorm(config.embed_dim)
148
+ self.ffn = MixtureOfExperts(config)
149
+ self.drop = nn.Dropout(config.drop_rate)
150
+
151
  def forward(self, x):
152
  attn_out = self.attn(self.ln1(x))
153
  x = x + self.drop(attn_out)
 
155
  x = x + self.drop(ffn_out)
156
  return x
157
 
158
+ # --- Step 2: Adapt your main model to inherit from PreTrainedModel ---
 
 
 
 
 
 
 
 
159
 
160
+ class AnubisMoeForCausalLM(PreTrainedModel):
161
+ config_class = AnubisMoeConfig # Link the config class
162
+
163
+ def __init__(self, config):
164
+ super().__init__(config)
165
+ self.tok_emb = nn.Embedding(config.vocab_size, config.embed_dim)
166
+ self.pos_emb = nn.Embedding(config.context_length, config.embed_dim)
167
+ self.drop_emb = nn.Dropout(config.drop_rate)
168
+
169
  self.trf_blocks = nn.Sequential(
170
+ *[TransformerBlockMOE(config) for _ in range(config.n_layers)]
171
  )
172
+
173
+ self.final_norm = LayerNorm(config.embed_dim)
174
+ self.out_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
175
 
176
  def forward(self, input_ids, **kwargs):
177
  batch_size, seq_len = input_ids.shape
178
  tok_embeds = self.tok_emb(input_ids)
 
179
  pos_ids = torch.arange(seq_len, device=input_ids.device)
180
  pos_embeds = self.pos_emb(pos_ids)
181
+
182
  x = tok_embeds + pos_embeds
183
  x = self.drop_emb(x)
184
  x = self.trf_blocks(x)
185
  x = self.final_norm(x)
186
  logits = self.out_head(x)
187
+
188
+ # The model must return a dictionary-like object
 
189
  return {"logits": logits}
190
 
191
+ # --- Step 3: Register your custom classes with the Auto* classes ---
192
+ AutoConfig.register("anubis_moe", AnubisMoeConfig)
193
+ AutoModelForCausalLM.register(AnubisMoeConfig, AnubisMoeForCausalLM)
194
+
195
+ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
196
+
197
+ # For-loop is the same as before: Get logits, and only focus on last time step
198
+ for _ in range(max_new_tokens):
199
+ idx_cond = idx[:, -context_size:]
200
+ with torch.no_grad():
201
+ logits = model(idx_cond)
202
+ # get the tensor
203
+ logits = logits[:, -1, :]
204
+
205
+ # New: Filter logits with top_k sampling
206
+ if top_k is not None:
207
+ # Keep only top_k values
208
+ top_logits, _ = torch.topk(logits, top_k)
209
+ min_val = top_logits[:, -1]
210
+ logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)
211
+
212
+ # New: Apply temperature scaling
213
+ if temperature > 0.0:
214
+ logits = logits / temperature
215
+
216
+ # Apply softmax to get probabilities
217
+ probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
218
+
219
+ # Sample from the distribution
220
+ idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
221
+
222
+ # Otherwise same as before: get idx of the vocab entry with the highest logits value
223
+ else:
224
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
225
+
226
+ if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
227
+ break
228
+
229
+ # Same as before: append sampled index to the running sequence
230
+ idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
231
+
232
+ return idx
233
+
234
+ def stream_generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
235
+ """
236
+ Stream tokens one by one instead of returning final sequence.
237
+ Yields (next_token, current_sequence).
238
+ """
239
+ for _ in range(max_new_tokens):
240
+ idx_cond = idx[:, -context_size:]
241
+
242
+ with torch.no_grad():
243
+ logits = model(idx_cond)
244
+ logits = logits[:, -1, :]
245
+
246
+ # --- top_k sampling ---
247
+ if top_k is not None:
248
+ top_logits, _ = torch.topk(logits, top_k)
249
+ min_val = top_logits[:, -1]
250
+ logits = torch.where(
251
+ logits < min_val,
252
+ torch.tensor(float("-inf")).to(logits.device),
253
+ logits,
254
+ )
255
+
256
+ # --- temperature ---
257
+ if temperature > 0.0:
258
+ logits = logits / temperature
259
+ probs = torch.softmax(logits, dim=-1)
260
+ idx_next = torch.multinomial(probs, num_samples=1)
261
+ else:
262
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True)
263
+
264
+ # --- stop if EOS ---
265
+ if eos_id is not None and idx_next.item() == eos_id:
266
+ break
267
+
268
+ # --- append and yield ---
269
+ idx = torch.cat((idx, idx_next), dim=1)
270
+ yield idx_next.item(), idx