import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutput from .configuration_duchifat_v2 import DuchifatConfig class DuchifatBlock(nn.Module): def __init__(self, config): super().__init__() self.ln1 = nn.LayerNorm(config.hidden_size) self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.wo = nn.Linear(config.hidden_size, config.hidden_size) self.ln2 = nn.LayerNorm(config.hidden_size) self.mlp = nn.Sequential( nn.Linear(config.hidden_size, 4 * config.hidden_size), nn.GELU(approximate='tanh'), nn.Linear(4 * config.hidden_size, config.hidden_size) ) self.n_head = config.nhead self.head_dim = config.hidden_size // config.nhead def forward(self, x): norm_x = self.ln1(x) B, T, C = norm_x.size() qkv = self.qkv(norm_x).view(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # Flash Attention (SDPA) attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C) x = x + self.wo(attn_out) x = x + self.mlp(self.ln2(x)) return x class DuchifatPreTrainedModel(PreTrainedModel): config_class = DuchifatConfig base_model_prefix = "model" _no_split_modules = ["DuchifatBlock"] class DuchifatCore(DuchifatPreTrainedModel): def __init__(self, config): super().__init__(config) self.wte = nn.Embedding(config.vocab_size, config.hidden_size) self.wpe = nn.Embedding(config.max_seq, config.hidden_size) self.blocks = nn.ModuleList([DuchifatBlock(config) for _ in range(config.num_layers)]) self.ln_f = nn.LayerNorm(config.hidden_size) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights self.post_init() def get_input_embeddings(self): return self.wte def set_input_embeddings(self, value): self.wte = value def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): # טיפול במקרה שבו input_ids לא נשלח כראוי if input_ids is None: raise ValueError("You must specify input_ids") B, T = input_ids.size() device = input_ids.device # בניית פוזיציות (Absolute Positional Embeddings) pos = torch.arange(0, T, dtype=torch.long, device=device) x = self.wte(input_ids) + self.wpe(pos) for block in self.blocks: x = block(x) logits = self.lm_head(self.ln_f(x)) loss = None if labels is not None: # Shift logits/labels עבור Causal Language Modeling (הזזה של 1 ימינה) shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) return CausalLMOutput( loss=loss, logits=logits ) # פונקציה חיונית שמאפשרת ל-generate לעבוד def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): return { "input_ids": input_ids, "attention_mask": attention_mask } # תמיכה ב-Beam Search ובדיקות קאש בסיסיות def _reorder_cache(self, past_key_values, beam_idx): return past_key_values