File size: 6,092 Bytes
4a4735e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | """AnCoder: Anchored DLM with Qwen3 backbone."""
from dataclasses import dataclass
import torch
from transformers import Qwen3Config, Qwen3ForCausalLM, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
class BiQwen3Config(Qwen3Config):
model_type = "biqwen3"
class BiQwen3(Qwen3ForCausalLM):
config_class = BiQwen3Config
def __init__(self, config):
super().__init__(config)
for layer in self.model.layers:
layer.self_attn.is_causal = False
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
compute_logits: bool = True,
**kwargs,
) -> ModelOutput:
if (input_ids is None) == (inputs_embeds is None):
raise ValueError("Must provide exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
# Dict-keyed mask bypasses HF's auto-causal construction.
if attention_mask is None or attention_mask.all():
attn_mask = {"full_attention": None}
else:
# 4D additive mask: 0=attend, -10000=ignore.
B, L, _ = inputs_embeds.shape
mask_4d = attention_mask.reshape(B, 1, 1, L).expand(B, 1, L, L)
mask_4d = (1.0 - mask_4d.to(inputs_embeds.dtype)) * -10000
attn_mask = {"full_attention": mask_4d}
out = self.model(inputs_embeds=inputs_embeds, attention_mask=attn_mask, use_cache=False)
hidden = out.last_hidden_state
logits = self.lm_head(hidden) if compute_logits else None
return ModelOutput(hidden=hidden, logits=logits)
@dataclass
class AnCoderOutput(ModelOutput):
anchor_hidden: torch.FloatTensor | None = None # (B, L, d)
anchor_logits: torch.FloatTensor | None = None # (B, L, V), None when compute_anchor_logits=False
denoiser_hidden: torch.FloatTensor | None = None # (B, L, d)
logits: torch.FloatTensor | None = None # (B, L, V), final predictions
class AnCoderConfig(PretrainedConfig):
model_type = "ancoder"
def __init__(
self,
anchor_config=None,
denoiser_config=None,
shift_logits: bool = True,
bos_token_id: int | None = 151644, # <|im_start|>, distinct from PAD=151643
**kwargs,
):
# Ensures that save_pretrained deduplicates and from_pretrained re-ties.
kwargs.setdefault("tie_word_embeddings", True)
super().__init__(bos_token_id=bos_token_id, **kwargs)
self.anchor_config = anchor_config
self.denoiser_config = denoiser_config
self.shift_logits = shift_logits
class AnCoder(PreTrainedModel):
config_class = AnCoderConfig
supports_gradient_checkpointing = True
# Maps shared params so that save_pretrained writes only one copy.
_tied_weights_keys = {
"anchor.lm_head.weight": "anchor.model.embed_tokens.weight",
"denoiser.model.embed_tokens.weight": "anchor.model.embed_tokens.weight",
"denoiser.lm_head.weight": "anchor.model.embed_tokens.weight",
}
def __init__(self, config: AnCoderConfig, anchor=None, denoiser=None):
super().__init__(config)
self.anchor = anchor or BiQwen3(BiQwen3Config(**config.anchor_config))
self.denoiser = denoiser or BiQwen3(BiQwen3Config(**config.denoiser_config))
self.tie_weights()
def tie_weights(self):
# Override: _tied_weights_keys is save-only; runtime tying done here.
self.anchor.lm_head.weight = self.anchor.model.embed_tokens.weight
self.denoiser.model.embed_tokens.weight = self.anchor.model.embed_tokens.weight
self.denoiser.lm_head.weight = self.anchor.model.embed_tokens.weight
def get_input_embeddings(self) -> torch.nn.Embedding:
return self.anchor.model.embed_tokens
def get_output_embeddings(self) -> torch.nn.Linear:
return self.denoiser.lm_head
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None = None,
compute_anchor_logits: bool = True,
**kwargs,
) -> AnCoderOutput:
"""When shift_logits=True, BOS is prepended so anchor->denoiser run on
(B, L+1, *); hidden is sliced to (B, L, d) before lm_head, so logits
emerge as (B, L, V) directly without materializing (B, L+1, V).
"""
if self.config.shift_logits:
B = input_ids.shape[0]
bos_id = self.config.bos_token_id
if bos_id is None:
raise ValueError("shift_logits=True requires bos_token_id on the config")
bos = torch.full((B, 1), bos_id, dtype=input_ids.dtype, device=input_ids.device)
input_ids_ = torch.cat([bos, input_ids], dim=1) # (B, L+1)
if attention_mask is not None:
ones = torch.ones((B, 1), dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([ones, attention_mask], dim=1) # (B, L+1)
else:
input_ids_ = input_ids
anchor_out = self.anchor(input_ids=input_ids_, attention_mask=attention_mask, compute_logits=False)
denoiser_out = self.denoiser(inputs_embeds=anchor_out.hidden, attention_mask=attention_mask, compute_logits=False)
end = -1 if self.config.shift_logits else None
anchor_hidden = anchor_out.hidden[:, :end, :].contiguous()
denoiser_hidden = denoiser_out.hidden[:, :end, :].contiguous()
return AnCoderOutput(
anchor_hidden=anchor_hidden,
anchor_logits=self.anchor.lm_head(anchor_hidden) if compute_anchor_logits else None,
denoiser_hidden=denoiser_hidden,
logits=self.denoiser.lm_head(denoiser_hidden),
)
# Ensures that save_pretrained emits auto_map and copies this file.
AnCoderConfig.register_for_auto_class()
AnCoder.register_for_auto_class("AutoModel")
|