File size: 7,716 Bytes
aacfb7d 9bfe061 d10f047 aacfb7d b773bbf aacfb7d e53fb05 aacfb7d 57c1748 aacfb7d 57c1748 aacfb7d 57c1748 aacfb7d 57c1748 aacfb7d b524ec2 57c1748 b524ec2 148f88a 5ba2653 148f88a 76e59d9 c3c3388 d541062 c3c3388 d541062 c3c3388 76e59d9 c3c3388 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers import LogitsProcessorList, StoppingCriteriaList, MaxLengthCriteria, MinLengthLogitsProcessor
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configure import RecombinationTransformerConfig
class MaskedSelfAttentionLayer(nn.Module):
def __init__(self, embed_dim, num_heads):
super(MaskedSelfAttentionLayer, self).__init__()
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
def forward(self, q, k, v, attn_mask=None):
attn_output, _ = self.multihead_attn(q, k, v, attn_mask=attn_mask)
return attn_output
class FcLayer(nn.Module):
def __init__(self, input_dim, output_dim):
super(FcLayer, self).__init__()
self.fc = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.fc(x)
class SwishGLU(nn.Module):
def __init__(self, input_dim):
super(SwishGLU, self).__init__()
self.fc1 = nn.Linear(input_dim, input_dim)
self.fc2 = nn.Linear(input_dim, input_dim)
def forward(self, x):
return torch.sigmoid(self.fc1(x)) * self.fc2(x)
class SpecialLayerF(nn.Module):
def __init__(self, input_dim):
super(SpecialLayerF, self).__init__()
self.proj_up = nn.Linear(input_dim, input_dim)
self.proj_gate = SwishGLU(input_dim)
def forward(self, o2, o3):
cross_product = o2 * o3
proj_up_output = self.proj_up(cross_product)
proj_gate_output = self.proj_gate(cross_product)
return proj_up_output * proj_gate_output
class RMSNorm(nn.Module):
def __init__(self, embed_dim, eps=1e-8):
super(RMSNorm, self).__init__()
self.embed_dim = embed_dim
self.eps = eps
self.scale = nn.Parameter(torch.ones(embed_dim))
def forward(self, x):
norm = x.norm(2, dim=-1, keepdim=True)
rms_norm = x / (norm + self.eps)
return self.scale * rms_norm
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(MLP, self).__init__()
self.up_proj = nn.Linear(input_dim, hidden_dim)
self.gate_proj = nn.Linear(input_dim, hidden_dim)
self.act = SwishGLU(hidden_dim)
self.down_proj = nn.Linear(hidden_dim, input_dim)
def forward(self, x):
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
class RecombinationTransformerLayer(nn.Module):
def __init__(self, embed_dim, num_heads):
super(RecombinationTransformerLayer, self).__init__()
self.num_heads = num_heads
# First self-attention layer
self.self_attention_1 = MaskedSelfAttentionLayer(embed_dim, num_heads)
self.fc_q = FcLayer(embed_dim, embed_dim)
self.fc_k = FcLayer(embed_dim, embed_dim)
self.fc_v = FcLayer(embed_dim, embed_dim)
# Second self-attention layer
self.self_attention_2 = MaskedSelfAttentionLayer(embed_dim, num_heads)
self.fc_qc = FcLayer(embed_dim, embed_dim)
self.fc_kb = FcLayer(embed_dim, embed_dim)
self.fc_vb = FcLayer(embed_dim, embed_dim)
# Third self-attention layer
self.self_attention_3 = MaskedSelfAttentionLayer(embed_dim, num_heads)
# Special layer F
self.special_layer_f = SpecialLayerF(embed_dim)
# MLP layer
self.mlp = MLP(embed_dim, embed_dim * 4)
self.rms_norm1 = RMSNorm(embed_dim)
self.rms_norm2 = RMSNorm(embed_dim)
def forward(self, x, attn_mask=None):
batch_size, seq_length, _ = x.size()
if attn_mask is not None:
# Reshape the attention mask to (batch_size * num_heads, seq_length, seq_length)
attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(batch_size * self.num_heads, seq_length, seq_length)
# First self-attention block
q1 = self.fc_q(x).transpose(0, 1)
k1 = self.fc_k(x).transpose(0, 1)
v1 = self.fc_v(x).transpose(0, 1)
o1 = self.self_attention_1(q1, k1, v1, attn_mask=attn_mask).transpose(0, 1)
# Second self-attention block
q2 = q1
k2 = self.fc_kb(o1).transpose(0, 1)
v2 = self.fc_vb(o1).transpose(0, 1)
o2 = self.self_attention_2(q2, k2, v2, attn_mask=attn_mask).transpose(0, 1)
# Third self-attention block
q3 = self.fc_qc(o1).transpose(0, 1)
k3 = k1
v3 = v1
o3 = self.self_attention_3(q3, k3, v3, attn_mask=attn_mask).transpose(0, 1)
# Special layer F
output_f = self.special_layer_f(o2, o3) * o1
# Add & Norm
x = x + output_f
x = self.rms_norm1(x)
# MLP block
mlp_output = self.mlp(x)
# Add & Norm
x = x + mlp_output
x = self.rms_norm2(x)
return x
class RecombinationTransformerForCausalLM(PreTrainedModel):
config_class = RecombinationTransformerConfig
def __init__(self, config):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim)
self.layers = nn.ModuleList([
RecombinationTransformerLayer(config.embed_dim, config.num_heads) for _ in range(config.num_layers)
])
self.final_rms_norm = RMSNorm(config.embed_dim)
self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
def forward(self, input_ids, attention_mask=None, past_key_values=None, return_dict=None, **kwargs):
if attention_mask is None:
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
batch_size, seq_length = input_ids.size()
causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=input_ids.device)).unsqueeze(0).expand(batch_size, -1, -1)
if past_key_values is None:
past_key_values = [None] * len(self.layers)
# Embedding
x = self.embed_tokens(input_ids)
new_past_key_values = []
for i, layer in enumerate(self.layers):
past_key_value = past_key_values[i]
x = layer(x, attn_mask=causal_mask)
new_past_key_values.append(x)
# Final normalization
x = self.final_rms_norm(x)
# LM head
logits = self.lm_head(x)
if not return_dict:
return (logits, new_past_key_values)
return CausalLMOutputWithPast(logits=logits, past_key_values=new_past_key_values)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if attention_mask is None:
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1):
logits_processor = LogitsProcessorList()
if min_length is not None:
logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=self.config.eos_token_id))
outputs = super().generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=num_return_sequences,
logits_processor=logits_processor
)
return outputs
|