| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel |
| | from transformers import LogitsProcessorList, StoppingCriteriaList, MaxLengthCriteria, MinLengthLogitsProcessor |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| |
|
| | from .configure import RecombinationTransformerConfig |
| |
|
| | class MaskedSelfAttentionLayer(nn.Module): |
| | def __init__(self, embed_dim, num_heads): |
| | super(MaskedSelfAttentionLayer, self).__init__() |
| | self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) |
| | |
| | def forward(self, q, k, v, attn_mask=None): |
| | attn_output, _ = self.multihead_attn(q, k, v, attn_mask=attn_mask) |
| | return attn_output |
| |
|
| | class FcLayer(nn.Module): |
| | def __init__(self, input_dim, output_dim): |
| | super(FcLayer, self).__init__() |
| | self.fc = nn.Linear(input_dim, output_dim) |
| | |
| | def forward(self, x): |
| | return self.fc(x) |
| |
|
| | class SwishGLU(nn.Module): |
| | def __init__(self, input_dim): |
| | super(SwishGLU, self).__init__() |
| | self.fc1 = nn.Linear(input_dim, input_dim) |
| | self.fc2 = nn.Linear(input_dim, input_dim) |
| | |
| | def forward(self, x): |
| | return torch.sigmoid(self.fc1(x)) * self.fc2(x) |
| |
|
| | class SpecialLayerF(nn.Module): |
| | def __init__(self, input_dim): |
| | super(SpecialLayerF, self).__init__() |
| | self.proj_up = nn.Linear(input_dim, input_dim) |
| | self.proj_gate = SwishGLU(input_dim) |
| | |
| | def forward(self, o2, o3): |
| | cross_product = o2 * o3 |
| | proj_up_output = self.proj_up(cross_product) |
| | proj_gate_output = self.proj_gate(cross_product) |
| | return proj_up_output * proj_gate_output |
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, embed_dim, eps=1e-8): |
| | super(RMSNorm, self).__init__() |
| | self.embed_dim = embed_dim |
| | self.eps = eps |
| | self.scale = nn.Parameter(torch.ones(embed_dim)) |
| | |
| | def forward(self, x): |
| | norm = x.norm(2, dim=-1, keepdim=True) |
| | rms_norm = x / (norm + self.eps) |
| | return self.scale * rms_norm |
| |
|
| | class MLP(nn.Module): |
| | def __init__(self, input_dim, hidden_dim): |
| | super(MLP, self).__init__() |
| | self.up_proj = nn.Linear(input_dim, hidden_dim) |
| | self.gate_proj = nn.Linear(input_dim, hidden_dim) |
| | self.act = SwishGLU(hidden_dim) |
| | self.down_proj = nn.Linear(hidden_dim, input_dim) |
| | |
| | def forward(self, x): |
| | return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) |
| |
|
| | class RecombinationTransformerLayer(nn.Module): |
| | def __init__(self, embed_dim, num_heads): |
| | super(RecombinationTransformerLayer, self).__init__() |
| | self.num_heads = num_heads |
| | |
| | |
| | self.self_attention_1 = MaskedSelfAttentionLayer(embed_dim, num_heads) |
| | self.fc_q = FcLayer(embed_dim, embed_dim) |
| | self.fc_k = FcLayer(embed_dim, embed_dim) |
| | self.fc_v = FcLayer(embed_dim, embed_dim) |
| | |
| | |
| | self.self_attention_2 = MaskedSelfAttentionLayer(embed_dim, num_heads) |
| | self.fc_qc = FcLayer(embed_dim, embed_dim) |
| | self.fc_kb = FcLayer(embed_dim, embed_dim) |
| | self.fc_vb = FcLayer(embed_dim, embed_dim) |
| | |
| | |
| | self.self_attention_3 = MaskedSelfAttentionLayer(embed_dim, num_heads) |
| | |
| | |
| | self.special_layer_f = SpecialLayerF(embed_dim) |
| | |
| | |
| | self.mlp = MLP(embed_dim, embed_dim * 4) |
| | self.rms_norm1 = RMSNorm(embed_dim) |
| | self.rms_norm2 = RMSNorm(embed_dim) |
| | |
| | def forward(self, x, attn_mask=None): |
| | batch_size, seq_length, _ = x.size() |
| | |
| | if attn_mask is not None: |
| | |
| | attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(batch_size * self.num_heads, seq_length, seq_length) |
| | |
| | |
| | q1 = self.fc_q(x).transpose(0, 1) |
| | k1 = self.fc_k(x).transpose(0, 1) |
| | v1 = self.fc_v(x).transpose(0, 1) |
| | o1 = self.self_attention_1(q1, k1, v1, attn_mask=attn_mask).transpose(0, 1) |
| | |
| | |
| | q2 = q1 |
| | k2 = self.fc_kb(o1).transpose(0, 1) |
| | v2 = self.fc_vb(o1).transpose(0, 1) |
| | o2 = self.self_attention_2(q2, k2, v2, attn_mask=attn_mask).transpose(0, 1) |
| | |
| | |
| | q3 = self.fc_qc(o1).transpose(0, 1) |
| | k3 = k1 |
| | v3 = v1 |
| | o3 = self.self_attention_3(q3, k3, v3, attn_mask=attn_mask).transpose(0, 1) |
| | |
| | |
| | output_f = self.special_layer_f(o2, o3) * o1 |
| | |
| | |
| | x = x + output_f |
| | x = self.rms_norm1(x) |
| | |
| | |
| | mlp_output = self.mlp(x) |
| | |
| | |
| | x = x + mlp_output |
| | x = self.rms_norm2(x) |
| | |
| | return x |
| |
|
| |
|
| |
|
| | class RecombinationTransformerForCausalLM(PreTrainedModel): |
| | config_class = RecombinationTransformerConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim) |
| | self.layers = nn.ModuleList([ |
| | RecombinationTransformerLayer(config.embed_dim, config.num_heads) for _ in range(config.num_layers) |
| | ]) |
| | self.final_rms_norm = RMSNorm(config.embed_dim) |
| | self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False) |
| | |
| | def forward(self, input_ids, attention_mask=None, past_key_values=None, return_dict=None, **kwargs): |
| | if attention_mask is None: |
| | attention_mask = torch.ones(input_ids.shape, device=input_ids.device) |
| | |
| | batch_size, seq_length = input_ids.size() |
| | causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=input_ids.device)).unsqueeze(0).expand(batch_size, -1, -1) |
| | |
| | if past_key_values is None: |
| | past_key_values = [None] * len(self.layers) |
| | |
| | |
| | x = self.embed_tokens(input_ids) |
| | |
| | new_past_key_values = [] |
| | for i, layer in enumerate(self.layers): |
| | past_key_value = past_key_values[i] |
| | x = layer(x, attn_mask=causal_mask) |
| | new_past_key_values.append(x) |
| | |
| | |
| | x = self.final_rms_norm(x) |
| | |
| | |
| | logits = self.lm_head(x) |
| | |
| | if not return_dict: |
| | return (logits, new_past_key_values) |
| |
|
| | return CausalLMOutputWithPast(logits=logits, past_key_values=new_past_key_values) |
| |
|
| | |
| | def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs): |
| | if past: |
| | input_ids = input_ids[:, -1].unsqueeze(-1) |
| | |
| | if attention_mask is None: |
| | attention_mask = torch.ones(input_ids.shape, device=input_ids.device) |
| | |
| | return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} |
| | |
| | def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1): |
| | logits_processor = LogitsProcessorList() |
| | |
| | if min_length is not None: |
| | logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=self.config.eos_token_id)) |
| | |
| | outputs = super().generate( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_length=max_length, |
| | num_return_sequences=num_return_sequences, |
| | logits_processor=logits_processor |
| | ) |
| | |
| | return outputs |
| |
|