| """AnCoder: Anchored DLM with Qwen3 backbone.""" |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| from transformers import Qwen3Config, Qwen3ForCausalLM, PretrainedConfig, PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
|
|
|
|
| class BiQwen3Config(Qwen3Config): |
| model_type = "biqwen3" |
|
|
|
|
| class BiQwen3(Qwen3ForCausalLM): |
| config_class = BiQwen3Config |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| for layer in self.model.layers: |
| layer.self_attn.is_causal = False |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| compute_logits: bool = True, |
| **kwargs, |
| ) -> ModelOutput: |
| if (input_ids is None) == (inputs_embeds is None): |
| raise ValueError("Must provide exactly one of input_ids or inputs_embeds") |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.model.embed_tokens(input_ids) |
|
|
| |
| if attention_mask is None or attention_mask.all(): |
| attn_mask = {"full_attention": None} |
| else: |
| |
| B, L, _ = inputs_embeds.shape |
| mask_4d = attention_mask.reshape(B, 1, 1, L).expand(B, 1, L, L) |
| mask_4d = (1.0 - mask_4d.to(inputs_embeds.dtype)) * -10000 |
| attn_mask = {"full_attention": mask_4d} |
|
|
| out = self.model(inputs_embeds=inputs_embeds, attention_mask=attn_mask, use_cache=False) |
| hidden = out.last_hidden_state |
| logits = self.lm_head(hidden) if compute_logits else None |
| return ModelOutput(hidden=hidden, logits=logits) |
|
|
|
|
| @dataclass |
| class AnCoderOutput(ModelOutput): |
| anchor_hidden: torch.FloatTensor | None = None |
| anchor_logits: torch.FloatTensor | None = None |
| denoiser_hidden: torch.FloatTensor | None = None |
| logits: torch.FloatTensor | None = None |
|
|
|
|
| class AnCoderConfig(PretrainedConfig): |
| model_type = "ancoder" |
|
|
| def __init__( |
| self, |
| anchor_config=None, |
| denoiser_config=None, |
| shift_logits: bool = True, |
| bos_token_id: int | None = 151644, |
| **kwargs, |
| ): |
| |
| kwargs.setdefault("tie_word_embeddings", True) |
| super().__init__(bos_token_id=bos_token_id, **kwargs) |
| self.anchor_config = anchor_config |
| self.denoiser_config = denoiser_config |
| self.shift_logits = shift_logits |
|
|
|
|
| class AnCoder(PreTrainedModel): |
| config_class = AnCoderConfig |
| supports_gradient_checkpointing = True |
|
|
| |
| _tied_weights_keys = { |
| "anchor.lm_head.weight": "anchor.model.embed_tokens.weight", |
| "denoiser.model.embed_tokens.weight": "anchor.model.embed_tokens.weight", |
| "denoiser.lm_head.weight": "anchor.model.embed_tokens.weight", |
| } |
|
|
| def __init__(self, config: AnCoderConfig, anchor=None, denoiser=None): |
| super().__init__(config) |
| self.anchor = anchor or BiQwen3(BiQwen3Config(**config.anchor_config)) |
| self.denoiser = denoiser or BiQwen3(BiQwen3Config(**config.denoiser_config)) |
| self.tie_weights() |
|
|
| def tie_weights(self): |
| |
| self.anchor.lm_head.weight = self.anchor.model.embed_tokens.weight |
| self.denoiser.model.embed_tokens.weight = self.anchor.model.embed_tokens.weight |
| self.denoiser.lm_head.weight = self.anchor.model.embed_tokens.weight |
|
|
| def get_input_embeddings(self) -> torch.nn.Embedding: |
| return self.anchor.model.embed_tokens |
|
|
| def get_output_embeddings(self) -> torch.nn.Linear: |
| return self.denoiser.lm_head |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: torch.Tensor | None = None, |
| compute_anchor_logits: bool = True, |
| **kwargs, |
| ) -> AnCoderOutput: |
| """When shift_logits=True, BOS is prepended so anchor->denoiser run on |
| (B, L+1, *); hidden is sliced to (B, L, d) before lm_head, so logits |
| emerge as (B, L, V) directly without materializing (B, L+1, V). |
| """ |
| if self.config.shift_logits: |
| B = input_ids.shape[0] |
| bos_id = self.config.bos_token_id |
| if bos_id is None: |
| raise ValueError("shift_logits=True requires bos_token_id on the config") |
| bos = torch.full((B, 1), bos_id, dtype=input_ids.dtype, device=input_ids.device) |
| input_ids_ = torch.cat([bos, input_ids], dim=1) |
| if attention_mask is not None: |
| ones = torch.ones((B, 1), dtype=attention_mask.dtype, device=attention_mask.device) |
| attention_mask = torch.cat([ones, attention_mask], dim=1) |
| else: |
| input_ids_ = input_ids |
|
|
| anchor_out = self.anchor(input_ids=input_ids_, attention_mask=attention_mask, compute_logits=False) |
| denoiser_out = self.denoiser(inputs_embeds=anchor_out.hidden, attention_mask=attention_mask, compute_logits=False) |
|
|
| end = -1 if self.config.shift_logits else None |
| anchor_hidden = anchor_out.hidden[:, :end, :].contiguous() |
| denoiser_hidden = denoiser_out.hidden[:, :end, :].contiguous() |
| return AnCoderOutput( |
| anchor_hidden=anchor_hidden, |
| anchor_logits=self.anchor.lm_head(anchor_hidden) if compute_anchor_logits else None, |
| denoiser_hidden=denoiser_hidden, |
| logits=self.denoiser.lm_head(denoiser_hidden), |
| ) |
|
|
|
|
| |
| AnCoderConfig.register_for_auto_class() |
| AnCoder.register_for_auto_class("AutoModel") |
|
|