File size: 6,945 Bytes
77d636f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig
import torch
import torch.nn as nn
from transformers import AutoModel
class ResidualBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * 2),
nn.GELU(),
nn.Linear(dim * 2, dim)
)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.norm(x + self.net(x))
class ResidualAutoencoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# 1. Encoder (Frozen)
print(f"Loading Encoder: {cfg.encoder_name}...")
self.encoder = AutoModel.from_pretrained(cfg.encoder_name, trust_remote_code=True)
self.hidden_dim = self.encoder.config.hidden_size
for p in self.encoder.parameters(): p.requires_grad = False
# 2. Latent Processor (No Dimension Reduction)
# 保持 768 维度,只做特征整理
# 使用残差块保证梯度流
self.compressor = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim), # 可选
ResidualBlock(self.hidden_dim),
# ResidualBlock(self.hidden_dim) # 可选:加深
)
self.decompressor = nn.Sequential(
ResidualBlock(self.hidden_dim),
# ResidualBlock(self.hidden_dim),
nn.Linear(self.hidden_dim, self.hidden_dim)
)
# 3. Decoder (Pretrained)
print(f"Loading Decoder: {cfg.encoder_name}...")
self.decoder = AutoModel.from_pretrained(cfg.encoder_name, trust_remote_code=True)
self.decoder.config.is_decoder = False
# 4. Head
self.lm_head = nn.Linear(self.hidden_dim, self.encoder.config.vocab_size, bias=False)
with torch.no_grad():
self.lm_head.weight.copy_(self.encoder.embeddings.word_embeddings.weight)
self.lm_head.weight.requires_grad = True
def encode(self, input_ids, attention_mask):
with torch.no_grad():
enc_out = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
return self.compressor(enc_out)
def decode(self, z, attention_mask):
h = self.decompressor(z)
dec_out = self.decoder(inputs_embeds=h, attention_mask=attention_mask).last_hidden_state
return self.lm_head(dec_out)
def forward(self, input_ids, attention_mask):
z = self.encode(input_ids, attention_mask)
logits = self.decode(z, attention_mask)
return logits, z
class ReshapedAutoencoder(nn.Module):
"""
Sequence-to-Sequence Autoencoder with Spherical Latent Space.
Logic: Token -> Jina -> Linear -> Linear -> Decoder -> Token
"""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.latent_scale = getattr(cfg,"latent_scale",10.0)
# 1. Encoder (Frozen Jina)
print(f"Loading Pretrained Encoder: {cfg.encoder_name}...")
# self.encoder = AutoModel.from_pretrained(cfg.encoder_name,local_files_only=True, trust_remote_code=False)
self.encoder = AutoModel.from_pretrained(cfg.encoder_name,trust_remote_code=True)
self.hidden_dim = self.encoder.config.hidden_size
self.vocab_size = self.encoder.config.vocab_size
# 冻结 Encoder 参数
for param in self.encoder.parameters():
param.requires_grad = False
# 放弃强制 Normalize,使用 LayerNorm 进行“软约束”
# 结构: Hidden -> Project -> LayerNorm -> Latent
self.compress = nn.Sequential(
nn.Linear(self.hidden_dim, cfg.latent_dim),
nn.GELU(),
nn.Linear(cfg.latent_dim, cfg.latent_dim),
nn.LayerNorm(cfg.latent_dim) # 关键:让 latent 保持稳定分布,利于 Flow
)
# 3. Decompressor
self.decompress = nn.Sequential(
nn.Linear(cfg.latent_dim, self.hidden_dim),
nn.GELU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LayerNorm(self.hidden_dim)
)
# 4. Decoder (Pretrained!)
# <--- load from pretaining Config --->
print(f"Loading Pretrained Decoder: {cfg.encoder_name}...")
# self.decoder = AutoModel.from_pretrained(cfg.encoder_name,local_files_only=True,trust_remote_code=False)
self.decoder = AutoModel.from_pretrained(cfg.encoder_name,trust_remote_code=True)
# for BERT,is_decoder=False 双向 Attention,这正是 NAR 需要的
# 不需要 causal mask
self.decoder.config.is_decoder = False
# 5. Output Head (Trainable)
# 初始化为 Encoder 的 Embedding,但允许训练
self.lm_head = nn.Linear(self.hidden_dim, self.encoder.config.vocab_size, bias=False)
with torch.no_grad():
self.lm_head.weight.copy_(self.encoder.embeddings.word_embeddings.weight)
# 允许微调,以适应 decoder 输出的偏差
self.lm_head.weight.requires_grad = True
def encode(self, input_ids, attention_mask):
"""
Input: [B, L]
Output: [B, L, Latent_Dim]
"""
with torch.no_grad():
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Compression
z = self.compress(outputs.last_hidden_state) # [B, L, 768]
## increase the scale
# z = z * self.latent_scale
return z
## 需要传入attention-mask 但是这里的疑问是对于推理没有mask 怎么办,看上去也没有判断eos
def decode(self, latents,attention_mask=None):
"""
Input: [B, L, Latent_Dim]
Output: [B, L, Vocab]
"""
## back to the original scale
# latents = latents / self.latent_scale
# 1. Decompress (back to Hidden Size)
hidden = self.decompress(latents)
# 2. Backbone Forward (通过 inputs_embeds 注入)
# AutoModel 会自动处理 mask (NAR 模式下通常是全向注意力)
decoder_outputs = self.decoder(
inputs_embeds=hidden,
attention_mask=attention_mask
)
sequence_output = decoder_outputs.last_hidden_state
# 3. Logits
return self.lm_head(sequence_output)
# def forward(self, input_ids, attention_mask):
# z = self.encode(input_ids, attention_mask)
# logits= self.decode(z, attention_mask=attention_mask)
# return logits, z
def forward(self, input_ids, encoder_mask, decoder_mask=None):
if decoder_mask is None:
decoder_mask = encoder_mask
z = self.encode(input_ids, encoder_mask)
logits = self.decode(z, attention_mask=decoder_mask)
return logits, z |