HALION-AI commited on
Commit
b08cac6
·
verified ·
1 Parent(s): ae3ce31

Add modeling_helionx.py

Browse files
Files changed (1) hide show
  1. modeling_helionx.py +71 -0
modeling_helionx.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from .configuration_helionx import HelionXConfig
5
+
6
+
7
+ class HelionXSelfAttention(nn.Module):
8
+ def __init__(self, config):
9
+ super().__init__()
10
+ self.attn = nn.MultiheadAttention(
11
+ embed_dim=config.hidden_size,
12
+ num_heads=config.num_attention_heads,
13
+ batch_first=True,
14
+ )
15
+
16
+ def forward(self, x):
17
+ out, _ = self.attn(x, x, x, need_weights=False)
18
+ return out
19
+
20
+
21
+ class HelionXBlock(nn.Module):
22
+ def __init__(self, config):
23
+ super().__init__()
24
+
25
+ self.self_attn = HelionXSelfAttention(config)
26
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
27
+
28
+ self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size)
29
+ self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size)
30
+
31
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
32
+ self.act = nn.GELU()
33
+
34
+ def forward(self, x):
35
+ x = x + self.self_attn(self.norm1(x))
36
+ x = x + self.linear2(self.act(self.linear1(self.norm2(x))))
37
+ return x
38
+
39
+
40
+ class HelionXLM(PreTrainedModel):
41
+ config_class = HelionXConfig
42
+ base_model_prefix = "helionx"
43
+
44
+ def __init__(self, config):
45
+ super().__init__(config)
46
+
47
+ self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
48
+ self.pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)
49
+
50
+ self.layers = nn.ModuleList(
51
+ [HelionXBlock(config) for _ in range(config.num_hidden_layers)]
52
+ )
53
+
54
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
55
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
56
+
57
+ self.post_init()
58
+
59
+ def forward(self, input_ids, **kwargs):
60
+ bsz, seq_len = input_ids.shape
61
+ pos = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
62
+
63
+ x = self.embed(input_ids) + self.pos_embed(pos)
64
+
65
+ for layer in self.layers:
66
+ x = layer(x)
67
+
68
+ x = self.ln(x)
69
+ logits = self.lm_head(x)
70
+
71
+ return {"logits": logits}