Trouter-Library commited on
Commit
42d6c22
·
verified ·
1 Parent(s): 9daefda

Create modeling_helion.py

Browse files
Files changed (1) hide show
  1. modeling_helion.py +404 -0
modeling_helion.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch Helion model implementation."""
2
+
3
+ import math
4
+ from typing import Optional, Tuple, Union, List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn import CrossEntropyLoss
10
+
11
+ from transformers import PreTrainedModel
12
+ from transformers.modeling_outputs import CausalLMOutputWithPast
13
+ from transformers.utils import logging
14
+
15
+ from .configuration_helion import HelionConfig
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ class HelionRMSNorm(nn.Module):
21
+ """Root Mean Square Layer Normalization."""
22
+
23
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
24
+ super().__init__()
25
+ self.weight = nn.Parameter(torch.ones(hidden_size))
26
+ self.variance_epsilon = eps
27
+
28
+ def forward(self, hidden_states):
29
+ input_dtype = hidden_states.dtype
30
+ hidden_states = hidden_states.to(torch.float32)
31
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
32
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
33
+ return self.weight * hidden_states.to(input_dtype)
34
+
35
+
36
+ class HelionRotaryEmbedding(nn.Module):
37
+ """Rotary Position Embedding (RoPE)."""
38
+
39
+ def __init__(self, dim: int, max_position_embeddings: int = 8192, base: int = 10000):
40
+ super().__init__()
41
+ self.dim = dim
42
+ self.max_position_embeddings = max_position_embeddings
43
+ self.base = base
44
+
45
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
46
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
47
+
48
+ def forward(self, x, seq_len: int):
49
+ # Generate position indices
50
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
51
+ freqs = torch.outer(t, self.inv_freq)
52
+ emb = torch.cat((freqs, freqs), dim=-1)
53
+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
54
+
55
+
56
+ def rotate_half(x):
57
+ """Rotates half the hidden dims of the input."""
58
+ x1 = x[..., : x.shape[-1] // 2]
59
+ x2 = x[..., x.shape[-1] // 2 :]
60
+ return torch.cat((-x2, x1), dim=-1)
61
+
62
+
63
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
64
+ """Apply rotary position embedding to queries and keys."""
65
+ cos = cos[position_ids].unsqueeze(1)
66
+ sin = sin[position_ids].unsqueeze(1)
67
+ q_embed = (q * cos) + (rotate_half(q) * sin)
68
+ k_embed = (k * cos) + (rotate_half(k) * sin)
69
+ return q_embed, k_embed
70
+
71
+
72
+ class HelionAttention(nn.Module):
73
+ """Multi-head attention with Grouped Query Attention (GQA)."""
74
+
75
+ def __init__(self, config: HelionConfig):
76
+ super().__init__()
77
+ self.config = config
78
+ self.hidden_size = config.hidden_size
79
+ self.num_heads = config.num_attention_heads
80
+ self.head_dim = self.hidden_size // self.num_heads
81
+ self.num_key_value_heads = config.num_key_value_heads
82
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
83
+ self.max_position_embeddings = config.max_position_embeddings
84
+
85
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
86
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
87
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
88
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
89
+
90
+ self.rotary_emb = HelionRotaryEmbedding(
91
+ self.head_dim,
92
+ max_position_embeddings=self.max_position_embeddings,
93
+ base=config.rope_theta,
94
+ )
95
+
96
+ def forward(
97
+ self,
98
+ hidden_states: torch.Tensor,
99
+ attention_mask: Optional[torch.Tensor] = None,
100
+ position_ids: Optional[torch.LongTensor] = None,
101
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
102
+ output_attentions: bool = False,
103
+ use_cache: bool = False,
104
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
105
+ bsz, q_len, _ = hidden_states.size()
106
+
107
+ query_states = self.q_proj(hidden_states)
108
+ key_states = self.k_proj(hidden_states)
109
+ value_states = self.v_proj(hidden_states)
110
+
111
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
112
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
113
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
114
+
115
+ kv_seq_len = key_states.shape[-2]
116
+ if past_key_value is not None:
117
+ kv_seq_len += past_key_value[0].shape[-2]
118
+
119
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
120
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
121
+
122
+ if past_key_value is not None:
123
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
124
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
125
+
126
+ past_key_value = (key_states, value_states) if use_cache else None
127
+
128
+ # Repeat k/v heads for GQA
129
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
130
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
131
+
132
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
133
+
134
+ if attention_mask is not None:
135
+ attn_weights = attn_weights + attention_mask
136
+
137
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
138
+ attn_output = torch.matmul(attn_weights, value_states)
139
+
140
+ attn_output = attn_output.transpose(1, 2).contiguous()
141
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
142
+ attn_output = self.o_proj(attn_output)
143
+
144
+ return attn_output, attn_weights if output_attentions else None, past_key_value
145
+
146
+
147
+ class HelionMLP(nn.Module):
148
+ """Feed-forward network with SwiGLU activation."""
149
+
150
+ def __init__(self, config: HelionConfig):
151
+ super().__init__()
152
+ self.hidden_size = config.hidden_size
153
+ self.intermediate_size = config.intermediate_size
154
+
155
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
156
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
157
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
158
+ self.act_fn = nn.SiLU()
159
+
160
+ def forward(self, x):
161
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
162
+
163
+
164
+ class HelionDecoderLayer(nn.Module):
165
+ """Single transformer decoder layer."""
166
+
167
+ def __init__(self, config: HelionConfig):
168
+ super().__init__()
169
+ self.hidden_size = config.hidden_size
170
+ self.self_attn = HelionAttention(config)
171
+ self.mlp = HelionMLP(config)
172
+ self.input_layernorm = HelionRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
173
+ self.post_attention_layernorm = HelionRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
174
+
175
+ def forward(
176
+ self,
177
+ hidden_states: torch.Tensor,
178
+ attention_mask: Optional[torch.Tensor] = None,
179
+ position_ids: Optional[torch.LongTensor] = None,
180
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
181
+ output_attentions: Optional[bool] = False,
182
+ use_cache: Optional[bool] = False,
183
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
184
+ residual = hidden_states
185
+ hidden_states = self.input_layernorm(hidden_states)
186
+
187
+ # Self Attention
188
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
189
+ hidden_states=hidden_states,
190
+ attention_mask=attention_mask,
191
+ position_ids=position_ids,
192
+ past_key_value=past_key_value,
193
+ output_attentions=output_attentions,
194
+ use_cache=use_cache,
195
+ )
196
+ hidden_states = residual + hidden_states
197
+
198
+ # MLP
199
+ residual = hidden_states
200
+ hidden_states = self.post_attention_layernorm(hidden_states)
201
+ hidden_states = self.mlp(hidden_states)
202
+ hidden_states = residual + hidden_states
203
+
204
+ outputs = (hidden_states,)
205
+ if output_attentions:
206
+ outputs += (self_attn_weights,)
207
+ if use_cache:
208
+ outputs += (present_key_value,)
209
+
210
+ return outputs
211
+
212
+
213
+ class HelionPreTrainedModel(PreTrainedModel):
214
+ """Helion pretrained model base class."""
215
+
216
+ config_class = HelionConfig
217
+ base_model_prefix = "model"
218
+ supports_gradient_checkpointing = True
219
+ _no_split_modules = ["HelionDecoderLayer"]
220
+
221
+ def _init_weights(self, module):
222
+ std = self.config.initializer_range
223
+ if isinstance(module, nn.Linear):
224
+ module.weight.data.normal_(mean=0.0, std=std)
225
+ if module.bias is not None:
226
+ module.bias.data.zero_()
227
+ elif isinstance(module, nn.Embedding):
228
+ module.weight.data.normal_(mean=0.0, std=std)
229
+ if module.padding_idx is not None:
230
+ module.weight.data[module.padding_idx].zero_()
231
+
232
+
233
+ class HelionModel(HelionPreTrainedModel):
234
+ """Helion transformer model."""
235
+
236
+ def __init__(self, config: HelionConfig):
237
+ super().__init__(config)
238
+ self.padding_idx = config.pad_token_id
239
+ self.vocab_size = config.vocab_size
240
+
241
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
242
+ self.layers = nn.ModuleList([HelionDecoderLayer(config) for _ in range(config.num_hidden_layers)])
243
+ self.norm = HelionRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
244
+
245
+ self.gradient_checkpointing = False
246
+ self.post_init()
247
+
248
+ def forward(
249
+ self,
250
+ input_ids: torch.LongTensor = None,
251
+ attention_mask: Optional[torch.Tensor] = None,
252
+ position_ids: Optional[torch.LongTensor] = None,
253
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
254
+ inputs_embeds: Optional[torch.FloatTensor] = None,
255
+ use_cache: Optional[bool] = None,
256
+ output_attentions: Optional[bool] = None,
257
+ output_hidden_states: Optional[bool] = None,
258
+ return_dict: Optional[bool] = None,
259
+ ):
260
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
261
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
262
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
263
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
264
+
265
+ if input_ids is not None and inputs_embeds is not None:
266
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
267
+ elif input_ids is not None:
268
+ batch_size, seq_length = input_ids.shape
269
+ elif inputs_embeds is not None:
270
+ batch_size, seq_length, _ = inputs_embeds.shape
271
+ else:
272
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
273
+
274
+ if position_ids is None:
275
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device if input_ids is not None else inputs_embeds.device)
276
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
277
+
278
+ if inputs_embeds is None:
279
+ inputs_embeds = self.embed_tokens(input_ids)
280
+
281
+ hidden_states = inputs_embeds
282
+
283
+ for idx, decoder_layer in enumerate(self.layers):
284
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
285
+
286
+ layer_outputs = decoder_layer(
287
+ hidden_states,
288
+ attention_mask=attention_mask,
289
+ position_ids=position_ids,
290
+ past_key_value=past_key_value,
291
+ output_attentions=output_attentions,
292
+ use_cache=use_cache,
293
+ )
294
+
295
+ hidden_states = layer_outputs[0]
296
+
297
+ hidden_states = self.norm(hidden_states)
298
+
299
+ return hidden_states
300
+
301
+
302
+ class HelionForCausalLM(HelionPreTrainedModel):
303
+ """Helion model with language modeling head."""
304
+
305
+ def __init__(self, config):
306
+ super().__init__(config)
307
+ self.model = HelionModel(config)
308
+ self.vocab_size = config.vocab_size
309
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
310
+
311
+ self.post_init()
312
+
313
+ def get_input_embeddings(self):
314
+ return self.model.embed_tokens
315
+
316
+ def set_input_embeddings(self, value):
317
+ self.model.embed_tokens = value
318
+
319
+ def get_output_embeddings(self):
320
+ return self.lm_head
321
+
322
+ def set_output_embeddings(self, new_embeddings):
323
+ self.lm_head = new_embeddings
324
+
325
+ def forward(
326
+ self,
327
+ input_ids: torch.LongTensor = None,
328
+ attention_mask: Optional[torch.Tensor] = None,
329
+ position_ids: Optional[torch.LongTensor] = None,
330
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
331
+ inputs_embeds: Optional[torch.FloatTensor] = None,
332
+ labels: Optional[torch.LongTensor] = None,
333
+ use_cache: Optional[bool] = None,
334
+ output_attentions: Optional[bool] = None,
335
+ output_hidden_states: Optional[bool] = None,
336
+ return_dict: Optional[bool] = None,
337
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
338
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
339
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
340
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
341
+
342
+ hidden_states = self.model(
343
+ input_ids=input_ids,
344
+ attention_mask=attention_mask,
345
+ position_ids=position_ids,
346
+ past_key_values=past_key_values,
347
+ inputs_embeds=inputs_embeds,
348
+ use_cache=use_cache,
349
+ output_attentions=output_attentions,
350
+ output_hidden_states=output_hidden_states,
351
+ return_dict=return_dict,
352
+ )
353
+
354
+ logits = self.lm_head(hidden_states)
355
+ logits = logits.float()
356
+
357
+ loss = None
358
+ if labels is not None:
359
+ shift_logits = logits[..., :-1, :].contiguous()
360
+ shift_labels = labels[..., 1:].contiguous()
361
+ loss_fct = CrossEntropyLoss()
362
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
363
+ shift_labels = shift_labels.view(-1)
364
+ shift_labels = shift_labels.to(shift_logits.device)
365
+ loss = loss_fct(shift_logits, shift_labels)
366
+
367
+ if not return_dict:
368
+ output = (logits,)
369
+ return (loss,) + output if loss is not None else output
370
+
371
+ return CausalLMOutputWithPast(
372
+ loss=loss,
373
+ logits=logits,
374
+ past_key_values=past_key_values,
375
+ hidden_states=None,
376
+ attentions=None,
377
+ )
378
+
379
+ def prepare_inputs_for_generation(
380
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
381
+ ):
382
+ if past_key_values:
383
+ input_ids = input_ids[:, -1:]
384
+
385
+ position_ids = kwargs.get("position_ids", None)
386
+ if attention_mask is not None and position_ids is None:
387
+ position_ids = attention_mask.long().cumsum(-1) - 1
388
+ position_ids.masked_fill_(attention_mask == 0, 1)
389
+ if past_key_values:
390
+ position_ids = position_ids[:, -1].unsqueeze(-1)
391
+
392
+ if inputs_embeds is not None and past_key_values is None:
393
+ model_inputs = {"inputs_embeds": inputs_embeds}
394
+ else:
395
+ model_inputs = {"input_ids": input_ids}
396
+
397
+ model_inputs.update({
398
+ "position_ids": position_ids,
399
+ "past_key_values": past_key_values,
400
+ "use_cache": kwargs.get("use_cache"),
401
+ "attention_mask": attention_mask,
402
+ })
403
+
404
+ return model_inputs