MCGPT-1 / modeling_mcgpt.py
Raziel1234's picture
Update modeling_mcgpt.py
9a284e6 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from .configuration_mcgpt import MCGPTConfig
class MCGPTBlock(nn.Module):
def __init__(self, hidden_size, nhead, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=nhead, batch_first=True, dropout=dropout)
self.norm1 = nn.LayerNorm(hidden_size)
self.ff = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(),
nn.Linear(hidden_size * 4, hidden_size)
)
self.norm2 = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_out, _ = self.self_attn(x, x, x, attn_mask=mask, need_weights=False)
x = self.norm1(x + self.dropout(attn_out))
x = self.norm2(x + self.dropout(self.ff(x)))
return x
class Expert(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.net = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(),
nn.Linear(hidden_size * 4, hidden_size)
)
def forward(self, x):
return self.net(x)
class MCGPTForCausalLM(PreTrainedModel, GenerationMixin):
config_class = MCGPTConfig
def __init__(self, config):
super().__init__(config)
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.blocks = nn.ModuleList([MCGPTBlock(config.hidden_size, config.nhead) for _ in range(config.num_layers)])
self.experts = nn.ModuleList([Expert(config.hidden_size) for _ in range(config.num_experts)])
self.router = nn.Linear(config.hidden_size, config.num_experts)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.main_input_name = "input_ids"
self.post_init()
def get_input_embeddings(self):
return self.embedding
def set_input_embeddings(self, value):
self.embedding = value
def forward(self, input_ids, labels=None, attention_mask=None, return_dict=True, **kwargs):
# attention_mask 讛转讜讜住祝 讻讗谉 讻讚讬 诇诪谞讜注 讗转 讛-ValueError 讘-Hugging Face
batch_size, seq_len = input_ids.shape
pos = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
x = self.embedding(input_ids) + self.pos_embedding(pos)
# 讘谞讬讬转 诪住讻讛 诇诪谞讬注转 讛爪爪讛 拽讚讬诪讛 (Causal Mask)
mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), 1).bool()
for block in self.blocks:
x = block(x, mask=mask)
weights = torch.softmax(self.router(x), dim=-1)
moe_out = torch.zeros_like(x)
for i, expert in enumerate(self.experts):
moe_out += weights[:, :, i:i+1] * expert(x)
logits = self.lm_head(moe_out)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = nn.CrossEntropyLoss()(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
return (logits, loss) if loss is not None else (logits,)
return CausalLMOutputWithCrossAttentions(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# 诪讜讜讚讗 砖专拽 input_ids 注讜讘专 讻讚讬 诇诪谞讜注 讻驻讬诇讜讬讜转
return {"input_ids": input_ids}