| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from .configuration_helionx import HelionXConfig |
|
|
|
|
| class HelionXSelfAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.attn = nn.MultiheadAttention( |
| embed_dim=config.hidden_size, |
| num_heads=config.num_attention_heads, |
| batch_first=True, |
| ) |
|
|
| def forward(self, x): |
| out, _ = self.attn(x, x, x, need_weights=False) |
| return out |
|
|
|
|
| class HelionXBlock(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
|
|
| self.self_attn = HelionXSelfAttention(config) |
| self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size) |
| self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
| self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.act = nn.GELU() |
|
|
| def forward(self, x): |
| x = x + self.self_attn(self.norm1(x)) |
| x = x + self.linear2(self.act(self.linear1(self.norm2(x)))) |
| return x |
|
|
|
|
| class HelionXLM(PreTrainedModel): |
| config_class = HelionXConfig |
| base_model_prefix = "helionx" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.embed = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
| self.layers = nn.ModuleList( |
| [HelionXBlock(config) for _ in range(config.num_hidden_layers)] |
| ) |
|
|
| self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| self.post_init() |
|
|
| def forward(self, input_ids, **kwargs): |
| bsz, seq_len = input_ids.shape |
| pos = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) |
|
|
| x = self.embed(input_ids) + self.pos_embed(pos) |
|
|
| for layer in self.layers: |
| x = layer(x) |
|
|
| x = self.ln(x) |
| logits = self.lm_head(x) |
|
|
| return {"logits": logits} |
|
|