SoraForSLM-1 / modeling_sora.py
Clemylia's picture
Update modeling_sora.py
c1e244c verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutput
from .configuration_sora import SoraConfig
class SoraForSLM(PreTrainedModel, GenerationMixin):
config_class = SoraConfig
def __init__(self, config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_heads,
dim_feedforward=config.hidden_size * 4,
batch_first=True,
activation="gelu"
) for _ in range(config.num_layers)
])
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
# Calcul des positions
seq_length = input_ids.size(1)
positions = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
# Embeddings
x = self.embeddings(input_ids) + self.position_embeddings(positions)
# Passage dans les couches (sans masque pour éviter tout conflit)
for layer in self.layers:
x = layer(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
# Shift pour l'entraînement causal
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
return CausalLMOutput(loss=loss, logits=logits)