Diff-Refine / src /models /autoencoder.py
2ira's picture
Add files using upload-large-folder tool
77d636f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig
import torch
import torch.nn as nn
from transformers import AutoModel
class ResidualBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * 2),
nn.GELU(),
nn.Linear(dim * 2, dim)
)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.norm(x + self.net(x))
class ResidualAutoencoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# 1. Encoder (Frozen)
print(f"Loading Encoder: {cfg.encoder_name}...")
self.encoder = AutoModel.from_pretrained(cfg.encoder_name, trust_remote_code=True)
self.hidden_dim = self.encoder.config.hidden_size
for p in self.encoder.parameters(): p.requires_grad = False
# 2. Latent Processor (No Dimension Reduction)
# 保持 768 维度,只做特征整理
# 使用残差块保证梯度流
self.compressor = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim), # 可选
ResidualBlock(self.hidden_dim),
# ResidualBlock(self.hidden_dim) # 可选:加深
)
self.decompressor = nn.Sequential(
ResidualBlock(self.hidden_dim),
# ResidualBlock(self.hidden_dim),
nn.Linear(self.hidden_dim, self.hidden_dim)
)
# 3. Decoder (Pretrained)
print(f"Loading Decoder: {cfg.encoder_name}...")
self.decoder = AutoModel.from_pretrained(cfg.encoder_name, trust_remote_code=True)
self.decoder.config.is_decoder = False
# 4. Head
self.lm_head = nn.Linear(self.hidden_dim, self.encoder.config.vocab_size, bias=False)
with torch.no_grad():
self.lm_head.weight.copy_(self.encoder.embeddings.word_embeddings.weight)
self.lm_head.weight.requires_grad = True
def encode(self, input_ids, attention_mask):
with torch.no_grad():
enc_out = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
return self.compressor(enc_out)
def decode(self, z, attention_mask):
h = self.decompressor(z)
dec_out = self.decoder(inputs_embeds=h, attention_mask=attention_mask).last_hidden_state
return self.lm_head(dec_out)
def forward(self, input_ids, attention_mask):
z = self.encode(input_ids, attention_mask)
logits = self.decode(z, attention_mask)
return logits, z
class ReshapedAutoencoder(nn.Module):
"""
Sequence-to-Sequence Autoencoder with Spherical Latent Space.
Logic: Token -> Jina -> Linear -> Linear -> Decoder -> Token
"""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.latent_scale = getattr(cfg,"latent_scale",10.0)
# 1. Encoder (Frozen Jina)
print(f"Loading Pretrained Encoder: {cfg.encoder_name}...")
# self.encoder = AutoModel.from_pretrained(cfg.encoder_name,local_files_only=True, trust_remote_code=False)
self.encoder = AutoModel.from_pretrained(cfg.encoder_name,trust_remote_code=True)
self.hidden_dim = self.encoder.config.hidden_size
self.vocab_size = self.encoder.config.vocab_size
# 冻结 Encoder 参数
for param in self.encoder.parameters():
param.requires_grad = False
# 放弃强制 Normalize,使用 LayerNorm 进行“软约束”
# 结构: Hidden -> Project -> LayerNorm -> Latent
self.compress = nn.Sequential(
nn.Linear(self.hidden_dim, cfg.latent_dim),
nn.GELU(),
nn.Linear(cfg.latent_dim, cfg.latent_dim),
nn.LayerNorm(cfg.latent_dim) # 关键:让 latent 保持稳定分布,利于 Flow
)
# 3. Decompressor
self.decompress = nn.Sequential(
nn.Linear(cfg.latent_dim, self.hidden_dim),
nn.GELU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LayerNorm(self.hidden_dim)
)
# 4. Decoder (Pretrained!)
# <--- load from pretaining Config --->
print(f"Loading Pretrained Decoder: {cfg.encoder_name}...")
# self.decoder = AutoModel.from_pretrained(cfg.encoder_name,local_files_only=True,trust_remote_code=False)
self.decoder = AutoModel.from_pretrained(cfg.encoder_name,trust_remote_code=True)
# for BERT,is_decoder=False 双向 Attention,这正是 NAR 需要的
# 不需要 causal mask
self.decoder.config.is_decoder = False
# 5. Output Head (Trainable)
# 初始化为 Encoder 的 Embedding,但允许训练
self.lm_head = nn.Linear(self.hidden_dim, self.encoder.config.vocab_size, bias=False)
with torch.no_grad():
self.lm_head.weight.copy_(self.encoder.embeddings.word_embeddings.weight)
# 允许微调,以适应 decoder 输出的偏差
self.lm_head.weight.requires_grad = True
def encode(self, input_ids, attention_mask):
"""
Input: [B, L]
Output: [B, L, Latent_Dim]
"""
with torch.no_grad():
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Compression
z = self.compress(outputs.last_hidden_state) # [B, L, 768]
## increase the scale
# z = z * self.latent_scale
return z
## 需要传入attention-mask 但是这里的疑问是对于推理没有mask 怎么办,看上去也没有判断eos
def decode(self, latents,attention_mask=None):
"""
Input: [B, L, Latent_Dim]
Output: [B, L, Vocab]
"""
## back to the original scale
# latents = latents / self.latent_scale
# 1. Decompress (back to Hidden Size)
hidden = self.decompress(latents)
# 2. Backbone Forward (通过 inputs_embeds 注入)
# AutoModel 会自动处理 mask (NAR 模式下通常是全向注意力)
decoder_outputs = self.decoder(
inputs_embeds=hidden,
attention_mask=attention_mask
)
sequence_output = decoder_outputs.last_hidden_state
# 3. Logits
return self.lm_head(sequence_output)
# def forward(self, input_ids, attention_mask):
# z = self.encode(input_ids, attention_mask)
# logits= self.decode(z, attention_mask=attention_mask)
# return logits, z
def forward(self, input_ids, encoder_mask, decoder_mask=None):
if decoder_mask is None:
decoder_mask = encoder_mask
z = self.encode(input_ids, encoder_mask)
logits = self.decode(z, attention_mask=decoder_mask)
return logits, z