File size: 6,092 Bytes
4a4735e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""AnCoder: Anchored DLM with Qwen3 backbone."""

from dataclasses import dataclass

import torch
from transformers import Qwen3Config, Qwen3ForCausalLM, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import ModelOutput


class BiQwen3Config(Qwen3Config):
    model_type = "biqwen3"


class BiQwen3(Qwen3ForCausalLM):
    config_class = BiQwen3Config

    def __init__(self, config):
        super().__init__(config)
        for layer in self.model.layers:
            layer.self_attn.is_causal = False

    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        compute_logits: bool = True,
        **kwargs,
    ) -> ModelOutput:
        if (input_ids is None) == (inputs_embeds is None):
            raise ValueError("Must provide exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

        # Dict-keyed mask bypasses HF's auto-causal construction.
        if attention_mask is None or attention_mask.all():
            attn_mask = {"full_attention": None}
        else:
            # 4D additive mask: 0=attend, -10000=ignore.
            B, L, _ = inputs_embeds.shape
            mask_4d = attention_mask.reshape(B, 1, 1, L).expand(B, 1, L, L)
            mask_4d = (1.0 - mask_4d.to(inputs_embeds.dtype)) * -10000
            attn_mask = {"full_attention": mask_4d}

        out = self.model(inputs_embeds=inputs_embeds, attention_mask=attn_mask, use_cache=False)
        hidden = out.last_hidden_state
        logits = self.lm_head(hidden) if compute_logits else None
        return ModelOutput(hidden=hidden, logits=logits)


@dataclass
class AnCoderOutput(ModelOutput):
    anchor_hidden: torch.FloatTensor | None = None    # (B, L, d)
    anchor_logits: torch.FloatTensor | None = None    # (B, L, V), None when compute_anchor_logits=False
    denoiser_hidden: torch.FloatTensor | None = None  # (B, L, d)
    logits: torch.FloatTensor | None = None           # (B, L, V), final predictions


class AnCoderConfig(PretrainedConfig):
    model_type = "ancoder"

    def __init__(
        self,
        anchor_config=None,
        denoiser_config=None,
        shift_logits: bool = True,
        bos_token_id: int | None = 151644,  # <|im_start|>, distinct from PAD=151643
        **kwargs,
    ):
        # Ensures that save_pretrained deduplicates and from_pretrained re-ties.
        kwargs.setdefault("tie_word_embeddings", True)
        super().__init__(bos_token_id=bos_token_id, **kwargs)
        self.anchor_config = anchor_config
        self.denoiser_config = denoiser_config
        self.shift_logits = shift_logits


class AnCoder(PreTrainedModel):
    config_class = AnCoderConfig
    supports_gradient_checkpointing = True

    # Maps shared params so that save_pretrained writes only one copy.
    _tied_weights_keys = {
        "anchor.lm_head.weight": "anchor.model.embed_tokens.weight",
        "denoiser.model.embed_tokens.weight": "anchor.model.embed_tokens.weight",
        "denoiser.lm_head.weight": "anchor.model.embed_tokens.weight",
    }

    def __init__(self, config: AnCoderConfig, anchor=None, denoiser=None):
        super().__init__(config)
        self.anchor = anchor or BiQwen3(BiQwen3Config(**config.anchor_config))
        self.denoiser = denoiser or BiQwen3(BiQwen3Config(**config.denoiser_config))
        self.tie_weights()

    def tie_weights(self):
        # Override: _tied_weights_keys is save-only; runtime tying done here.
        self.anchor.lm_head.weight = self.anchor.model.embed_tokens.weight
        self.denoiser.model.embed_tokens.weight = self.anchor.model.embed_tokens.weight
        self.denoiser.lm_head.weight = self.anchor.model.embed_tokens.weight

    def get_input_embeddings(self) -> torch.nn.Embedding:
        return self.anchor.model.embed_tokens

    def get_output_embeddings(self) -> torch.nn.Linear:
        return self.denoiser.lm_head

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.Tensor | None = None,
        compute_anchor_logits: bool = True,
        **kwargs,
    ) -> AnCoderOutput:
        """When shift_logits=True, BOS is prepended so anchor->denoiser run on
        (B, L+1, *); hidden is sliced to (B, L, d) before lm_head, so logits
        emerge as (B, L, V) directly without materializing (B, L+1, V).
        """
        if self.config.shift_logits:
            B = input_ids.shape[0]
            bos_id = self.config.bos_token_id
            if bos_id is None:
                raise ValueError("shift_logits=True requires bos_token_id on the config")
            bos = torch.full((B, 1), bos_id, dtype=input_ids.dtype, device=input_ids.device)
            input_ids_ = torch.cat([bos, input_ids], dim=1)                         # (B, L+1)
            if attention_mask is not None:
                ones = torch.ones((B, 1), dtype=attention_mask.dtype, device=attention_mask.device)
                attention_mask = torch.cat([ones, attention_mask], dim=1)           # (B, L+1)
        else:
            input_ids_ = input_ids

        anchor_out = self.anchor(input_ids=input_ids_, attention_mask=attention_mask, compute_logits=False)
        denoiser_out = self.denoiser(inputs_embeds=anchor_out.hidden, attention_mask=attention_mask, compute_logits=False)

        end = -1 if self.config.shift_logits else None
        anchor_hidden = anchor_out.hidden[:, :end, :].contiguous()
        denoiser_hidden = denoiser_out.hidden[:, :end, :].contiguous()
        return AnCoderOutput(
            anchor_hidden=anchor_hidden,
            anchor_logits=self.anchor.lm_head(anchor_hidden) if compute_anchor_logits else None,
            denoiser_hidden=denoiser_hidden,
            logits=self.denoiser.lm_head(denoiser_hidden),
        )


# Ensures that save_pretrained emits auto_map and copies this file.
AnCoderConfig.register_for_auto_class()
AnCoder.register_for_auto_class("AutoModel")