AnCoder-1.0B-Base / modeling_ancoder.py
AntonXue's picture
Initial release: SWA-averaged Stage-1 endpoint (steps 46k-50k, 1k stride)
4a4735e verified
"""AnCoder: Anchored DLM with Qwen3 backbone."""
from dataclasses import dataclass
import torch
from transformers import Qwen3Config, Qwen3ForCausalLM, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
class BiQwen3Config(Qwen3Config):
model_type = "biqwen3"
class BiQwen3(Qwen3ForCausalLM):
config_class = BiQwen3Config
def __init__(self, config):
super().__init__(config)
for layer in self.model.layers:
layer.self_attn.is_causal = False
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
compute_logits: bool = True,
**kwargs,
) -> ModelOutput:
if (input_ids is None) == (inputs_embeds is None):
raise ValueError("Must provide exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
# Dict-keyed mask bypasses HF's auto-causal construction.
if attention_mask is None or attention_mask.all():
attn_mask = {"full_attention": None}
else:
# 4D additive mask: 0=attend, -10000=ignore.
B, L, _ = inputs_embeds.shape
mask_4d = attention_mask.reshape(B, 1, 1, L).expand(B, 1, L, L)
mask_4d = (1.0 - mask_4d.to(inputs_embeds.dtype)) * -10000
attn_mask = {"full_attention": mask_4d}
out = self.model(inputs_embeds=inputs_embeds, attention_mask=attn_mask, use_cache=False)
hidden = out.last_hidden_state
logits = self.lm_head(hidden) if compute_logits else None
return ModelOutput(hidden=hidden, logits=logits)
@dataclass
class AnCoderOutput(ModelOutput):
anchor_hidden: torch.FloatTensor | None = None # (B, L, d)
anchor_logits: torch.FloatTensor | None = None # (B, L, V), None when compute_anchor_logits=False
denoiser_hidden: torch.FloatTensor | None = None # (B, L, d)
logits: torch.FloatTensor | None = None # (B, L, V), final predictions
class AnCoderConfig(PretrainedConfig):
model_type = "ancoder"
def __init__(
self,
anchor_config=None,
denoiser_config=None,
shift_logits: bool = True,
bos_token_id: int | None = 151644, # <|im_start|>, distinct from PAD=151643
**kwargs,
):
# Ensures that save_pretrained deduplicates and from_pretrained re-ties.
kwargs.setdefault("tie_word_embeddings", True)
super().__init__(bos_token_id=bos_token_id, **kwargs)
self.anchor_config = anchor_config
self.denoiser_config = denoiser_config
self.shift_logits = shift_logits
class AnCoder(PreTrainedModel):
config_class = AnCoderConfig
supports_gradient_checkpointing = True
# Maps shared params so that save_pretrained writes only one copy.
_tied_weights_keys = {
"anchor.lm_head.weight": "anchor.model.embed_tokens.weight",
"denoiser.model.embed_tokens.weight": "anchor.model.embed_tokens.weight",
"denoiser.lm_head.weight": "anchor.model.embed_tokens.weight",
}
def __init__(self, config: AnCoderConfig, anchor=None, denoiser=None):
super().__init__(config)
self.anchor = anchor or BiQwen3(BiQwen3Config(**config.anchor_config))
self.denoiser = denoiser or BiQwen3(BiQwen3Config(**config.denoiser_config))
self.tie_weights()
def tie_weights(self):
# Override: _tied_weights_keys is save-only; runtime tying done here.
self.anchor.lm_head.weight = self.anchor.model.embed_tokens.weight
self.denoiser.model.embed_tokens.weight = self.anchor.model.embed_tokens.weight
self.denoiser.lm_head.weight = self.anchor.model.embed_tokens.weight
def get_input_embeddings(self) -> torch.nn.Embedding:
return self.anchor.model.embed_tokens
def get_output_embeddings(self) -> torch.nn.Linear:
return self.denoiser.lm_head
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None = None,
compute_anchor_logits: bool = True,
**kwargs,
) -> AnCoderOutput:
"""When shift_logits=True, BOS is prepended so anchor->denoiser run on
(B, L+1, *); hidden is sliced to (B, L, d) before lm_head, so logits
emerge as (B, L, V) directly without materializing (B, L+1, V).
"""
if self.config.shift_logits:
B = input_ids.shape[0]
bos_id = self.config.bos_token_id
if bos_id is None:
raise ValueError("shift_logits=True requires bos_token_id on the config")
bos = torch.full((B, 1), bos_id, dtype=input_ids.dtype, device=input_ids.device)
input_ids_ = torch.cat([bos, input_ids], dim=1) # (B, L+1)
if attention_mask is not None:
ones = torch.ones((B, 1), dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([ones, attention_mask], dim=1) # (B, L+1)
else:
input_ids_ = input_ids
anchor_out = self.anchor(input_ids=input_ids_, attention_mask=attention_mask, compute_logits=False)
denoiser_out = self.denoiser(inputs_embeds=anchor_out.hidden, attention_mask=attention_mask, compute_logits=False)
end = -1 if self.config.shift_logits else None
anchor_hidden = anchor_out.hidden[:, :end, :].contiguous()
denoiser_hidden = denoiser_out.hidden[:, :end, :].contiguous()
return AnCoderOutput(
anchor_hidden=anchor_hidden,
anchor_logits=self.anchor.lm_head(anchor_hidden) if compute_anchor_logits else None,
denoiser_hidden=denoiser_hidden,
logits=self.denoiser.lm_head(denoiser_hidden),
)
# Ensures that save_pretrained emits auto_map and copies this file.
AnCoderConfig.register_for_auto_class()
AnCoder.register_for_auto_class("AutoModel")