| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutput |
| | from .configuration_duchifat_v2 import DuchifatConfig |
| |
|
| | class DuchifatBlock(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.ln1 = nn.LayerNorm(config.hidden_size) |
| | self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size) |
| | self.wo = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.ln2 = nn.LayerNorm(config.hidden_size) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(config.hidden_size, 4 * config.hidden_size), |
| | nn.GELU(approximate='tanh'), |
| | nn.Linear(4 * config.hidden_size, config.hidden_size) |
| | ) |
| | self.n_head = config.nhead |
| | self.head_dim = config.hidden_size // config.nhead |
| |
|
| | def forward(self, x): |
| | norm_x = self.ln1(x) |
| | B, T, C = norm_x.size() |
| | qkv = self.qkv(norm_x).view(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | |
| | attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| | attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C) |
| |
|
| | x = x + self.wo(attn_out) |
| | x = x + self.mlp(self.ln2(x)) |
| | return x |
| |
|
| | class DuchifatPreTrainedModel(PreTrainedModel): |
| | config_class = DuchifatConfig |
| | base_model_prefix = "model" |
| | _no_split_modules = ["DuchifatBlock"] |
| |
|
| | class DuchifatCore(DuchifatPreTrainedModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.wte = nn.Embedding(config.vocab_size, config.hidden_size) |
| | self.wpe = nn.Embedding(config.max_seq, config.hidden_size) |
| | self.blocks = nn.ModuleList([DuchifatBlock(config) for _ in range(config.num_layers)]) |
| | self.ln_f = nn.LayerNorm(config.hidden_size) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.wte |
| |
|
| | def set_input_embeddings(self, value): |
| | self.wte = value |
| |
|
| | def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| | |
| | if input_ids is None: |
| | raise ValueError("You must specify input_ids") |
| | |
| | B, T = input_ids.size() |
| | device = input_ids.device |
| | |
| | |
| | pos = torch.arange(0, T, dtype=torch.long, device=device) |
| | |
| | x = self.wte(input_ids) + self.wpe(pos) |
| |
|
| | for block in self.blocks: |
| | x = block(x) |
| |
|
| | logits = self.lm_head(self.ln_f(x)) |
| |
|
| | loss = None |
| | if labels is not None: |
| | |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
| |
|
| | return CausalLMOutput( |
| | loss=loss, |
| | logits=logits |
| | ) |
| |
|
| | |
| | def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): |
| | return { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask |
| | } |
| |
|
| | |
| | def _reorder_cache(self, past_key_values, beam_idx): |
| | return past_key_values |