import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoConfig import torch import torch.nn as nn from transformers import AutoModel class ResidualBlock(nn.Module): def __init__(self, dim): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * 2), nn.GELU(), nn.Linear(dim * 2, dim) ) self.norm = nn.LayerNorm(dim) def forward(self, x): return self.norm(x + self.net(x)) class ResidualAutoencoder(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg # 1. Encoder (Frozen) print(f"Loading Encoder: {cfg.encoder_name}...") self.encoder = AutoModel.from_pretrained(cfg.encoder_name, trust_remote_code=True) self.hidden_dim = self.encoder.config.hidden_size for p in self.encoder.parameters(): p.requires_grad = False # 2. Latent Processor (No Dimension Reduction) # 保持 768 维度,只做特征整理 # 使用残差块保证梯度流 self.compressor = nn.Sequential( nn.Linear(self.hidden_dim, self.hidden_dim), # 可选 ResidualBlock(self.hidden_dim), # ResidualBlock(self.hidden_dim) # 可选:加深 ) self.decompressor = nn.Sequential( ResidualBlock(self.hidden_dim), # ResidualBlock(self.hidden_dim), nn.Linear(self.hidden_dim, self.hidden_dim) ) # 3. Decoder (Pretrained) print(f"Loading Decoder: {cfg.encoder_name}...") self.decoder = AutoModel.from_pretrained(cfg.encoder_name, trust_remote_code=True) self.decoder.config.is_decoder = False # 4. Head self.lm_head = nn.Linear(self.hidden_dim, self.encoder.config.vocab_size, bias=False) with torch.no_grad(): self.lm_head.weight.copy_(self.encoder.embeddings.word_embeddings.weight) self.lm_head.weight.requires_grad = True def encode(self, input_ids, attention_mask): with torch.no_grad(): enc_out = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state return self.compressor(enc_out) def decode(self, z, attention_mask): h = self.decompressor(z) dec_out = self.decoder(inputs_embeds=h, attention_mask=attention_mask).last_hidden_state return self.lm_head(dec_out) def forward(self, input_ids, attention_mask): z = self.encode(input_ids, attention_mask) logits = self.decode(z, attention_mask) return logits, z class ReshapedAutoencoder(nn.Module): """ Sequence-to-Sequence Autoencoder with Spherical Latent Space. Logic: Token -> Jina -> Linear -> Linear -> Decoder -> Token """ def __init__(self, cfg): super().__init__() self.cfg = cfg self.latent_scale = getattr(cfg,"latent_scale",10.0) # 1. Encoder (Frozen Jina) print(f"Loading Pretrained Encoder: {cfg.encoder_name}...") # self.encoder = AutoModel.from_pretrained(cfg.encoder_name,local_files_only=True, trust_remote_code=False) self.encoder = AutoModel.from_pretrained(cfg.encoder_name,trust_remote_code=True) self.hidden_dim = self.encoder.config.hidden_size self.vocab_size = self.encoder.config.vocab_size # 冻结 Encoder 参数 for param in self.encoder.parameters(): param.requires_grad = False # 放弃强制 Normalize,使用 LayerNorm 进行“软约束” # 结构: Hidden -> Project -> LayerNorm -> Latent self.compress = nn.Sequential( nn.Linear(self.hidden_dim, cfg.latent_dim), nn.GELU(), nn.Linear(cfg.latent_dim, cfg.latent_dim), nn.LayerNorm(cfg.latent_dim) # 关键:让 latent 保持稳定分布,利于 Flow ) # 3. Decompressor self.decompress = nn.Sequential( nn.Linear(cfg.latent_dim, self.hidden_dim), nn.GELU(), nn.Linear(self.hidden_dim, self.hidden_dim), nn.LayerNorm(self.hidden_dim) ) # 4. Decoder (Pretrained!) # <--- load from pretaining Config ---> print(f"Loading Pretrained Decoder: {cfg.encoder_name}...") # self.decoder = AutoModel.from_pretrained(cfg.encoder_name,local_files_only=True,trust_remote_code=False) self.decoder = AutoModel.from_pretrained(cfg.encoder_name,trust_remote_code=True) # for BERT,is_decoder=False 双向 Attention,这正是 NAR 需要的 # 不需要 causal mask self.decoder.config.is_decoder = False # 5. Output Head (Trainable) # 初始化为 Encoder 的 Embedding,但允许训练 self.lm_head = nn.Linear(self.hidden_dim, self.encoder.config.vocab_size, bias=False) with torch.no_grad(): self.lm_head.weight.copy_(self.encoder.embeddings.word_embeddings.weight) # 允许微调,以适应 decoder 输出的偏差 self.lm_head.weight.requires_grad = True def encode(self, input_ids, attention_mask): """ Input: [B, L] Output: [B, L, Latent_Dim] """ with torch.no_grad(): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) # Compression z = self.compress(outputs.last_hidden_state) # [B, L, 768] ## increase the scale # z = z * self.latent_scale return z ## 需要传入attention-mask 但是这里的疑问是对于推理没有mask 怎么办,看上去也没有判断eos def decode(self, latents,attention_mask=None): """ Input: [B, L, Latent_Dim] Output: [B, L, Vocab] """ ## back to the original scale # latents = latents / self.latent_scale # 1. Decompress (back to Hidden Size) hidden = self.decompress(latents) # 2. Backbone Forward (通过 inputs_embeds 注入) # AutoModel 会自动处理 mask (NAR 模式下通常是全向注意力) decoder_outputs = self.decoder( inputs_embeds=hidden, attention_mask=attention_mask ) sequence_output = decoder_outputs.last_hidden_state # 3. Logits return self.lm_head(sequence_output) # def forward(self, input_ids, attention_mask): # z = self.encode(input_ids, attention_mask) # logits= self.decode(z, attention_mask=attention_mask) # return logits, z def forward(self, input_ids, encoder_mask, decoder_mask=None): if decoder_mask is None: decoder_mask = encoder_mask z = self.encode(input_ids, encoder_mask) logits = self.decode(z, attention_mask=decoder_mask) return logits, z