| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel, GenerationMixin |
| | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
| | from .configuration_mcgpt import MCGPTConfig |
| |
|
| | class MCGPTBlock(nn.Module): |
| | def __init__(self, hidden_size, nhead, dropout=0.1): |
| | super().__init__() |
| | self.self_attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=nhead, batch_first=True, dropout=dropout) |
| | self.norm1 = nn.LayerNorm(hidden_size) |
| | self.ff = nn.Sequential( |
| | nn.Linear(hidden_size, hidden_size * 4), |
| | nn.GELU(), |
| | nn.Linear(hidden_size * 4, hidden_size) |
| | ) |
| | self.norm2 = nn.LayerNorm(hidden_size) |
| | self.dropout = nn.Dropout(dropout) |
| | |
| | def forward(self, x, mask=None): |
| | attn_out, _ = self.self_attn(x, x, x, attn_mask=mask, need_weights=False) |
| | x = self.norm1(x + self.dropout(attn_out)) |
| | x = self.norm2(x + self.dropout(self.ff(x))) |
| | return x |
| |
|
| | class Expert(nn.Module): |
| | def __init__(self, hidden_size): |
| | super().__init__() |
| | self.net = nn.Sequential( |
| | nn.Linear(hidden_size, hidden_size * 4), |
| | nn.GELU(), |
| | nn.Linear(hidden_size * 4, hidden_size) |
| | ) |
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| | class MCGPTForCausalLM(PreTrainedModel, GenerationMixin): |
| | config_class = MCGPTConfig |
| | |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
| | self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| | self.blocks = nn.ModuleList([MCGPTBlock(config.hidden_size, config.nhead) for _ in range(config.num_layers)]) |
| | self.experts = nn.ModuleList([Expert(config.hidden_size) for _ in range(config.num_experts)]) |
| | self.router = nn.Linear(config.hidden_size, config.num_experts) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | |
| | self.main_input_name = "input_ids" |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.embedding |
| | |
| | def set_input_embeddings(self, value): |
| | self.embedding = value |
| |
|
| | def forward(self, input_ids, labels=None, attention_mask=None, return_dict=True, **kwargs): |
| | |
| | batch_size, seq_len = input_ids.shape |
| | pos = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) |
| | x = self.embedding(input_ids) + self.pos_embedding(pos) |
| | |
| | |
| | mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), 1).bool() |
| | |
| | for block in self.blocks: |
| | x = block(x, mask=mask) |
| | |
| | weights = torch.softmax(self.router(x), dim=-1) |
| | moe_out = torch.zeros_like(x) |
| | for i, expert in enumerate(self.experts): |
| | moe_out += weights[:, :, i:i+1] * expert(x) |
| | |
| | logits = self.lm_head(moe_out) |
| | |
| | loss = None |
| | if labels is not None: |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | loss = nn.CrossEntropyLoss()(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
| | |
| | if not return_dict: |
| | return (logits, loss) if loss is not None else (logits,) |
| | |
| | return CausalLMOutputWithCrossAttentions(loss=loss, logits=logits) |
| |
|
| | def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| | |
| | return {"input_ids": input_ids} |