| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutput |
| from .configuration_sora import SoraConfig |
|
|
| class SoraForSLM(PreTrainedModel, GenerationMixin): |
| config_class = SoraConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
| self.layers = nn.ModuleList([ |
| nn.TransformerEncoderLayer( |
| d_model=config.hidden_size, |
| nhead=config.num_heads, |
| dim_feedforward=config.hidden_size * 4, |
| batch_first=True, |
| activation="gelu" |
| ) for _ in range(config.num_layers) |
| ]) |
|
|
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embeddings |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| |
| seq_length = input_ids.size(1) |
| positions = torch.arange(seq_length, device=input_ids.device).unsqueeze(0) |
| |
| |
| x = self.embeddings(input_ids) + self.position_embeddings(positions) |
|
|
| |
| for layer in self.layers: |
| x = layer(x) |
|
|
| logits = self.lm_head(x) |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = input_ids[..., 1:].contiguous() |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
| |
| return CausalLMOutput(loss=loss, logits=logits) |