File size: 6,945 Bytes
77d636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig

import torch
import torch.nn as nn
from transformers import AutoModel

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.GELU(),
            nn.Linear(dim * 2, dim)
        )
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.norm(x + self.net(x))

class ResidualAutoencoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        # 1. Encoder (Frozen)
        print(f"Loading Encoder: {cfg.encoder_name}...")
        self.encoder = AutoModel.from_pretrained(cfg.encoder_name, trust_remote_code=True)
        self.hidden_dim = self.encoder.config.hidden_size
        for p in self.encoder.parameters(): p.requires_grad = False
            
        # 2. Latent Processor (No Dimension Reduction)
        # 保持 768 维度,只做特征整理
        # 使用残差块保证梯度流
        self.compressor = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim), # 可选
            ResidualBlock(self.hidden_dim),
            # ResidualBlock(self.hidden_dim) # 可选:加深
        )
        
        self.decompressor = nn.Sequential(
            ResidualBlock(self.hidden_dim),
            # ResidualBlock(self.hidden_dim),
            nn.Linear(self.hidden_dim, self.hidden_dim)
        )

        # 3. Decoder (Pretrained)
        print(f"Loading Decoder: {cfg.encoder_name}...")
        self.decoder = AutoModel.from_pretrained(cfg.encoder_name, trust_remote_code=True)
        self.decoder.config.is_decoder = False 
        
        # 4. Head
        self.lm_head = nn.Linear(self.hidden_dim, self.encoder.config.vocab_size, bias=False)
        with torch.no_grad():
            self.lm_head.weight.copy_(self.encoder.embeddings.word_embeddings.weight)
        self.lm_head.weight.requires_grad = True

    def encode(self, input_ids, attention_mask):
        with torch.no_grad():
            enc_out = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
        return self.compressor(enc_out)

    def decode(self, z, attention_mask):
        h = self.decompressor(z)
        dec_out = self.decoder(inputs_embeds=h, attention_mask=attention_mask).last_hidden_state
        return self.lm_head(dec_out)

    def forward(self, input_ids, attention_mask):
        z = self.encode(input_ids, attention_mask)
        logits = self.decode(z, attention_mask)
        return logits, z

class ReshapedAutoencoder(nn.Module):
    """
    Sequence-to-Sequence Autoencoder with Spherical Latent Space.
    Logic: Token -> Jina -> Linear -> Linear -> Decoder -> Token
    """
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.latent_scale = getattr(cfg,"latent_scale",10.0)
        
        # 1. Encoder (Frozen Jina)
        print(f"Loading Pretrained Encoder: {cfg.encoder_name}...")
        # self.encoder = AutoModel.from_pretrained(cfg.encoder_name,local_files_only=True, trust_remote_code=False)
        self.encoder = AutoModel.from_pretrained(cfg.encoder_name,trust_remote_code=True)
        self.hidden_dim = self.encoder.config.hidden_size
        self.vocab_size = self.encoder.config.vocab_size
        
        # 冻结 Encoder 参数
        for param in self.encoder.parameters():
            param.requires_grad = False
          
        # 放弃强制 Normalize,使用 LayerNorm 进行“软约束”
        # 结构: Hidden -> Project -> LayerNorm -> Latent
        self.compress = nn.Sequential(
            nn.Linear(self.hidden_dim, cfg.latent_dim),
            nn.GELU(),
            nn.Linear(cfg.latent_dim, cfg.latent_dim),
            nn.LayerNorm(cfg.latent_dim) # 关键:让 latent 保持稳定分布,利于 Flow
        )
        
        # 3. Decompressor
        self.decompress = nn.Sequential(
            nn.Linear(cfg.latent_dim, self.hidden_dim),
            nn.GELU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.LayerNorm(self.hidden_dim)
        )

        # 4. Decoder (Pretrained!)
        # <--- load from pretaining Config --->
        print(f"Loading Pretrained Decoder: {cfg.encoder_name}...")
        # self.decoder = AutoModel.from_pretrained(cfg.encoder_name,local_files_only=True,trust_remote_code=False)
        self.decoder = AutoModel.from_pretrained(cfg.encoder_name,trust_remote_code=True)
        
        
        # for BERT,is_decoder=False 双向 Attention,这正是 NAR 需要的
        # 不需要 causal mask
        self.decoder.config.is_decoder = False 
        
        # 5. Output Head (Trainable)
        # 初始化为 Encoder 的 Embedding,但允许训练
        self.lm_head = nn.Linear(self.hidden_dim, self.encoder.config.vocab_size, bias=False)
        with torch.no_grad():
            self.lm_head.weight.copy_(self.encoder.embeddings.word_embeddings.weight)
        # 允许微调,以适应 decoder 输出的偏差
        self.lm_head.weight.requires_grad = True 

    def encode(self, input_ids, attention_mask):
        """
        Input: [B, L]
        Output: [B, L, Latent_Dim] 
        """
        with torch.no_grad():
            outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        # Compression
        z = self.compress(outputs.last_hidden_state) # [B, L, 768]

        ## increase the scale
        # z = z * self.latent_scale
        
        return z

    ## 需要传入attention-mask 但是这里的疑问是对于推理没有mask 怎么办,看上去也没有判断eos
    def decode(self, latents,attention_mask=None):
        """
        Input: [B, L, Latent_Dim]
        Output: [B, L, Vocab]
        """
        ## back to the original scale
        # latents = latents / self.latent_scale
        # 1. Decompress (back to Hidden Size)
        hidden = self.decompress(latents)
        
        # 2. Backbone Forward (通过 inputs_embeds 注入)
        # AutoModel 会自动处理 mask (NAR 模式下通常是全向注意力)
        decoder_outputs = self.decoder(
            inputs_embeds=hidden,
            attention_mask=attention_mask
        )

        sequence_output = decoder_outputs.last_hidden_state

        # 3. Logits
        return self.lm_head(sequence_output)

    # def forward(self, input_ids, attention_mask):
    #     z = self.encode(input_ids, attention_mask)
    #     logits= self.decode(z, attention_mask=attention_mask)
    #     return logits, z
    def forward(self, input_ids, encoder_mask, decoder_mask=None):
        if decoder_mask is None:
            decoder_mask = encoder_mask
        z = self.encode(input_ids, encoder_mask)
        logits = self.decode(z, attention_mask=decoder_mask)
        return logits, z