"""AnCoder: Anchored DLM with Qwen3 backbone.""" from dataclasses import dataclass import torch from transformers import Qwen3Config, Qwen3ForCausalLM, PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import ModelOutput class BiQwen3Config(Qwen3Config): model_type = "biqwen3" class BiQwen3(Qwen3ForCausalLM): config_class = BiQwen3Config def __init__(self, config): super().__init__(config) for layer in self.model.layers: layer.self_attn.is_causal = False def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, inputs_embeds: torch.FloatTensor | None = None, compute_logits: bool = True, **kwargs, ) -> ModelOutput: if (input_ids is None) == (inputs_embeds is None): raise ValueError("Must provide exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) # Dict-keyed mask bypasses HF's auto-causal construction. if attention_mask is None or attention_mask.all(): attn_mask = {"full_attention": None} else: # 4D additive mask: 0=attend, -10000=ignore. B, L, _ = inputs_embeds.shape mask_4d = attention_mask.reshape(B, 1, 1, L).expand(B, 1, L, L) mask_4d = (1.0 - mask_4d.to(inputs_embeds.dtype)) * -10000 attn_mask = {"full_attention": mask_4d} out = self.model(inputs_embeds=inputs_embeds, attention_mask=attn_mask, use_cache=False) hidden = out.last_hidden_state logits = self.lm_head(hidden) if compute_logits else None return ModelOutput(hidden=hidden, logits=logits) @dataclass class AnCoderOutput(ModelOutput): anchor_hidden: torch.FloatTensor | None = None # (B, L, d) anchor_logits: torch.FloatTensor | None = None # (B, L, V), None when compute_anchor_logits=False denoiser_hidden: torch.FloatTensor | None = None # (B, L, d) logits: torch.FloatTensor | None = None # (B, L, V), final predictions class AnCoderConfig(PretrainedConfig): model_type = "ancoder" def __init__( self, anchor_config=None, denoiser_config=None, shift_logits: bool = True, bos_token_id: int | None = 151644, # <|im_start|>, distinct from PAD=151643 **kwargs, ): # Ensures that save_pretrained deduplicates and from_pretrained re-ties. kwargs.setdefault("tie_word_embeddings", True) super().__init__(bos_token_id=bos_token_id, **kwargs) self.anchor_config = anchor_config self.denoiser_config = denoiser_config self.shift_logits = shift_logits class AnCoder(PreTrainedModel): config_class = AnCoderConfig supports_gradient_checkpointing = True # Maps shared params so that save_pretrained writes only one copy. _tied_weights_keys = { "anchor.lm_head.weight": "anchor.model.embed_tokens.weight", "denoiser.model.embed_tokens.weight": "anchor.model.embed_tokens.weight", "denoiser.lm_head.weight": "anchor.model.embed_tokens.weight", } def __init__(self, config: AnCoderConfig, anchor=None, denoiser=None): super().__init__(config) self.anchor = anchor or BiQwen3(BiQwen3Config(**config.anchor_config)) self.denoiser = denoiser or BiQwen3(BiQwen3Config(**config.denoiser_config)) self.tie_weights() def tie_weights(self): # Override: _tied_weights_keys is save-only; runtime tying done here. self.anchor.lm_head.weight = self.anchor.model.embed_tokens.weight self.denoiser.model.embed_tokens.weight = self.anchor.model.embed_tokens.weight self.denoiser.lm_head.weight = self.anchor.model.embed_tokens.weight def get_input_embeddings(self) -> torch.nn.Embedding: return self.anchor.model.embed_tokens def get_output_embeddings(self) -> torch.nn.Linear: return self.denoiser.lm_head def forward( self, input_ids: torch.LongTensor, attention_mask: torch.Tensor | None = None, compute_anchor_logits: bool = True, **kwargs, ) -> AnCoderOutput: """When shift_logits=True, BOS is prepended so anchor->denoiser run on (B, L+1, *); hidden is sliced to (B, L, d) before lm_head, so logits emerge as (B, L, V) directly without materializing (B, L+1, V). """ if self.config.shift_logits: B = input_ids.shape[0] bos_id = self.config.bos_token_id if bos_id is None: raise ValueError("shift_logits=True requires bos_token_id on the config") bos = torch.full((B, 1), bos_id, dtype=input_ids.dtype, device=input_ids.device) input_ids_ = torch.cat([bos, input_ids], dim=1) # (B, L+1) if attention_mask is not None: ones = torch.ones((B, 1), dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([ones, attention_mask], dim=1) # (B, L+1) else: input_ids_ = input_ids anchor_out = self.anchor(input_ids=input_ids_, attention_mask=attention_mask, compute_logits=False) denoiser_out = self.denoiser(inputs_embeds=anchor_out.hidden, attention_mask=attention_mask, compute_logits=False) end = -1 if self.config.shift_logits else None anchor_hidden = anchor_out.hidden[:, :end, :].contiguous() denoiser_hidden = denoiser_out.hidden[:, :end, :].contiguous() return AnCoderOutput( anchor_hidden=anchor_hidden, anchor_logits=self.anchor.lm_head(anchor_hidden) if compute_anchor_logits else None, denoiser_hidden=denoiser_hidden, logits=self.denoiser.lm_head(denoiser_hidden), ) # Ensures that save_pretrained emits auto_map and copies this file. AnCoderConfig.register_for_auto_class() AnCoder.register_for_auto_class("AutoModel")