import torch import torch.nn as nn from transformers import PreTrainedModel, GenerationMixin from transformers.modeling_outputs import CausalLMOutput from .configuration_sora import SoraConfig class SoraForSLM(PreTrainedModel, GenerationMixin): config_class = SoraConfig def __init__(self, config): super().__init__(config) self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.layers = nn.ModuleList([ nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_heads, dim_feedforward=config.hidden_size * 4, batch_first=True, activation="gelu" ) for _ in range(config.num_layers) ]) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.embeddings def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): # Calcul des positions seq_length = input_ids.size(1) positions = torch.arange(seq_length, device=input_ids.device).unsqueeze(0) # Embeddings x = self.embeddings(input_ids) + self.position_embeddings(positions) # Passage dans les couches (sans masque pour éviter tout conflit) for layer in self.layers: x = layer(x) logits = self.lm_head(x) loss = None if labels is not None: # Shift pour l'entraînement causal shift_logits = logits[..., :-1, :].contiguous() shift_labels = input_ids[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) return CausalLMOutput(loss=loss, logits=logits)