import torch import torch.nn as nn from transformers import PreTrainedModel from transformers import LogitsProcessorList, StoppingCriteriaList, MaxLengthCriteria, MinLengthLogitsProcessor from transformers.modeling_outputs import CausalLMOutputWithPast from .configure import RecombinationTransformerConfig class MaskedSelfAttentionLayer(nn.Module): def __init__(self, embed_dim, num_heads): super(MaskedSelfAttentionLayer, self).__init__() self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) def forward(self, q, k, v, attn_mask=None): attn_output, _ = self.multihead_attn(q, k, v, attn_mask=attn_mask) return attn_output class FcLayer(nn.Module): def __init__(self, input_dim, output_dim): super(FcLayer, self).__init__() self.fc = nn.Linear(input_dim, output_dim) def forward(self, x): return self.fc(x) class SwishGLU(nn.Module): def __init__(self, input_dim): super(SwishGLU, self).__init__() self.fc1 = nn.Linear(input_dim, input_dim) self.fc2 = nn.Linear(input_dim, input_dim) def forward(self, x): return torch.sigmoid(self.fc1(x)) * self.fc2(x) class SpecialLayerF(nn.Module): def __init__(self, input_dim): super(SpecialLayerF, self).__init__() self.proj_up = nn.Linear(input_dim, input_dim) self.proj_gate = SwishGLU(input_dim) def forward(self, o2, o3): cross_product = o2 * o3 proj_up_output = self.proj_up(cross_product) proj_gate_output = self.proj_gate(cross_product) return proj_up_output * proj_gate_output class RMSNorm(nn.Module): def __init__(self, embed_dim, eps=1e-8): super(RMSNorm, self).__init__() self.embed_dim = embed_dim self.eps = eps self.scale = nn.Parameter(torch.ones(embed_dim)) def forward(self, x): norm = x.norm(2, dim=-1, keepdim=True) rms_norm = x / (norm + self.eps) return self.scale * rms_norm class MLP(nn.Module): def __init__(self, input_dim, hidden_dim): super(MLP, self).__init__() self.up_proj = nn.Linear(input_dim, hidden_dim) self.gate_proj = nn.Linear(input_dim, hidden_dim) self.act = SwishGLU(hidden_dim) self.down_proj = nn.Linear(hidden_dim, input_dim) def forward(self, x): return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) class RecombinationTransformerLayer(nn.Module): def __init__(self, embed_dim, num_heads): super(RecombinationTransformerLayer, self).__init__() self.num_heads = num_heads # First self-attention layer self.self_attention_1 = MaskedSelfAttentionLayer(embed_dim, num_heads) self.fc_q = FcLayer(embed_dim, embed_dim) self.fc_k = FcLayer(embed_dim, embed_dim) self.fc_v = FcLayer(embed_dim, embed_dim) # Second self-attention layer self.self_attention_2 = MaskedSelfAttentionLayer(embed_dim, num_heads) self.fc_qc = FcLayer(embed_dim, embed_dim) self.fc_kb = FcLayer(embed_dim, embed_dim) self.fc_vb = FcLayer(embed_dim, embed_dim) # Third self-attention layer self.self_attention_3 = MaskedSelfAttentionLayer(embed_dim, num_heads) # Special layer F self.special_layer_f = SpecialLayerF(embed_dim) # MLP layer self.mlp = MLP(embed_dim, embed_dim * 4) self.rms_norm1 = RMSNorm(embed_dim) self.rms_norm2 = RMSNorm(embed_dim) def forward(self, x, attn_mask=None): batch_size, seq_length, _ = x.size() if attn_mask is not None: # Reshape the attention mask to (batch_size * num_heads, seq_length, seq_length) attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(batch_size * self.num_heads, seq_length, seq_length) # First self-attention block q1 = self.fc_q(x).transpose(0, 1) k1 = self.fc_k(x).transpose(0, 1) v1 = self.fc_v(x).transpose(0, 1) o1 = self.self_attention_1(q1, k1, v1, attn_mask=attn_mask).transpose(0, 1) # Second self-attention block q2 = q1 k2 = self.fc_kb(o1).transpose(0, 1) v2 = self.fc_vb(o1).transpose(0, 1) o2 = self.self_attention_2(q2, k2, v2, attn_mask=attn_mask).transpose(0, 1) # Third self-attention block q3 = self.fc_qc(o1).transpose(0, 1) k3 = k1 v3 = v1 o3 = self.self_attention_3(q3, k3, v3, attn_mask=attn_mask).transpose(0, 1) # Special layer F output_f = self.special_layer_f(o2, o3) * o1 # Add & Norm x = x + output_f x = self.rms_norm1(x) # MLP block mlp_output = self.mlp(x) # Add & Norm x = x + mlp_output x = self.rms_norm2(x) return x class RecombinationTransformerForCausalLM(PreTrainedModel): config_class = RecombinationTransformerConfig def __init__(self, config): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim) self.layers = nn.ModuleList([ RecombinationTransformerLayer(config.embed_dim, config.num_heads) for _ in range(config.num_layers) ]) self.final_rms_norm = RMSNorm(config.embed_dim) self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False) def forward(self, input_ids, attention_mask=None, past_key_values=None, return_dict=None, **kwargs): if attention_mask is None: attention_mask = torch.ones(input_ids.shape, device=input_ids.device) batch_size, seq_length = input_ids.size() causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=input_ids.device)).unsqueeze(0).expand(batch_size, -1, -1) if past_key_values is None: past_key_values = [None] * len(self.layers) # Embedding x = self.embed_tokens(input_ids) new_past_key_values = [] for i, layer in enumerate(self.layers): past_key_value = past_key_values[i] x = layer(x, attn_mask=causal_mask) new_past_key_values.append(x) # Final normalization x = self.final_rms_norm(x) # LM head logits = self.lm_head(x) if not return_dict: return (logits, new_past_key_values) return CausalLMOutputWithPast(logits=logits, past_key_values=new_past_key_values) def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs): if past: input_ids = input_ids[:, -1].unsqueeze(-1) if attention_mask is None: attention_mask = torch.ones(input_ids.shape, device=input_ids.device) return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1): logits_processor = LogitsProcessorList() if min_length is not None: logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=self.config.eos_token_id)) outputs = super().generate( input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, num_return_sequences=num_return_sequences, logits_processor=logits_processor ) return outputs