import torch import torch.nn as nn from transformers import PreTrainedModel, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from .configuration_mcgpt import MCGPTConfig class MCGPTBlock(nn.Module): def __init__(self, hidden_size, nhead, dropout=0.1): super().__init__() self.self_attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=nhead, batch_first=True, dropout=dropout) self.norm1 = nn.LayerNorm(hidden_size) self.ff = nn.Sequential( nn.Linear(hidden_size, hidden_size * 4), nn.GELU(), nn.Linear(hidden_size * 4, hidden_size) ) self.norm2 = nn.LayerNorm(hidden_size) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): attn_out, _ = self.self_attn(x, x, x, attn_mask=mask, need_weights=False) x = self.norm1(x + self.dropout(attn_out)) x = self.norm2(x + self.dropout(self.ff(x))) return x class Expert(nn.Module): def __init__(self, hidden_size): super().__init__() self.net = nn.Sequential( nn.Linear(hidden_size, hidden_size * 4), nn.GELU(), nn.Linear(hidden_size * 4, hidden_size) ) def forward(self, x): return self.net(x) class MCGPTForCausalLM(PreTrainedModel, GenerationMixin): config_class = MCGPTConfig def __init__(self, config): super().__init__(config) self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.blocks = nn.ModuleList([MCGPTBlock(config.hidden_size, config.nhead) for _ in range(config.num_layers)]) self.experts = nn.ModuleList([Expert(config.hidden_size) for _ in range(config.num_experts)]) self.router = nn.Linear(config.hidden_size, config.num_experts) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.main_input_name = "input_ids" self.post_init() def get_input_embeddings(self): return self.embedding def set_input_embeddings(self, value): self.embedding = value def forward(self, input_ids, labels=None, attention_mask=None, return_dict=True, **kwargs): # attention_mask התווסף כאן כדי למנוע את ה-ValueError ב-Hugging Face batch_size, seq_len = input_ids.shape pos = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) x = self.embedding(input_ids) + self.pos_embedding(pos) # בניית מסכה למניעת הצצה קדימה (Causal Mask) mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), 1).bool() for block in self.blocks: x = block(x, mask=mask) weights = torch.softmax(self.router(x), dim=-1) moe_out = torch.zeros_like(x) for i, expert in enumerate(self.experts): moe_out += weights[:, :, i:i+1] * expert(x) logits = self.lm_head(moe_out) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = nn.CrossEntropyLoss()(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if not return_dict: return (logits, loss) if loss is not None else (logits,) return CausalLMOutputWithCrossAttentions(loss=loss, logits=logits) def prepare_inputs_for_generation(self, input_ids, **kwargs): # מוודא שרק input_ids עובר כדי למנוע כפילויות return {"input_ids": input_ids}