Feature Extraction
Transformers
Safetensors
English
diffretriever
information-retrieval
dense-retrieval
sparse-retrieval
colbert
diffusion-language-model
lora
custom_code
Instructions to use ielabgroup/diffretriever-llada-8b-single with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ielabgroup/diffretriever-llada-8b-single with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="ielabgroup/diffretriever-llada-8b-single", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ielabgroup/diffretriever-llada-8b-single", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| DiffRetriever — fine-tunable variant (paper §3.4 supervised fine-tuning). | |
| Contrastively fine-tunable wrapper around Dream / LLaDA backbones, | |
| sharing the masked-position prediction encoder with the zero-shot | |
| DreamRetriever / LLaDA2Retriever classes. Supports Dream, LLaDA1/1.5, | |
| LLaDA2 backbones via backbone_adapters.py. | |
| Encoding: | |
| steps=1 (fast): Single forward pass over [prefix][text][suffix][MASK×K]. | |
| repr_hidden[:, s, :] from MASK position s; quotation_emb from the | |
| closing " token. | |
| steps>1 (rich): Iterative denoising loop with mixed representations. | |
| At each step, gen positions already decoded in a prior step contribute | |
| their frozen hidden state (no gradient); positions still MASK contribute | |
| the current step's hidden (with gradient). | |
| Uniform LLaDA-style unmasking schedule: n_per_step = K // n_steps | |
| tokens are decoded per step. | |
| Loss (training): | |
| K > 1: ColBERT MaxSim InfoNCE over all K mixed vectors. | |
| steps>1: loss computed only at the final denoising step so gradient | |
| flows from the fully-contextualized representation without dilution. | |
| If progressive_step_weight > 0, retrieval loss is also applied at | |
| each intermediate step with linearly increasing weight (t/T), | |
| giving direct supervision to early denoising steps. | |
| Optional: sparse InfoNCE + FLOPS L1 regularization. | |
| At eval time the same encode() output (repr_hidden, quotation_emb, | |
| sparse_acts) supports all zero-shot retrieval modes: single_dense, | |
| multi_dense_k*, multi-denoise-step variants, and sparse versions of each. | |
| """ | |
| import re | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Dict, List, Optional, Tuple | |
| from pathlib import Path | |
| import logging | |
| from .sparse_utils import filter_sparse | |
| from .backbone_adapters import get_adapter, BackboneAdapter | |
| logger = logging.getLogger(__name__) | |
| _NUM_WORDS = ['one', 'two', 'three', 'four', 'five', | |
| 'six', 'seven', 'eight', 'nine', 'ten'] | |
| class TrainableDiffusionRetriever(nn.Module): | |
| """ | |
| Trainable Diffusion Retriever. | |
| Build via from_backbone(); use tokenize() in your data collator, then pass | |
| pre-tokenized tensors to forward() during training. | |
| """ | |
| def __init__( | |
| self, | |
| backbone: nn.Module, | |
| tokenizer, | |
| mask_token_id: int, | |
| hidden_size: int, | |
| query_prefix_ids: List[int], | |
| query_suffix_ids: List[int], | |
| passage_prefix_ids: List[int], | |
| passage_suffix_ids: List[int], | |
| max_length: int = 512, | |
| n_gen_tokens: int = 4, | |
| n_gen_q_tokens: Optional[int] = None, # asymmetric K_q (defaults to n_gen_tokens) | |
| n_gen_p_tokens: Optional[int] = None, # asymmetric K_p (defaults to n_gen_tokens) | |
| temperature: float = 0.01, | |
| num_denoise_steps: int = 4, | |
| sparse_weight: float = 1.0, | |
| normalize: bool = True, | |
| flash_attn: bool = False, | |
| use_eos: bool = False, | |
| # K-adapter (joint training of K-router + retriever). When enabled, | |
| # the model encodes at K_max and at training time computes per-cell | |
| # InfoNCE losses; a tiny MLP head maps the query's quotation_emb to a | |
| # softmax over K-cell choices and is supervised by KL(teacher || head) | |
| # where teacher = softmax(-per_cell_loss / τ_T). See compute_loss(). | |
| use_k_adapter: bool = False, | |
| adapter_weight: float = 1.0, | |
| teacher_temperature: float = 1.0, | |
| k_adapter_options: Optional[Tuple[int, ...]] = None, | |
| # Two-stage K-pre-encoder (true encoding savings via per-item K). | |
| # When enabled, a tiny MLP runs over the embedding-layer output and | |
| # predicts K_q (for queries) and K_p (for passages) per item. Each | |
| # item's input is sliced to its predicted K (removing K_max-K MASK | |
| # tokens), the batch is padded to max-K-in-batch, and main encoder | |
| # runs at variable length. See _forward_with_k_pre_encoder(). | |
| use_k_pre_encoder: bool = False, | |
| gumbel_temperature: float = 1.0, | |
| k_cost_lambda: float = 0.001, | |
| k_pre_encoder_options: Optional[Tuple[int, ...]] = None, | |
| ): | |
| super().__init__() | |
| self.backbone = backbone | |
| self.tokenizer = tokenizer | |
| self.mask_token_id = mask_token_id | |
| self.hidden_size = hidden_size | |
| # Prompt IDs (cached lists, not parameters) | |
| self.register_buffer('_dummy', torch.zeros(1)) # for .device | |
| self._query_prefix_ids = list(query_prefix_ids) | |
| self._query_suffix_ids = list(query_suffix_ids) | |
| self._passage_prefix_ids = list(passage_prefix_ids) | |
| self._passage_suffix_ids = list(passage_suffix_ids) | |
| self.max_length = max_length | |
| # Per-side K. When the user passes only `n_gen_tokens`, both sides | |
| # use it (symmetric — original behaviour). Pass `n_gen_q_tokens` / | |
| # `n_gen_p_tokens` to train asymmetric (e.g. K_q=4, K_p=16). | |
| self.n_gen_tokens = n_gen_tokens # legacy, kept for back-compat | |
| self.n_gen_q_tokens = n_gen_q_tokens if n_gen_q_tokens is not None else n_gen_tokens | |
| self.n_gen_p_tokens = n_gen_p_tokens if n_gen_p_tokens is not None else n_gen_tokens | |
| self.use_eos = use_eos # kept for config compat; _n_tail is authoritative | |
| # _n_tail is the structural tail (quote + chat_end + EOS); it's the | |
| # same regardless of K_q or K_p as long as at least one side has K>0. | |
| max_k = max(self.n_gen_q_tokens, self.n_gen_p_tokens) | |
| self._n_tail = 3 if max_k > 0 else 0 | |
| self._tail_ids: List[int] = self._build_tail_ids() # cached at init | |
| self.temperature = temperature | |
| # Denoise step count is capped by the larger K (the side with more | |
| # mask tokens has more positions to iteratively decode). | |
| self.num_denoise_steps = min(num_denoise_steps, max_k) | |
| self.sparse_weight = sparse_weight | |
| self.dense_weight = 1.0 | |
| self.normalize = normalize | |
| self.flash_attn = flash_attn | |
| self.model_type = 'unknown' # set by from_backbone() / load() | |
| self.adapter: Optional[BackboneAdapter] = None # set by factory methods | |
| self._hook_registered = False # set by _setup_hook() | |
| # Auxiliary losses (activated by setting weight > 0) | |
| self.denoising_weight = 0.0 # diffusion-native masked text denoising | |
| self.diversity_weight = 0.0 # explicit multi-vector diversity | |
| self.denoise_mask_ratio = 0.15 # fraction of text tokens to mask for denoising | |
| self.progressive_step_weight = 0.0 # progressive step supervision (multi-step only) | |
| self.use_fresh_final = False # "fresh" ablation: use current hidden for ALL K at final step | |
| self.soft_denoising = False # soft-token multi-step: differentiable embeddings instead of hard tokens | |
| self.soft_temperature = 1.0 # temperature for softmax in soft-token denoising | |
| self.corruption_rate = 0.0 # max text corruption rate for denoising-conditioned training (0 = off) | |
| self.debug_dense_metrics = False | |
| self.debug_compare_hidden_once = False | |
| self._debug_hidden_compared = False | |
| # ── K-adapter (joint K-router + retriever) ───────────────────────── | |
| self.use_k_adapter = bool(use_k_adapter) | |
| self.adapter_weight = float(adapter_weight) | |
| self.teacher_temperature = float(teacher_temperature) | |
| # Default K options: factors of K_max that fall within the active K | |
| # range. E.g., K_max=16 → {1,2,4,8,16}; K_max=4 → {1,2,4}. | |
| if k_adapter_options is None: | |
| k_max_active = max(self.n_gen_q_tokens, self.n_gen_p_tokens) | |
| base = (1, 2, 4, 8, 16) | |
| self.k_adapter_options: Tuple[int, ...] = tuple( | |
| k for k in base if k <= k_max_active | |
| ) | |
| else: | |
| self.k_adapter_options = tuple(int(k) for k in k_adapter_options) | |
| self.n_K = len(self.k_adapter_options) | |
| if self.use_k_adapter: | |
| # Tiny MLP: hidden_size → hidden_size//4 → n_K * n_K. | |
| # Trained jointly with the backbone via the per-cell loss teacher. | |
| # Kept in fp32 to match the .float() conversion of representations | |
| # in encode(); HF Trainer's bf16 autocast handles mixed precision | |
| # at the call site without needing a manual dtype cast here. | |
| head_dim = max(64, hidden_size // 4) | |
| self.k_adapter = nn.Sequential( | |
| nn.Linear(hidden_size, head_dim), | |
| nn.GELU(), | |
| nn.Linear(head_dim, self.n_K * self.n_K), | |
| ) | |
| logger.info( | |
| f"KAdapter enabled: K options={self.k_adapter_options}, " | |
| f"adapter_weight={self.adapter_weight}, " | |
| f"teacher_temperature={self.teacher_temperature}" | |
| ) | |
| else: | |
| self.k_adapter = None | |
| # ── K-pre-encoder (two-stage encoding) ───────────────────────────── | |
| self.use_k_pre_encoder = bool(use_k_pre_encoder) | |
| self.gumbel_temperature = float(gumbel_temperature) | |
| self.k_cost_lambda = float(k_cost_lambda) | |
| if k_pre_encoder_options is None: | |
| k_max_active_pe = max(self.n_gen_q_tokens, self.n_gen_p_tokens) | |
| base_pe = (1, 2, 4, 8, 16) | |
| self.k_pre_encoder_options: Tuple[int, ...] = tuple( | |
| k for k in base_pe if k <= k_max_active_pe | |
| ) | |
| else: | |
| self.k_pre_encoder_options = tuple(int(k) for k in k_pre_encoder_options) | |
| self.n_K_pe = len(self.k_pre_encoder_options) | |
| # Buffer of K values for differentiable expected-K (cost regularizer). | |
| self.register_buffer( | |
| 'k_pe_options_tensor', | |
| torch.tensor(self.k_pre_encoder_options, dtype=torch.float32), | |
| persistent=False, | |
| ) | |
| if self.use_k_pre_encoder: | |
| head_dim_pe = max(64, hidden_size // 4) | |
| # Separate q + p heads — query and passage have different optimal K | |
| # distributions and texts differ enough to justify dedicated heads. | |
| # Both heads operate on the (shared) embedding layer's output — | |
| # mean-pooled over real tokens. Cheap (~1% of full forward). | |
| self.k_pre_encoder_q = nn.Sequential( | |
| nn.Linear(hidden_size, head_dim_pe), | |
| nn.GELU(), | |
| nn.Linear(head_dim_pe, self.n_K_pe), | |
| ) | |
| self.k_pre_encoder_p = nn.Sequential( | |
| nn.Linear(hidden_size, head_dim_pe), | |
| nn.GELU(), | |
| nn.Linear(head_dim_pe, self.n_K_pe), | |
| ) | |
| logger.info( | |
| f"K pre-encoder enabled: K options={self.k_pre_encoder_options}, " | |
| f"gumbel_τ={self.gumbel_temperature}, " | |
| f"cost_λ={self.k_cost_lambda}" | |
| ) | |
| else: | |
| self.k_pre_encoder_q = None | |
| self.k_pre_encoder_p = None | |
| def _k(self, is_query: bool) -> int: | |
| """Return the appropriate per-side K. Lets the rest of the code | |
| say `K = self._k(is_query)` and stay correct under both symmetric | |
| (K_q == K_p) and asymmetric configurations.""" | |
| return self.n_gen_q_tokens if is_query else self.n_gen_p_tokens | |
| # ---------------------------------------------------------------- | |
| # Build from pretrained backbone | |
| # ---------------------------------------------------------------- | |
| def from_backbone( | |
| cls, | |
| model_name: str, | |
| model_type: str, | |
| query_prompt: str, | |
| passage_prompt: str, | |
| max_length: int = 512, | |
| n_gen_tokens: int = 4, | |
| n_gen_q_tokens: Optional[int] = None, | |
| n_gen_p_tokens: Optional[int] = None, | |
| temperature: float = 0.01, | |
| num_denoise_steps: int = 4, | |
| sparse_weight: float = 1.0, | |
| normalize: bool = True, | |
| gradient_checkpointing: bool = True, | |
| lora_rank: int = 0, | |
| lora_alpha: int = 64, | |
| lora_dropout: float = 0.05, | |
| device_map=None, | |
| use_eos: bool = False, | |
| disable_hidden_hook: bool = False, | |
| # K-adapter passthrough | |
| use_k_adapter: bool = False, | |
| adapter_weight: float = 1.0, | |
| teacher_temperature: float = 1.0, | |
| k_adapter_options: Optional[Tuple[int, ...]] = None, | |
| # K pre-encoder passthrough | |
| use_k_pre_encoder: bool = False, | |
| gumbel_temperature: float = 1.0, | |
| k_cost_lambda: float = 0.001, | |
| k_pre_encoder_options: Optional[Tuple[int, ...]] = None, | |
| ) -> 'TrainableDiffusionRetriever': | |
| from transformers import AutoTokenizer | |
| adapter = get_adapter(model_type) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| tokenizer.padding_side = 'left' | |
| backbone = adapter.load_backbone(model_name, device_map=device_map) | |
| if lora_rank > 0: | |
| from peft import get_peft_model | |
| lora_cfg = adapter.get_lora_config(lora_rank, lora_alpha, lora_dropout) | |
| backbone = get_peft_model(backbone, lora_cfg) | |
| backbone.print_trainable_parameters() | |
| if gradient_checkpointing: | |
| backbone.enable_input_require_grads() | |
| adapter.enable_gradient_checkpointing(backbone) | |
| mask_token_id = adapter.mask_token_id | |
| logger.info(f"Mask token: {model_type} → {mask_token_id}") | |
| # Resolve per-side K (default to symmetric n_gen_tokens). The | |
| # YAML adaptation ('the relevant document is one word' → 'four | |
| # words') uses the per-side K so the natural-language prompt | |
| # matches the actual MASK count the model will see. | |
| k_q = n_gen_q_tokens if n_gen_q_tokens is not None else n_gen_tokens | |
| k_p = n_gen_p_tokens if n_gen_p_tokens is not None else n_gen_tokens | |
| # Build prompt token IDs (per-side adaptation) | |
| q_yaml = cls._load_yaml(query_prompt) | |
| p_yaml = cls._load_yaml(passage_prompt) | |
| q_yaml = dict(q_yaml) | |
| q_yaml['user_suffix'] = cls._adapt_for_k(q_yaml.get('user_suffix', ''), k_q) | |
| q_yaml['assistant_prefix'] = cls._adapt_for_k(q_yaml.get('assistant_prefix', ''), k_q) | |
| p_yaml = dict(p_yaml) | |
| p_yaml['user_suffix'] = cls._adapt_for_k(p_yaml.get('user_suffix', ''), k_p) | |
| p_yaml['assistant_prefix'] = cls._adapt_for_k(p_yaml.get('assistant_prefix', ''), k_p) | |
| q_prefix_ids, q_suffix_ids = cls._build_prompt_ids(tokenizer, q_yaml) | |
| p_prefix_ids, p_suffix_ids = cls._build_prompt_ids(tokenizer, p_yaml) | |
| logger.info(f"Query prompt: {len(q_prefix_ids)} prefix + {len(q_suffix_ids)} suffix tokens") | |
| logger.info(f"Passage prompt: {len(p_prefix_ids)} prefix + {len(p_suffix_ids)} suffix tokens") | |
| if k_q == k_p: | |
| logger.info(f"n_gen_tokens={k_q} (symmetric), num_denoise_steps={num_denoise_steps}") | |
| else: | |
| logger.info(f"n_gen_q_tokens={k_q}, n_gen_p_tokens={k_p} (asymmetric), " | |
| f"num_denoise_steps={num_denoise_steps}") | |
| if sparse_weight > 0: | |
| logger.info(f"Sparse loss: sparse_weight={sparse_weight}") | |
| if num_denoise_steps is None: | |
| num_denoise_steps = max(k_q, k_p) | |
| hidden_size = backbone.config.hidden_size | |
| model = cls( | |
| backbone=backbone, | |
| tokenizer=tokenizer, | |
| mask_token_id=mask_token_id, | |
| hidden_size=hidden_size, | |
| query_prefix_ids=q_prefix_ids, | |
| query_suffix_ids=q_suffix_ids, | |
| passage_prefix_ids=p_prefix_ids, | |
| passage_suffix_ids=p_suffix_ids, | |
| max_length=max_length, | |
| n_gen_tokens=n_gen_tokens, | |
| n_gen_q_tokens=k_q, | |
| n_gen_p_tokens=k_p, | |
| temperature=temperature, | |
| num_denoise_steps=num_denoise_steps, | |
| sparse_weight=sparse_weight, | |
| normalize=normalize, | |
| flash_attn=adapter.flash_attn, | |
| use_eos=use_eos, | |
| use_k_adapter=use_k_adapter, | |
| adapter_weight=adapter_weight, | |
| teacher_temperature=teacher_temperature, | |
| k_adapter_options=k_adapter_options, | |
| use_k_pre_encoder=use_k_pre_encoder, | |
| gumbel_temperature=gumbel_temperature, | |
| k_cost_lambda=k_cost_lambda, | |
| k_pre_encoder_options=k_pre_encoder_options, | |
| ) | |
| model.model_type = model_type | |
| model.adapter = adapter | |
| model.lora_rank = lora_rank | |
| model.lora_alpha = lora_alpha | |
| # Hook on output projection for efficient hidden state extraction | |
| model._last_hidden: Dict[str, torch.Tensor] = {} | |
| if disable_hidden_hook: | |
| logger.info(f"{model_type}: hidden hook disabled by flag") | |
| model._hook_registered = False | |
| else: | |
| model._hook_registered = adapter.register_hidden_hook( | |
| backbone, model._last_hidden) | |
| if not model._hook_registered: | |
| logger.info(f"{model_type}: no hook — will use output_hidden_states=True") | |
| return model | |
| # ---------------------------------------------------------------- | |
| # Prompt helpers (mirrors DreamRetriever) | |
| # ---------------------------------------------------------------- | |
| def _load_yaml(path: str) -> dict: | |
| import yaml | |
| p = Path(path) | |
| if not p.exists(): | |
| raise FileNotFoundError(f"Prompt YAML not found: {path}") | |
| return yaml.safe_load(p.read_text()) | |
| def _adapt_for_k(text: str, k: int) -> str: | |
| if k <= 1 or not text: | |
| return text | |
| count = _NUM_WORDS[k - 1] if k <= len(_NUM_WORDS) else str(k) | |
| result = re.sub( | |
| r'\b(?:' + '|'.join(_NUM_WORDS) + r')\b(\s+words?)', | |
| lambda m: f'{count} words', text, | |
| ) | |
| return re.sub(r'\bword is\b', 'words are', result) | |
| def _build_prompt_ids(tokenizer, yaml_dict: dict) -> Tuple[List[int], List[int]]: | |
| system = yaml_dict.get('system', '') | |
| user_prefix = yaml_dict.get('user_prefix', '') | |
| user_suffix = yaml_dict.get('user_suffix', '') | |
| assistant_prefix = yaml_dict.get('assistant_prefix', '') | |
| SENTINEL = "XSENTINELX" | |
| messages = [] | |
| if system: | |
| messages.append({"role": "system", "content": system}) | |
| messages.append({"role": "user", "content": user_prefix + SENTINEL + user_suffix}) | |
| full_str = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| pre_str, post_str = full_str.split(SENTINEL, 1) | |
| prefix_ids = tokenizer.encode(pre_str, add_special_tokens=False) | |
| suffix_ids = tokenizer.encode(post_str + assistant_prefix, add_special_tokens=False) | |
| return prefix_ids, suffix_ids | |
| # ---------------------------------------------------------------- | |
| # Structural tail tokens (matches zero-shot dream/llada retrievers) | |
| # ---------------------------------------------------------------- | |
| def _build_tail_ids(self) -> List[int]: | |
| """Build structural tail tokens: [quote, chat_end, EOS]. | |
| Matches the zero-shot DreamRetriever._build_tail_ids and | |
| LLaDA2Retriever._build_tail_ids so that training sees the same | |
| sequence layout as zero-shot inference. | |
| """ | |
| if self._n_tail <= 0: | |
| return [] | |
| eos_id = self.tokenizer.eos_token_id | |
| tail: List[int] = [] | |
| # Closing quote " | |
| quote_ids = self.tokenizer.encode('"', add_special_tokens=False) | |
| if len(quote_ids) == 1: | |
| tail.append(quote_ids[0]) | |
| else: | |
| logger.warning('quote \'"\' tokenises to %d tokens; using EOS in slot 0', len(quote_ids)) | |
| tail.append(eos_id) | |
| # Chat-template end token (im_end for Dream/Qwen, eot_id for LLaDA) | |
| _KNOWN_EOT = {'<|im_end|>': None, '<|eot_id|>': 126348} # fallback for LLaDA | |
| for tok_str, fallback_id in _KNOWN_EOT.items(): | |
| tid = self.tokenizer.convert_tokens_to_ids(tok_str) | |
| unk = getattr(self.tokenizer, 'unk_token_id', None) | |
| if tid is not None and tid >= 0 and tid != unk and tid != eos_id: | |
| tail.append(tid) | |
| break | |
| elif fallback_id is not None and fallback_id != eos_id: | |
| tail.append(fallback_id) | |
| break | |
| # Fill remaining slots with EOS | |
| while len(tail) < self._n_tail: | |
| tail.append(eos_id) | |
| return tail[:self._n_tail] | |
| # ---------------------------------------------------------------- | |
| # Tokenization (called from the data collator, runs on CPU) | |
| # ---------------------------------------------------------------- | |
| def tokenize( | |
| self, texts: List[str], is_query: bool | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Returns (input_ids, attention_mask) tensors on CPU.""" | |
| prefix_ids = self._query_prefix_ids if is_query else self._passage_prefix_ids | |
| suffix_ids = self._query_suffix_ids if is_query else self._passage_suffix_ids | |
| gen_ids = [self.mask_token_id] * self._k(is_query) | |
| mask_block = gen_ids + self._tail_ids # [MASK×K | " | chat_end | EOS] | |
| max_text_len = self.max_length - len(prefix_ids) - len(suffix_ids) | |
| enc = self.tokenizer( | |
| texts, | |
| padding=False, | |
| truncation=True, | |
| max_length=max_text_len, | |
| return_attention_mask=False, | |
| return_token_type_ids=False, | |
| add_special_tokens=False, | |
| ) | |
| enc['input_ids'] = [ | |
| prefix_ids + ids + suffix_ids + mask_block | |
| for ids in enc['input_ids'] | |
| ] | |
| collated = self.tokenizer.pad( | |
| enc, | |
| padding=True, | |
| return_attention_mask=True, | |
| return_tensors='pt', | |
| ) | |
| return collated['input_ids'], collated['attention_mask'] | |
| # ---------------------------------------------------------------- | |
| # Attention mask: bidirectional (padding-only masking) | |
| # ---------------------------------------------------------------- | |
| def _build_4d_mask( | |
| self, seq_len: int, attention_mask: torch.Tensor | |
| ) -> torch.Tensor: | |
| """4D [B, 1, S, S] bidirectional attention mask.""" | |
| dtype = next(self.backbone.parameters()).dtype | |
| min_val = torch.finfo(dtype).min | |
| B = attention_mask.size(0) | |
| mask_4d = torch.zeros(B, 1, seq_len, seq_len, | |
| device=attention_mask.device, dtype=dtype) | |
| pad = ~attention_mask.bool() | |
| mask_4d = mask_4d.masked_fill(pad.unsqueeze(1).unsqueeze(1), min_val) | |
| mask_4d = mask_4d.masked_fill(pad.unsqueeze(1).unsqueeze(3), min_val) | |
| return mask_4d | |
| # ---------------------------------------------------------------- | |
| # Backbone forward | |
| # ---------------------------------------------------------------- | |
| def _fwd( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| need_logits: bool = False, | |
| mask_4d: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """Single backbone pass. Returns (hidden [B,L,H], logits [B,L,V] or None). | |
| mask_4d: pre-computed 4D attention mask — pass it in the multi-step loop | |
| to avoid recomputing the same mask at every denoising step. | |
| inputs_embeds: if provided, use these embeddings instead of input_ids. | |
| """ | |
| # Adapter determines mask format: 4D [B,1,L,L] or 2D [B,L] | |
| use_4d = self.adapter.needs_4d_mask() if self.adapter else True | |
| if use_4d: | |
| if mask_4d is None: | |
| seq_len = inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1) | |
| mask_4d = self._build_4d_mask(seq_len, attention_mask) | |
| fwd_mask = mask_4d | |
| else: | |
| fwd_mask = attention_mask | |
| # If hook is registered, skip output_hidden_states (saves ~1.2GB). | |
| need_hidden_states = not self._hook_registered | |
| fwd_kwargs = dict( | |
| attention_mask=fwd_mask, | |
| output_hidden_states=need_hidden_states, | |
| return_dict=True, | |
| ) | |
| if inputs_embeds is not None: | |
| fwd_kwargs['inputs_embeds'] = inputs_embeds | |
| else: | |
| fwd_kwargs['input_ids'] = input_ids | |
| outputs = self.backbone(**fwd_kwargs) | |
| # Keep hidden in native dtype (bf16). Callers convert only small slices. | |
| if self._hook_registered and 'h' in self._last_hidden: | |
| hidden = self._last_hidden.pop('h') | |
| elif getattr(outputs, 'last_hidden_state', None) is not None: | |
| hidden = outputs.last_hidden_state | |
| elif hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None: | |
| hidden = outputs.hidden_states[-1] | |
| else: | |
| hidden = outputs[0] | |
| # Keep logits in native dtype (bf16) to avoid doubling memory. | |
| # Downstream consumers convert small slices to float32 as needed. | |
| logits = outputs.logits if need_logits and hasattr(outputs, 'logits') else None | |
| if (self.debug_compare_hidden_once | |
| and self._hook_registered | |
| and not self._debug_hidden_compared): | |
| should_log = (not torch.distributed.is_initialized() | |
| or torch.distributed.get_rank() == 0) | |
| if should_log: | |
| with torch.no_grad(): | |
| ref_outputs = self.backbone( | |
| input_ids=input_ids, | |
| attention_mask=fwd_mask, | |
| output_hidden_states=True, | |
| return_dict=True, | |
| ) | |
| if getattr(ref_outputs, 'last_hidden_state', None) is not None: | |
| ref_hidden = ref_outputs.last_hidden_state | |
| elif hasattr(ref_outputs, 'hidden_states') and ref_outputs.hidden_states is not None: | |
| ref_hidden = ref_outputs.hidden_states[-1] | |
| else: | |
| ref_hidden = ref_outputs[0] | |
| diff = (hidden.float() - ref_hidden.float()).abs() | |
| logger.warning( | |
| "Hidden-hook check: model=%s shape=%s mean_abs_diff=%.6g max_abs_diff=%.6g", | |
| self.model_type, tuple(hidden.shape), | |
| diff.mean().item(), diff.max().item(), | |
| ) | |
| self._debug_hidden_compared = True | |
| return hidden, logits | |
| # ---------------------------------------------------------------- | |
| # Confidence-based unmasking | |
| # ---------------------------------------------------------------- | |
| def _sample_with_confidence(logits: torch.Tensor, alg: str = 'entropy'): | |
| """Greedy decode + confidence score for discrete diffusion unmasking. | |
| Args: | |
| logits: [N, V] logits at the N currently-masked positions | |
| alg: 'entropy' (neg entropy; higher = more certain) or 'max_prob' | |
| Returns: | |
| (confidence [N], x0 [N]) — confidence scores and predicted token IDs | |
| """ | |
| probs = torch.softmax(logits.float(), dim=-1) | |
| x0 = probs.argmax(dim=-1) | |
| if alg == 'entropy': | |
| log_probs = torch.log(probs + 1e-10) | |
| confidence = (probs * log_probs).sum(dim=-1) # neg entropy | |
| else: | |
| confidence = probs.max(dim=-1).values | |
| return confidence, x0 | |
| def _unmask_step( | |
| self, | |
| curr_ids: torch.Tensor, # [B, L] | |
| logits: torch.Tensor, # [B, L, V] | |
| K: int, | |
| n_per_step: int, | |
| ) -> Tuple[torch.Tensor, List[List[int]]]: | |
| """Unmask n_per_step most-confident gen tokens (uniform LLaDA-style schedule). | |
| Returns: | |
| new_ids: updated token IDs (same shape as curr_ids) | |
| newly_decoded: list-of-lists; newly_decoded[i] = gen-block positions | |
| [0..K-1] that were just decoded for example i | |
| """ | |
| new_ids = curr_ids.clone() | |
| B, L = curr_ids.shape | |
| g_start = L - K - self._n_tail | |
| newly_decoded: List[List[int]] = [[] for _ in range(B)] | |
| for i in range(B): | |
| gen_ids = new_ids[i, g_start:g_start + K] | |
| mask_bool = (gen_ids == self.mask_token_id) | |
| n_masked = mask_bool.sum().item() | |
| if n_masked == 0: | |
| continue | |
| # Convert only the small slice [n_masked, V] to float32, not full logits | |
| ml = logits[i, g_start:g_start + K][mask_bool].detach().float() | |
| if not torch.isfinite(ml).all(): | |
| continue | |
| conf, x0 = self._sample_with_confidence(ml) | |
| n_tr = min(n_per_step, n_masked) | |
| _, xfer = torch.topk(conf, n_tr) | |
| masked_abs = torch.where(mask_bool)[0] | |
| selected = masked_abs[xfer] | |
| new_ids[i, g_start + selected] = x0[xfer] | |
| newly_decoded[i] = selected.tolist() | |
| return new_ids, newly_decoded | |
| # ---------------------------------------------------------------- | |
| # Soft-token multi-step helpers | |
| # ---------------------------------------------------------------- | |
| def _get_embed_layer(self) -> nn.Module: | |
| """Return the input embedding layer (works through PEFT wrapper).""" | |
| return self.backbone.get_input_embeddings() | |
| def _soft_unmask_step( | |
| self, | |
| curr_embeds: torch.Tensor, # [B, L, H_emb] — current embeddings (differentiable) | |
| curr_ids: torch.Tensor, # [B, L] — tracking which positions are still MASK | |
| logits: torch.Tensor, # [B, L, V] — logits from current forward pass | |
| K: int, | |
| n_per_step: int, | |
| temperature: float = 1.0, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]]]: | |
| """Soft-unmask n_per_step most confident gen tokens. | |
| Instead of hard token replacement (argmax → embed lookup), computes | |
| differentiable soft embeddings: softmax(logits / τ) @ embedding_matrix. | |
| Gradient flows through logits → softmax → matmul → next-step embeddings. | |
| Returns: | |
| new_embeds: [B, L, H_emb] with soft tokens at decoded positions | |
| new_ids: [B, L] updated token IDs (hard argmax, for tracking only) | |
| newly_decoded: list of gen-block positions [0..K-1] decoded this step | |
| """ | |
| embed_weight = self._get_embed_layer().weight # [V, H_emb] | |
| B, L = curr_ids.shape | |
| g_start = L - K - self._n_tail | |
| new_ids = curr_ids.clone() | |
| newly_decoded: List[List[int]] = [[] for _ in range(B)] | |
| # Compute soft embeddings for all K gen positions (differentiable) | |
| gen_logits = logits[:, g_start:g_start + K, :] # [B, K, V] | |
| soft_probs = F.softmax(gen_logits / temperature, dim=-1) # [B, K, V] | |
| soft_embs = soft_probs @ embed_weight # [B, K, H_emb] | |
| # Determine which positions to unmask (most confident, same schedule as hard) | |
| replace_mask = torch.zeros(B, K, dtype=torch.bool, device=curr_embeds.device) | |
| for i in range(B): | |
| gen_ids = curr_ids[i, g_start:g_start + K] | |
| mask_bool = (gen_ids == self.mask_token_id) | |
| n_masked = mask_bool.sum().item() | |
| if n_masked == 0: | |
| continue | |
| ml = gen_logits[i][mask_bool].detach().float() | |
| if not torch.isfinite(ml).all(): | |
| continue | |
| conf, x0 = self._sample_with_confidence(ml) | |
| n_tr = min(n_per_step, n_masked) | |
| _, xfer = torch.topk(conf, n_tr) | |
| masked_abs = torch.where(mask_bool)[0] | |
| selected = masked_abs[xfer] | |
| replace_mask[i, selected] = True | |
| new_ids[i, g_start + selected] = x0[xfer] | |
| newly_decoded[i] = selected.tolist() | |
| # Mix: replace selected gen positions with soft embeddings, keep rest | |
| curr_gen_embs = curr_embeds[:, g_start:g_start + K, :] # [B, K, H_emb] | |
| new_gen_embs = torch.where( | |
| replace_mask.unsqueeze(-1), # [B, K, 1] | |
| soft_embs, # [B, K, H_emb] — differentiable | |
| curr_gen_embs, # [B, K, H_emb] | |
| ) | |
| # Rebuild full sequence (avoids in-place ops on grad-carrying tensor) | |
| new_embeds = torch.cat([ | |
| curr_embeds[:, :g_start, :], | |
| new_gen_embs, | |
| curr_embeds[:, g_start + K:, :], | |
| ], dim=1) | |
| return new_embeds, new_ids, newly_decoded | |
| def _soft_multistep_forward( | |
| self, | |
| query_input_ids: torch.Tensor, | |
| query_attention_mask: torch.Tensor, | |
| passage_input_ids: torch.Tensor, | |
| passage_attention_mask: torch.Tensor, | |
| query_content_ids: Optional[List], | |
| passage_content_ids: Optional[List], | |
| n_steps: int, | |
| ) -> Dict[str, torch.Tensor]: | |
| """Multi-step denoising training with soft (differentiable) token replacement. | |
| Unlike hard multi-step where decoded positions are frozen (.detach()), | |
| soft-token mode keeps gradient flowing through ALL K positions at every | |
| step via softmax(logits/τ) @ embedding_matrix. | |
| Per-step contrastive loss (GIRCSE-style) provides direct supervision at | |
| each denoising step. A monotonicity regularizer penalizes regressions. | |
| """ | |
| K_q, K_p = self.n_gen_q_tokens, self.n_gen_p_tokens | |
| n_tail = self._n_tail | |
| n_per_step_q = max(1, K_q // n_steps) | |
| n_per_step_p = max(1, K_p // n_steps) | |
| device = query_input_ids.device | |
| B_q = query_input_ids.size(0) | |
| B_p = passage_input_ids.size(0) | |
| n_paq = B_p // B_q | |
| L_q = query_input_ids.size(1) | |
| L_p = passage_input_ids.size(1) | |
| need_sparse = self.sparse_weight > 0 | |
| # Pre-compute 4D masks (sequence structure doesn't change across steps) | |
| use_4d = self.adapter.needs_4d_mask() if self.adapter else True | |
| q_mask_4d = self._build_4d_mask(L_q, query_attention_mask) if use_4d else None | |
| p_mask_4d = self._build_4d_mask(L_p, passage_attention_mask) if use_4d else None | |
| # Initial embeddings from the backbone's embedding layer | |
| embed_layer = self._get_embed_layer() | |
| q_embeds = embed_layer(query_input_ids) | |
| p_embeds = embed_layer(passage_input_ids) | |
| q_curr_ids = query_input_ids.clone() | |
| p_curr_ids = passage_input_ids.clone() | |
| q_g = L_q - K_q - n_tail | |
| p_g = L_p - K_p - n_tail | |
| step_losses: List[torch.Tensor] = [] | |
| final_loss: Dict[str, torch.Tensor] = {} | |
| for step in range(n_steps): | |
| is_last = (step == n_steps - 1) | |
| # Forward pass using embeddings (step 0: original, step 1+: soft tokens) | |
| q_h, q_logits = self._fwd( | |
| q_curr_ids, query_attention_mask, | |
| need_logits=True, mask_4d=q_mask_4d, | |
| inputs_embeds=q_embeds if step > 0 else None, | |
| ) | |
| p_h, p_logits = self._fwd( | |
| p_curr_ids, passage_attention_mask, | |
| need_logits=True, mask_4d=p_mask_4d, | |
| inputs_embeds=p_embeds if step > 0 else None, | |
| ) | |
| # Extract representations — ALL K positions have gradient | |
| q_repr_hidden = q_h[:, q_g:q_g + K_q, :].float() | |
| q_quotation_emb = q_h[:, q_g - 1, :].float() | |
| p_repr_hidden = p_h[:, p_g:p_g + K_p, :].float() | |
| p_quotation_emb = p_h[:, p_g - 1, :].float() | |
| q_sparse_max = None | |
| p_sparse_max = None | |
| # Monotonic trick: log1p(relu(.)) is non-decreasing, so | |
| # max_k log1p(relu(x_k)) == log1p(relu(max_k x_k)) | |
| # Maxing in bf16 first avoids materialising [B, K, V] fp32 intermediate. | |
| if need_sparse and q_logits is not None: | |
| q_sparse_max = torch.log1p(torch.relu( | |
| q_logits[:, q_g:q_g + K_q, :].max(dim=1).values)) | |
| if need_sparse and p_logits is not None: | |
| p_sparse_max = torch.log1p(torch.relu( | |
| p_logits[:, p_g:p_g + K_p, :].max(dim=1).values)) | |
| if self.normalize: | |
| q_quotation_emb = F.normalize(q_quotation_emb, p=2, dim=-1) | |
| q_repr_hidden = F.normalize(q_repr_hidden, p=2, dim=-1) | |
| p_quotation_emb = F.normalize(p_quotation_emb, p=2, dim=-1) | |
| p_repr_hidden = F.normalize(p_repr_hidden, p=2, dim=-1) | |
| q_repr = {'repr_hidden': q_repr_hidden, 'quotation_emb': q_quotation_emb} | |
| p_repr = {'repr_hidden': p_repr_hidden, 'quotation_emb': p_quotation_emb} | |
| if q_sparse_max is not None: | |
| if query_content_ids is not None: | |
| q_sparse_max = filter_sparse(q_sparse_max, query_content_ids) | |
| q_repr['sparse_acts'] = q_sparse_max | |
| if p_sparse_max is not None: | |
| if passage_content_ids is not None: | |
| p_sparse_max = filter_sparse(p_sparse_max, passage_content_ids) | |
| p_repr['sparse_acts'] = p_sparse_max | |
| # Cross-GPU negative sharing | |
| q_repr = self._gather_repr(q_repr) | |
| p_repr = self._gather_repr(p_repr) | |
| B_q_g = q_repr['repr_hidden'].size(0) | |
| B_p_g = p_repr['repr_hidden'].size(0) | |
| labels = torch.arange(B_q_g, device=device) * (B_p_g // B_q_g) | |
| # Per-step contrastive loss | |
| step_loss_dict = self.compute_loss(q_repr, p_repr, labels) | |
| step_losses.append(step_loss_dict['loss']) | |
| if is_last: | |
| final_loss = step_loss_dict | |
| # Soft unmask for next step (per-side K) | |
| if not is_last: | |
| q_embeds, q_curr_ids, _ = self._soft_unmask_step( | |
| q_embeds, q_curr_ids, q_logits, K_q, n_per_step_q, | |
| self.soft_temperature) | |
| p_embeds, p_curr_ids, _ = self._soft_unmask_step( | |
| p_embeds, p_curr_ids, p_logits, K_p, n_per_step_p, | |
| self.soft_temperature) | |
| # Accumulate per-step losses (linearly increasing weight) | |
| if len(step_losses) > 1: | |
| progressive_loss = sum( | |
| (s + 1) / n_steps * loss | |
| for s, loss in enumerate(step_losses[:-1]) | |
| ) | |
| final_loss['loss'] = final_loss['loss'] + progressive_loss | |
| final_loss['loss_progressive'] = progressive_loss.detach() | |
| # Monotonicity regularizer: penalize if later step is worse | |
| mono_penalty = torch.stack([ | |
| torch.relu(step_losses[k + 1] - step_losses[k].detach()) | |
| for k in range(len(step_losses) - 1) | |
| ]).mean() | |
| final_loss['loss'] = final_loss['loss'] + 0.1 * mono_penalty | |
| final_loss['loss_monotonicity'] = mono_penalty.detach() | |
| # DDP scaling | |
| if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: | |
| final_loss['loss'] = final_loss['loss'] * torch.distributed.get_world_size() | |
| # Denoising auxiliary (same as other paths) | |
| if self.denoising_weight > 0 and final_loss: | |
| p_corrupted, p_denoise_targets, mask_ratio = self._apply_text_masking( | |
| passage_input_ids, passage_attention_mask) | |
| _, p_logits_dn = self._fwd(p_corrupted, passage_attention_mask, | |
| need_logits=True, mask_4d=p_mask_4d) | |
| denoise_loss = self.compute_denoising_loss( | |
| p_logits_dn, p_denoise_targets, mask_ratio) | |
| final_loss['loss'] = final_loss['loss'] + self.denoising_weight * denoise_loss | |
| final_loss['loss_denoising'] = denoise_loss.detach() | |
| return final_loss | |
| # ---------------------------------------------------------------- | |
| # K pre-encoder (two-stage variable-length training) | |
| # ---------------------------------------------------------------- | |
| def _pre_encoder_logits( | |
| self, | |
| input_ids: torch.Tensor, # [B, L] | |
| attention_mask: torch.Tensor, # [B, L] | |
| is_query: bool, | |
| ) -> torch.Tensor: | |
| """Tiny embedding-only K-router. Mean-pools the embedding-layer | |
| output over real (non-pad) tokens, then runs a small MLP head. | |
| Returns [B, n_K_pe] logits over self.k_pre_encoder_options. | |
| Cost ≈ embedding lookup + 2-layer MLP — under 1% of full forward. | |
| """ | |
| head = self.k_pre_encoder_q if is_query else self.k_pre_encoder_p | |
| # Use main encoder's embedding layer (parameter-shared). | |
| embed_layer = self._get_embed_layer() | |
| emb = embed_layer(input_ids) # [B, L, H] | |
| # Mean-pool over real tokens (attention_mask = 0 at pads). | |
| m = attention_mask.float().unsqueeze(-1) | |
| pooled = (emb.float() * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) # [B, H] | |
| # Match head dtype to avoid AMP/DeepSpeed dtype mismatch. | |
| head_dtype = next(head.parameters()).dtype | |
| logits = head(pooled.to(dtype=head_dtype)).float() # [B, n_K_pe] | |
| return logits | |
| def _slice_to_per_item_K( | |
| self, | |
| input_ids: torch.Tensor, # [B, L_max] left-padded, K_max MASKs at end | |
| attention_mask: torch.Tensor, | |
| K_per_item: torch.Tensor, # [B] integers in self.k_pre_encoder_options | |
| K_max_side: int, # n_gen_q_tokens or n_gen_p_tokens for this side | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Slice each item's input from K_max MASK tokens down to K MASK tokens | |
| (per-item K), preserving the structural tail [", chat_end, EOS]. | |
| Layout (all left-padded — content right-aligned): | |
| [PAD ... PAD | prefix | text | suffix | K_max MASKs | tail] | |
| After slicing item with K_i: | |
| [PAD ... PAD | prefix | text | suffix | K_i MASKs | tail] | |
| Batch is then padded to max-K-in-batch length. | |
| """ | |
| B = input_ids.size(0) | |
| L_in = input_ids.size(1) | |
| n_tail = self._n_tail | |
| K_per_item = K_per_item.tolist() if torch.is_tensor(K_per_item) else list(K_per_item) | |
| # Per-item sliced sequences. | |
| sliced_ids: List[torch.Tensor] = [] | |
| sliced_mask: List[torch.Tensor] = [] | |
| for i in range(B): | |
| K_i = int(K_per_item[i]) | |
| n_remove = K_max_side - K_i | |
| # Cut tokens [-(n_tail + n_remove) : -n_tail). | |
| # Equivalent: keep input_ids[: L_in - n_tail - n_remove] + input_ids[-n_tail:] | |
| keep_until = L_in - n_tail - n_remove | |
| new_ids = torch.cat([input_ids[i, :keep_until], input_ids[i, -n_tail:] | |
| if n_tail > 0 else input_ids.new_empty(0)]) | |
| new_mask = torch.cat([attention_mask[i, :keep_until], attention_mask[i, -n_tail:] | |
| if n_tail > 0 else attention_mask.new_empty(0)]) | |
| sliced_ids.append(new_ids) | |
| sliced_mask.append(new_mask) | |
| # Pad LEFT to max length so the structural tail stays at position -1. | |
| L_out = max(s.size(0) for s in sliced_ids) | |
| pad_id = (self.tokenizer.pad_token_id or self.tokenizer.eos_token_id or 0) | |
| out_ids = input_ids.new_full((B, L_out), pad_id) | |
| out_mask = attention_mask.new_zeros((B, L_out)) | |
| for i, (s_ids, s_mask) in enumerate(zip(sliced_ids, sliced_mask)): | |
| L_i = s_ids.size(0) | |
| out_ids[i, L_out - L_i:] = s_ids | |
| out_mask[i, L_out - L_i:] = s_mask | |
| return out_ids, out_mask | |
| def _extract_repr_per_K( | |
| self, | |
| hidden: torch.Tensor, # [B, L_out, H] | |
| K_per_item: torch.Tensor, # [B] int K per item | |
| n_K_max: int, # max K across self.k_pre_encoder_options | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Extract per-item repr_hidden using each item's K. | |
| Returns: | |
| repr_hidden: [B, n_K_max, H] — first K_i positions filled, rest zero | |
| quotation_emb: [B, H] — position before first MASK per item | |
| With left-padding, the rightmost positions per item are: | |
| [..., gen_K_i MASKs, tail tokens (n_tail)] | |
| So mask block = positions [L_out - n_tail - K_i, L_out - n_tail). | |
| Quotation = position L_out - n_tail - K_i - 1. | |
| """ | |
| B, L_out, H = hidden.shape | |
| n_tail = self._n_tail | |
| K_per_item = K_per_item.tolist() if torch.is_tensor(K_per_item) else list(K_per_item) | |
| repr_hidden = hidden.new_zeros(B, n_K_max, H) | |
| quotation_emb = hidden.new_zeros(B, H) | |
| for i in range(B): | |
| K_i = int(K_per_item[i]) | |
| mask_end = L_out - n_tail | |
| mask_start = mask_end - K_i | |
| repr_hidden[i, :K_i] = hidden[i, mask_start:mask_end].float() | |
| if mask_start - 1 >= 0: | |
| quotation_emb[i] = hidden[i, mask_start - 1].float() | |
| return repr_hidden, quotation_emb | |
| def _gumbel_st_sample_K( | |
| self, | |
| logits: torch.Tensor, # [B, n_K_pe] | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Gumbel-softmax sample with straight-through. | |
| Returns: | |
| K_per_item: [B] int — actual K integer values to use for slicing | |
| soft_one_hot: [B, n_K_pe] — hard one-hot in fwd, soft in backward | |
| (used for cost regularizer + ST credit) | |
| chosen_prob: [B] — one-hot's value at chosen index (always 1.0 in fwd) | |
| used to inject gradient back through ST | |
| """ | |
| soft_one_hot = F.gumbel_softmax( | |
| logits, tau=max(self.gumbel_temperature, 1e-3), hard=True | |
| ) # [B, n_K_pe] | |
| # Integer K from one-hot | |
| K_idx = soft_one_hot.argmax(dim=-1) # [B] | |
| K_per_item = self.k_pe_options_tensor[K_idx].long() # [B] | |
| # ST credit: chosen_prob is always 1.0 in forward, has gradient via ST | |
| chosen_prob = soft_one_hot.gather(-1, K_idx.unsqueeze(-1)).squeeze(-1) | |
| return K_per_item, soft_one_hot, chosen_prob | |
| def _forward_with_k_pre_encoder( | |
| self, | |
| query_input_ids: torch.Tensor, | |
| query_attention_mask: torch.Tensor, | |
| passage_input_ids: torch.Tensor, | |
| passage_attention_mask: torch.Tensor, | |
| query_content_ids: Optional[List], | |
| passage_content_ids: Optional[List], | |
| labels: torch.Tensor, | |
| ) -> Dict[str, torch.Tensor]: | |
| """Two-stage forward: pre-encoder predicts K_q, K_p per item; main | |
| encoder runs at variable length (per-item K). Real encoding savings | |
| when batch's max-K is below K_max. | |
| Backprop: | |
| - Main encoder: gradient flows through standard InfoNCE on per-item | |
| sliced representations. | |
| - Pre-encoder: gradient flows via Gumbel straight-through on | |
| chosen_prob (always 1.0 in fwd; gradient injected through ST). | |
| Plus a cost regularizer λ * E[K] discourages collapsing to K=K_max. | |
| """ | |
| device = labels.device | |
| K_max_q = self.n_gen_q_tokens | |
| K_max_p = self.n_gen_p_tokens | |
| # ── 1. Pre-encoder forward (cheap — embedding lookup + small MLP). | |
| q_logits = self._pre_encoder_logits( | |
| query_input_ids, query_attention_mask, is_query=True) | |
| p_logits = self._pre_encoder_logits( | |
| passage_input_ids, passage_attention_mask, is_query=False) | |
| # ── 2. Gumbel-ST sample K per item (queries + passages). | |
| K_q_per_item, q_soft, q_chosen = self._gumbel_st_sample_K(q_logits) | |
| K_p_per_item, p_soft, p_chosen = self._gumbel_st_sample_K(p_logits) | |
| # Clamp at the per-side K_max in case the pre-encoder options exceed it. | |
| # (Shouldn't happen by construction but a safety net.) | |
| K_q_per_item = K_q_per_item.clamp(min=1, max=K_max_q) | |
| K_p_per_item = K_p_per_item.clamp(min=1, max=K_max_p) | |
| # ── 3. Slice each item's input to its K (variable-length batch). | |
| q_ids_sliced, q_mask_sliced = self._slice_to_per_item_K( | |
| query_input_ids, query_attention_mask, K_q_per_item, K_max_q) | |
| p_ids_sliced, p_mask_sliced = self._slice_to_per_item_K( | |
| passage_input_ids, passage_attention_mask, K_p_per_item, K_max_p) | |
| # ── 4. Main encoder forward (one pass each at variable length). | |
| need_sparse = self.sparse_weight > 0 | |
| q_hidden, q_logits_main = self._fwd( | |
| q_ids_sliced, q_mask_sliced, need_logits=need_sparse) | |
| p_hidden, p_logits_main = self._fwd( | |
| p_ids_sliced, p_mask_sliced, need_logits=need_sparse) | |
| # ── 5. Per-item repr extraction. | |
| n_K_max = max(self.k_pre_encoder_options) | |
| q_repr_hidden, q_quotation_emb = self._extract_repr_per_K( | |
| q_hidden, K_q_per_item, n_K_max) | |
| p_repr_hidden, p_quotation_emb = self._extract_repr_per_K( | |
| p_hidden, K_p_per_item, n_K_max) | |
| # ── 6. Sparse: max-pool over each item's K positions. | |
| q_sparse_max = None | |
| p_sparse_max = None | |
| if need_sparse and q_logits_main is not None: | |
| # For each item, compute log(1+relu(...)) over its K MASK positions, | |
| # then max-pool across positions. Same as legacy sparse computation. | |
| B_q = q_ids_sliced.size(0) | |
| L_q_out = q_ids_sliced.size(1) | |
| q_sparse_max = q_logits_main.new_zeros(B_q, q_logits_main.size(-1)) | |
| for i in range(B_q): | |
| K_i = int(K_q_per_item[i].item()) | |
| mask_end = L_q_out - self._n_tail | |
| mask_start = mask_end - K_i | |
| slogits = q_logits_main[i, mask_start:mask_end] | |
| # Monotonic trick: max over K positions first → log1p(relu()) on [V]. | |
| q_sparse_max[i] = torch.log1p(torch.relu(slogits.max(dim=0).values)) | |
| if need_sparse and p_logits_main is not None: | |
| B_p = p_ids_sliced.size(0) | |
| L_p_out = p_ids_sliced.size(1) | |
| p_sparse_max = p_logits_main.new_zeros(B_p, p_logits_main.size(-1)) | |
| for i in range(B_p): | |
| K_i = int(K_p_per_item[i].item()) | |
| mask_end = L_p_out - self._n_tail | |
| mask_start = mask_end - K_i | |
| slogits = p_logits_main[i, mask_start:mask_end] | |
| # Monotonic trick: max over K positions first → log1p(relu()) on [V]. | |
| p_sparse_max[i] = torch.log1p(torch.relu(slogits.max(dim=0).values)) | |
| # ── 7. Normalize. | |
| if self.normalize: | |
| q_repr_hidden = F.normalize(q_repr_hidden, p=2, dim=-1) | |
| q_quotation_emb = F.normalize(q_quotation_emb, p=2, dim=-1) | |
| p_repr_hidden = F.normalize(p_repr_hidden, p=2, dim=-1) | |
| p_quotation_emb = F.normalize(p_quotation_emb, p=2, dim=-1) | |
| # ── 8. ST credit injection — multiply repr by chosen_prob (always 1.0 | |
| # in fwd, has gradient through ST in backward). This is the standard | |
| # Gumbel-ST trick for credit assignment when downstream operations | |
| # (slicing, reshape) are non-differentiable. | |
| q_repr_hidden = q_repr_hidden * q_chosen.float().unsqueeze(-1).unsqueeze(-1) | |
| p_repr_hidden = p_repr_hidden * p_chosen.float().unsqueeze(-1).unsqueeze(-1) | |
| q_quotation_emb = q_quotation_emb * q_chosen.float().unsqueeze(-1) | |
| p_quotation_emb = p_quotation_emb * p_chosen.float().unsqueeze(-1) | |
| # Build repr dicts and apply content filter to sparse. | |
| q_repr = {'repr_hidden': q_repr_hidden, 'quotation_emb': q_quotation_emb} | |
| p_repr = {'repr_hidden': p_repr_hidden, 'quotation_emb': p_quotation_emb} | |
| if q_sparse_max is not None: | |
| if query_content_ids is not None: | |
| q_sparse_max = filter_sparse(q_sparse_max, query_content_ids) | |
| q_repr['sparse_acts'] = q_sparse_max | |
| if p_sparse_max is not None: | |
| if passage_content_ids is not None: | |
| p_sparse_max = filter_sparse(p_sparse_max, passage_content_ids) | |
| p_repr['sparse_acts'] = p_sparse_max | |
| # ── 9. Cross-GPU gather + recompute labels for gathered batch. | |
| q_repr = self._gather_repr(q_repr) | |
| p_repr = self._gather_repr(p_repr) | |
| B_q_all = q_repr['repr_hidden'].size(0) | |
| B_p_all = p_repr['repr_hidden'].size(0) | |
| n_paq_g = B_p_all // B_q_all | |
| labels_g = torch.arange(B_q_all, device=device) * n_paq_g | |
| # ── 10. Standard InfoNCE on per-item-K representations (zero-padded | |
| # positions contribute nothing thanks to MaxSim's max-over-positions). | |
| loss_dict = self.compute_loss(q_repr, p_repr, labels_g) | |
| # ── 11. Cost regularizer — λ * E[K_q + K_p] using soft probs (so it's | |
| # differentiable w.r.t. logits). Discourages collapsing to K_max. | |
| if self.k_cost_lambda > 0: | |
| K_options_t = self.k_pe_options_tensor # [n_K_pe] | |
| expected_K_q = (q_soft * K_options_t).sum(dim=-1).mean() | |
| expected_K_p = (p_soft * K_options_t).sum(dim=-1).mean() | |
| cost_loss = (expected_K_q + expected_K_p) / float(max(K_options_t.max().item(), 1.0)) | |
| loss_dict['loss'] = loss_dict['loss'] + self.k_cost_lambda * cost_loss | |
| loss_dict['k_cost'] = cost_loss.detach() | |
| loss_dict['expected_K_q'] = expected_K_q.detach() | |
| loss_dict['expected_K_p'] = expected_K_p.detach() | |
| # Diagnostics: marginal K distributions. | |
| with torch.no_grad(): | |
| for i, K in enumerate(self.k_pre_encoder_options): | |
| loss_dict[f'pe_kq_p{K}'] = q_soft[:, i].mean().detach() | |
| loss_dict[f'pe_kp_p{K}'] = p_soft[:, i].mean().detach() | |
| # DDP loss scaling (counters HF Trainer's gradient averaging). | |
| if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: | |
| loss_dict['loss'] = loss_dict['loss'] * torch.distributed.get_world_size() | |
| return loss_dict | |
| # ---------------------------------------------------------------- | |
| # Encode — single-pass or multi-step denoising (inference) | |
| # ---------------------------------------------------------------- | |
| def encode( | |
| self, | |
| input_ids: torch.Tensor, # [B, L] | |
| attention_mask: torch.Tensor, # [B, L] | |
| is_query: bool = False, | |
| compute_sparse: Optional[bool] = None, | |
| content_token_ids: Optional[List] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| PromptReps encoding. | |
| num_denoise_steps == 1: single forward pass (fast path, differentiable, | |
| used in both training and inference). | |
| num_denoise_steps > 1: iterative denoising loop (inference). | |
| repr_hidden[i, k] is captured at the step when gen token k transitions | |
| MASK → decoded, benefiting from previously-decoded token context. | |
| Uniform schedule: n_per_step = K // n_steps tokens decoded per step. | |
| quotation_emb is captured at the final step. | |
| Call under torch.no_grad() for inference; use forward() for training. | |
| Returns: | |
| repr_hidden [B, K, H] ColBERT repr (one vector per gen token) | |
| quotation_emb [B, H] closing " token (single-dense repr) | |
| sparse_acts [B, V] sum of log(1+relu(logit)) across K positions | |
| """ | |
| K = self._k(is_query) | |
| B = input_ids.size(0) | |
| L = input_ids.size(1) | |
| n_steps = self.num_denoise_steps | |
| device = input_ids.device | |
| g_start = L - K - self._n_tail # left-padded: MASK block then structural tail | |
| if compute_sparse is None: | |
| need_sparse = self.sparse_weight > 0 | |
| else: | |
| need_sparse = compute_sparse | |
| # Initialised here so both the single- and multi-step branches always leave | |
| # these bound — the multi-step path only ever populates sparse_max. | |
| sparse_max: Optional[torch.Tensor] = None | |
| sparse_per_pos: Optional[torch.Tensor] = None | |
| # ── Single-pass fast path (n_steps ≤ 1) ────────────────────────────── | |
| if n_steps <= 1: | |
| hidden, logits = self._fwd(input_ids, attention_mask, need_logits=need_sparse) | |
| repr_hidden = hidden[:, g_start:g_start + K, :].float() # [B, K, H] | |
| quotation_emb = hidden[:, g_start - 1, :].float() # [B, H] | |
| if need_sparse and logits is not None: | |
| # Per-position sparse acts for ColBERT-style multi-vector sparse | |
| # (caller may pop result['sparse_acts_per_pos'] to do per-K topk). | |
| # Keep in bf16 — the downstream topk → ×100 → int round path | |
| # doesn't need fp32 precision, and the fp32 cast was the | |
| # dominant memory bandwidth cost at K=16 (312 MB → 156 MB/batch | |
| # for V=152k, B=32; 2× reduction). | |
| gen_logits = logits[:, g_start:g_start + K, :] # [B, K, V] bf16 | |
| sparse_per_pos = torch.log1p(torch.relu(gen_logits)) # [B, K, V] bf16 | |
| # sparse_max stays fp32 (small, [B, V]) for downstream consumers. | |
| sparse_max = sparse_per_pos.max(dim=1).values.float() # [B, V] fp32 | |
| # ── Multi-step denoising loop (n_steps > 1, inference) ─────────────── | |
| else: | |
| curr_ids = input_ids.clone() | |
| n_per_step = max(1, K // n_steps) | |
| vocab_size = self.backbone.config.vocab_size | |
| repr_buf = torch.zeros(B, K, self.hidden_size, device=device) | |
| repr_saved = torch.zeros(B, K, dtype=torch.bool, device=device) | |
| sparse_max = (torch.zeros(B, vocab_size, device=device) | |
| if need_sparse else None) | |
| quotation_emb = None | |
| mask_4d = self._build_4d_mask(L, attention_mask) if (self.adapter.needs_4d_mask() if self.adapter else True) else None | |
| for step in range(n_steps): | |
| hidden, logits = self._fwd(curr_ids, attention_mask, need_logits=True, mask_4d=mask_4d) | |
| is_last = (step == n_steps - 1) | |
| # ── Vectorized over batch (replaces `for i in range(B): ... .item() ...`). | |
| # Behavior preserved: greedy argmax, neg-entropy confidence, top-n_per_step | |
| # transitions per item; on last step all remaining masks transition. | |
| gen_ids_all = curr_ids[:, g_start:g_start + K] # [B, K] int | |
| mask_bool_all = (gen_ids_all == self.mask_token_id) # [B, K] bool | |
| ml_all = logits[:, g_start:g_start + K, :] # [B, K, V] bf16 | |
| # Confidence + argmax (matches _sample_with_confidence(alg='entropy')) | |
| probs = F.softmax(ml_all.float(), dim=-1) # [B, K, V] fp32 | |
| log_probs = torch.log(probs + 1e-10) | |
| conf = (probs * log_probs).sum(dim=-1) # [B, K] (neg entropy) | |
| x0_pred = probs.argmax(dim=-1) # [B, K] | |
| # Restrict topk to currently-masked positions | |
| conf = torch.where(mask_bool_all, conf, torch.full_like(conf, -float('inf'))) | |
| if is_last: | |
| xfer_mask = mask_bool_all # all remaining masks | |
| else: | |
| n_tr = min(n_per_step, K) | |
| _, top_idx = conf.topk(n_tr, dim=-1) # [B, n_tr] | |
| xfer_mask = torch.zeros_like(mask_bool_all) | |
| xfer_mask.scatter_(1, top_idx, True) | |
| xfer_mask = xfer_mask & mask_bool_all # only actually-masked positions | |
| # ── Vectorized state updates ──────────────────────────────────── | |
| # 1. curr_ids: replace MASK with predicted at xfer positions | |
| if xfer_mask.any(): | |
| new_K = torch.where(xfer_mask, x0_pred, gen_ids_all) | |
| # In-place via clone (curr_ids may be a view of input_ids) | |
| curr_ids = curr_ids.clone() | |
| curr_ids[:, g_start:g_start + K] = new_K | |
| # 2. repr_buf: save hidden at to_save positions (xfer & not yet saved) | |
| to_save = xfer_mask & ~repr_saved | |
| if to_save.any(): | |
| hidden_chunk = hidden[:, g_start:g_start + K, :] # [B, K, H] | |
| repr_buf = torch.where(to_save.unsqueeze(-1), hidden_chunk, repr_buf) | |
| # 3. sparse_max: max over to_save positions of log(1+relu(logits)) | |
| if sparse_max is not None: | |
| sparse_chunk = torch.log1p(torch.relu(ml_all)) # [B, K, V] bf16 | |
| masked_sparse = torch.where( | |
| to_save.unsqueeze(-1).expand_as(sparse_chunk), | |
| sparse_chunk, | |
| torch.full_like(sparse_chunk, -float('inf'))) | |
| new_max = masked_sparse.max(dim=1).values.float() # [B, V] fp32 | |
| sparse_max = torch.max(sparse_max, new_max) | |
| repr_saved = repr_saved | to_save | |
| if is_last: | |
| quotation_emb = hidden[:, g_start - 1, :].float() | |
| # Catch positions that were never masked (rare but possible | |
| # if input already had a non-MASK token in the K block). | |
| unsaved = ~repr_saved | |
| if unsaved.any(): | |
| hidden_chunk = hidden[:, g_start:g_start + K, :] | |
| repr_buf = torch.where(unsaved.unsqueeze(-1), hidden_chunk, repr_buf) | |
| if sparse_max is not None: | |
| sparse_chunk = torch.log1p(torch.relu(ml_all)) | |
| masked_sparse = torch.where( | |
| unsaved.unsqueeze(-1).expand_as(sparse_chunk), | |
| sparse_chunk, | |
| torch.full_like(sparse_chunk, -float('inf'))) | |
| new_max = masked_sparse.max(dim=1).values.float() | |
| sparse_max = torch.max(sparse_max, new_max) | |
| repr_saved = repr_saved | unsaved | |
| repr_buf = torch.nan_to_num(repr_buf) | |
| repr_hidden = repr_buf | |
| quotation_emb = torch.nan_to_num(quotation_emb) | |
| # ── Normalize ───────────────────────────────────────────────────────── | |
| if self.normalize: | |
| quotation_emb = F.normalize(quotation_emb, p=2, dim=-1) | |
| repr_hidden = F.normalize(repr_hidden, p=2, dim=-1) | |
| result: Dict[str, torch.Tensor] = { | |
| 'repr_hidden': repr_hidden, | |
| 'quotation_emb': quotation_emb, | |
| } | |
| if sparse_max is not None: | |
| if content_token_ids is not None: | |
| sparse_max = filter_sparse(sparse_max, content_token_ids) | |
| result['sparse_acts'] = sparse_max | |
| # Per-position sparse [B, K, V] — single-pass only, K>1 | |
| if sparse_per_pos is not None and K > 1: | |
| result['sparse_acts_per_pos'] = sparse_per_pos # unfiltered; caller applies content filter | |
| return result | |
| # ---------------------------------------------------------------- | |
| # Cross-GPU negative sharing | |
| # ---------------------------------------------------------------- | |
| def _dist_gather(t: torch.Tensor) -> torch.Tensor: | |
| """All-gather tensors across GPUs with gradient passthrough.""" | |
| if not (torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1): | |
| return t | |
| gathered = [torch.zeros_like(t) for _ in range(torch.distributed.get_world_size())] | |
| torch.distributed.all_gather(gathered, t.contiguous()) | |
| # Replace own shard with original (keeps gradient) | |
| gathered[torch.distributed.get_rank()] = t | |
| return torch.cat(gathered, dim=0) | |
| def _gather_repr(self, repr_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| """Gather all representation tensors across GPUs.""" | |
| out = {} | |
| for k, v in repr_dict.items(): | |
| out[k] = self._dist_gather(v) | |
| return out | |
| # ---------------------------------------------------------------- | |
| # MaxSim (ColBERT-style) | |
| # ---------------------------------------------------------------- | |
| def maxsim(q: torch.Tensor, p: torch.Tensor) -> torch.Tensor: | |
| """ | |
| q: [B_q, k, H], p: [B_p, k, H] → [B_q, B_p] MaxSim scores. | |
| """ | |
| sims = torch.einsum('ikh,jlh->ijkl', q, p) # [B_q, B_p, k_q, k_p] | |
| return sims.max(dim=-1).values.sum(dim=-1) # [B_q, B_p] | |
| # ---------------------------------------------------------------- | |
| # Auxiliary losses (diffusion-native) | |
| # ---------------------------------------------------------------- | |
| def _corrupt_text( | |
| self, | |
| input_ids: torch.Tensor, # [B, L] | |
| attention_mask: torch.Tensor, # [B, L] | |
| rate: float, | |
| ) -> torch.Tensor: | |
| """Randomly replace rate% of text tokens with MASK for corruption augmentation. | |
| Only corrupts real text tokens (not padding, not gen MASKs, not EOS). | |
| Returns a new tensor (does not modify input in-place). | |
| """ | |
| B, L = input_ids.shape | |
| K = self.n_gen_p_tokens # only called on passages | |
| g_start = L - K - self._n_tail | |
| corrupted = input_ids.clone() | |
| is_real = attention_mask.bool() | |
| is_gen = torch.zeros(L, dtype=torch.bool, device=input_ids.device) | |
| is_gen[g_start:] = True | |
| is_candidate = is_real & ~is_gen.unsqueeze(0) & (input_ids != self.mask_token_id) | |
| for i in range(B): | |
| cand_idx = torch.where(is_candidate[i])[0] | |
| if cand_idx.numel() == 0: | |
| continue | |
| n_mask = max(1, int(cand_idx.numel() * rate)) | |
| perm = torch.randperm(cand_idx.numel(), device=input_ids.device)[:n_mask] | |
| corrupted[i, cand_idx[perm]] = self.mask_token_id | |
| return corrupted | |
| def _apply_text_masking( | |
| self, | |
| input_ids: torch.Tensor, # [B, L] | |
| attention_mask: torch.Tensor, # [B, L] | |
| ) -> Tuple[torch.Tensor, torch.Tensor, float]: | |
| """Randomly mask text tokens (NOT generation MASKs) for denoising auxiliary loss. | |
| Returns: | |
| corrupted_ids: input_ids with some text tokens replaced by MASK | |
| denoise_targets: original token IDs at masked positions (-100 elsewhere) | |
| mask_ratio: fraction of tokens masked (t for 1/t weighting) | |
| """ | |
| B, L = input_ids.shape | |
| K = self.n_gen_p_tokens # called on passages | |
| g_start = L - K - self._n_tail # left-padded: gen block + structural tail | |
| corrupted = input_ids.clone() | |
| targets = torch.full_like(input_ids, -100) # -100 = ignore in CE | |
| # Candidate text positions: not padding, not in gen/EOS block, not already MASK | |
| is_real = attention_mask.bool() # [B, L] | |
| is_gen = torch.zeros(L, dtype=torch.bool, device=input_ids.device) | |
| is_gen[g_start:] = True # marks MASK positions + EOS as non-candidate | |
| is_candidate = is_real & ~is_gen.unsqueeze(0) & (input_ids != self.mask_token_id) | |
| for i in range(B): | |
| cand_idx = torch.where(is_candidate[i])[0] | |
| if cand_idx.numel() == 0: | |
| continue | |
| n_mask = max(1, int(cand_idx.numel() * self.denoise_mask_ratio)) | |
| perm = torch.randperm(cand_idx.numel(), device=input_ids.device)[:n_mask] | |
| mask_positions = cand_idx[perm] | |
| targets[i, mask_positions] = input_ids[i, mask_positions] | |
| corrupted[i, mask_positions] = self.mask_token_id | |
| mask_ratio = self.denoise_mask_ratio | |
| return corrupted, targets, mask_ratio | |
| def compute_denoising_loss( | |
| self, | |
| logits: torch.Tensor, # [B, L, V] from forward pass on corrupted input | |
| targets: torch.Tensor, # [B, L] with -100 for non-masked positions | |
| mask_ratio: float, # t for 1/t weighting | |
| ) -> torch.Tensor: | |
| """Compute Dream/LLaDA-style denoising loss: weighted CE on masked text tokens.""" | |
| # Flatten and compute CE (ignores -100 positions automatically) | |
| loss = F.cross_entropy( | |
| logits.view(-1, logits.size(-1)), | |
| targets.view(-1), | |
| ignore_index=-100, | |
| ) | |
| # Weight by 1/t (matching Dream/LLaDA pre-training ELBO) | |
| return loss / max(mask_ratio, 1e-3) | |
| def compute_diversity_loss( | |
| self, | |
| repr_hidden: torch.Tensor, # [B, K, H] | |
| ) -> torch.Tensor: | |
| """Push K representation vectors to be diverse (low pairwise cosine similarity).""" | |
| K = repr_hidden.size(1) | |
| if K <= 1: | |
| return torch.tensor(0.0, device=repr_hidden.device) | |
| # Pairwise cosine similarity [B, K, K] — repr_hidden is already L2-normalised | |
| sim_matrix = torch.bmm(repr_hidden, repr_hidden.transpose(1, 2)) # [B, K, K] | |
| # Mean of upper triangle (exclude diagonal) | |
| mask = torch.triu(torch.ones(K, K, device=repr_hidden.device), diagonal=1).bool() | |
| pairwise_sims = sim_matrix[:, mask] # [B, K*(K-1)/2] | |
| # Hinge: penalize similarities above 0 (push toward orthogonal) | |
| diversity_loss = torch.relu(pairwise_sims).mean() | |
| return diversity_loss | |
| def _mean_offdiag_cos(repr_hidden: torch.Tensor) -> torch.Tensor: | |
| """Mean off-diagonal cosine similarity over K vectors.""" | |
| K = repr_hidden.size(1) | |
| if K <= 1: | |
| return torch.tensor(0.0, device=repr_hidden.device) | |
| normed = F.normalize(repr_hidden.float(), p=2, dim=-1) | |
| sim = torch.bmm(normed, normed.transpose(1, 2)) | |
| mask = torch.triu(torch.ones(K, K, device=repr_hidden.device), diagonal=1).bool() | |
| return sim[:, mask].mean() | |
| def _dense_debug_stats( | |
| self, | |
| q_repr: Dict[str, torch.Tensor], | |
| p_repr: Dict[str, torch.Tensor], | |
| labels: torch.Tensor, | |
| q_raw: Optional[torch.Tensor] = None, | |
| p_raw: Optional[torch.Tensor] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """Summaries that explain whether dense loss comes from poor score gaps or collapsed vectors.""" | |
| scores = self.maxsim(q_repr['repr_hidden'], p_repr['repr_hidden']) | |
| row_idx = torch.arange(scores.size(0), device=scores.device) | |
| pos = scores[row_idx, labels] | |
| neg = scores.clone() | |
| neg[row_idx, labels] = float('-inf') | |
| hard_neg = neg.max(dim=1).values | |
| stats = { | |
| 'debug_pos_score': pos.mean().detach(), | |
| 'debug_hardneg_score': hard_neg.mean().detach(), | |
| 'debug_score_gap': (pos - hard_neg).mean().detach(), | |
| 'debug_pos_beats_hardneg': (pos > hard_neg).float().mean().detach(), | |
| 'debug_q_mask_cos': self._mean_offdiag_cos(q_repr['repr_hidden']).detach(), | |
| 'debug_p_mask_cos': self._mean_offdiag_cos(p_repr['repr_hidden']).detach(), | |
| } | |
| if q_raw is not None: | |
| stats['debug_q_mask_norm'] = q_raw.float().norm(dim=-1).mean().detach() | |
| stats['debug_q_mask_raw_cos'] = self._mean_offdiag_cos(q_raw).detach() | |
| if p_raw is not None: | |
| stats['debug_p_mask_norm'] = p_raw.float().norm(dim=-1).mean().detach() | |
| stats['debug_p_mask_raw_cos'] = self._mean_offdiag_cos(p_raw).detach() | |
| return stats | |
| # ---------------------------------------------------------------- | |
| # Loss | |
| # ---------------------------------------------------------------- | |
| def compute_loss( | |
| self, | |
| q_repr: Dict[str, torch.Tensor], | |
| p_repr: Dict[str, torch.Tensor], | |
| labels: torch.Tensor, # [B_q] index of positive in B_p | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Dense loss: ColBERT MaxSim on repr_hidden (MASK positions). | |
| For K=1 this reduces to dot product of the single MASK vector. | |
| Sparse InfoNCE (if sparse_weight > 0): raw dot product + temperature. | |
| K-adapter: when self.use_k_adapter is True, dispatch to the per-cell | |
| loss path that trains both the retriever and the K-routing head. | |
| """ | |
| if self.use_k_adapter: | |
| return self._compute_loss_with_k_adapter(q_repr, p_repr, labels) | |
| device = labels.device | |
| K = q_repr['repr_hidden'].size(1) | |
| total_loss = torch.tensor(0.0, device=device) | |
| # Primary dense: ColBERT MaxSim on repr_hidden (MASK positions). | |
| colbert_scores = self.maxsim(q_repr['repr_hidden'], p_repr['repr_hidden']) | |
| colbert_loss = F.cross_entropy(colbert_scores / K / self.temperature, labels) | |
| total_loss = self.dense_weight * colbert_loss | |
| # Sparse InfoNCE — raw dot product, same as inference. | |
| # Content-token filtering (PromptReps-style) applied in forward() keeps | |
| # scores small (~10-100), so no clamping or normalization needed. | |
| sparse_loss = torch.tensor(0.0, device=device) | |
| if (self.sparse_weight > 0 | |
| and 'sparse_acts' in q_repr | |
| and 'sparse_acts' in p_repr): | |
| sparse_scores = q_repr['sparse_acts'] @ p_repr['sparse_acts'].T | |
| sparse_loss = F.cross_entropy(sparse_scores, labels) | |
| total_loss = total_loss + self.sparse_weight * sparse_loss | |
| # Auxiliary: diversity loss on repr_hidden (if enabled) | |
| diversity_loss = torch.tensor(0.0, device=device) | |
| if self.diversity_weight > 0 and K > 1: | |
| # Average over query and passage batches (both need diverse vectors) | |
| diversity_loss = 0.5 * ( | |
| self.compute_diversity_loss(q_repr['repr_hidden']) + | |
| self.compute_diversity_loss(p_repr['repr_hidden']) | |
| ) | |
| total_loss = total_loss + self.diversity_weight * diversity_loss | |
| result = { | |
| 'loss': total_loss, | |
| 'loss_dense': (self.dense_weight * colbert_loss).detach(), | |
| 'loss_sparse': sparse_loss.detach(), | |
| } | |
| if self.diversity_weight > 0: | |
| result['loss_diversity'] = diversity_loss.detach() | |
| # loss_denoising is added in forward() since it needs logits | |
| return result | |
| # ---------------------------------------------------------------- | |
| # K-adapter loss | |
| # ---------------------------------------------------------------- | |
| def _compute_loss_with_k_adapter( | |
| self, | |
| q_repr: Dict[str, torch.Tensor], | |
| p_repr: Dict[str, torch.Tensor], | |
| labels: torch.Tensor, | |
| ) -> Dict[str, torch.Tensor]: | |
| """Joint training of retriever + K-routing head. | |
| For each (K_q, K_p) cell in `self.k_adapter_options × self.k_adapter_options`: | |
| 1. Slice prefix of q_repr / p_repr at that K. | |
| 2. Compute MaxSim → fuse with sparse → CE loss vs in-batch negatives. | |
| 3. Per-query loss tensor of shape [B_q, n_K_q, n_K_p]. | |
| Teacher distribution per query: softmax(-per_cell_loss / τ_T). | |
| Adapter (small MLP on quotation_emb) predicts a softmax over cells. | |
| Adapter loss = KL(teacher || adapter_probs) — the adapter learns to | |
| mimic which cell the model itself prefers (lowest CE) without any | |
| oracle/MRR involvement. | |
| Retriever loss = sum_c adapter_probs[i, c] * per_cell_loss[i, c]: | |
| - early in training (uniform adapter): ~ matryoshka-style avg over | |
| all cells, every prefix gets gradient → all K positions trained. | |
| - late in training (peaky adapter): gradient concentrates on | |
| adapter-preferred cell → cell specialization. | |
| """ | |
| device = labels.device | |
| K_opts = self.k_adapter_options | |
| n_K = self.n_K | |
| K_max_q = q_repr['repr_hidden'].size(1) | |
| K_max_p = p_repr['repr_hidden'].size(1) | |
| # If model was instantiated at smaller K than max(k_adapter_options), | |
| # filter the option list down to what we actually have positions for. | |
| valid_q = [k for k in K_opts if k <= K_max_q] | |
| valid_p = [k for k in K_opts if k <= K_max_p] | |
| # Index lookup back into the n_K-wide cell grid (so options outside | |
| # the model's K-range get teacher prob 0 / adapter logits ignored). | |
| idx_q = [K_opts.index(k) for k in valid_q] | |
| idx_p = [K_opts.index(k) for k in valid_p] | |
| # ── 1. Sparse score (K-invariant — added to every cell uniformly). | |
| # Stays max-pooled so it doesn't depend on K choice. | |
| sparse_scores = None | |
| if (self.sparse_weight > 0 | |
| and 'sparse_acts' in q_repr | |
| and 'sparse_acts' in p_repr): | |
| sparse_scores = q_repr['sparse_acts'] @ p_repr['sparse_acts'].T | |
| # Note: sparse uses raw dot product (no temperature), matching the | |
| # legacy compute_loss path. Adding it BEFORE softmax in CE won't | |
| # change ranking since it shifts every doc by the same per-query | |
| # amount — but it DOES affect per-cell CE values (different cells | |
| # have different dense magnitudes; sparse adds a constant). This | |
| # is intentional: the per-cell loss reflects the actual fused-score | |
| # contrast the model is trained on. | |
| # | |
| # We deliberately apply temperature scaling to the dense part only, | |
| # mirroring how inference fuses dense (MaxSim/K) + sparse_weight*sparse. | |
| B_q = q_repr['repr_hidden'].size(0) | |
| # ── 2. Per-cell dense scores → per-cell CE losses. | |
| # [B_q, B_p, K_max_q, K_max_p] is the full per-position dot product. | |
| # We compute MaxSim per (K_q, K_p) by slicing prefixes. | |
| per_cell_losses = q_repr['repr_hidden'].new_zeros(B_q, n_K, n_K) | |
| # Precompute the "all positions" inner-product tensor once: | |
| # sims_full[b_q, b_p, i, j] = q_repr[b_q, i] · p_repr[b_q_p_pair, j] | |
| # then maxsim(K_q, K_p) = sims_full[:, :, :K_q, :K_p].max(-1).sum(-1) | |
| sims_full = torch.einsum('bih,djh->bdij', | |
| q_repr['repr_hidden'].float(), | |
| p_repr['repr_hidden'].float()) # [B_q, B_p, K_max_q, K_max_p] | |
| for ii, K_q in enumerate(valid_q): | |
| ki = idx_q[ii] | |
| for jj, K_p in enumerate(valid_p): | |
| kj = idx_p[jj] | |
| # Slice → max over passage positions → sum over query positions. | |
| cell_dense = sims_full[:, :, :K_q, :K_p].max(dim=-1).values.sum(dim=-1) | |
| cell_dense = cell_dense / K_q # MaxSim normalization, matches legacy | |
| cell_score = self.dense_weight * cell_dense / self.temperature | |
| if sparse_scores is not None: | |
| cell_score = cell_score + self.sparse_weight * sparse_scores | |
| # Per-query CE — reduction='none' for [B_q] tensor, no batch mean. | |
| cell_loss = F.cross_entropy(cell_score, labels, reduction='none') # [B_q] | |
| per_cell_losses[:, ki, kj] = cell_loss | |
| # ── 3. Teacher distribution: softmax(-loss/τ_T) over cells. | |
| # Lower loss (better separation) → higher teacher weight. | |
| flat_losses = per_cell_losses.view(B_q, n_K * n_K) # [B_q, n_K^2] | |
| teacher_probs = F.softmax( | |
| -flat_losses.detach() / max(self.teacher_temperature, 1e-6), dim=-1 | |
| ) | |
| # ── 4. Adapter prediction. Input: query's quotation_emb (K-invariant | |
| # content-aware summary). Output: distribution over n_K^2 cells. | |
| if 'quotation_emb' not in q_repr: | |
| raise RuntimeError( | |
| "K-adapter requires 'quotation_emb' in q_repr. " | |
| "encode() must populate it before compute_loss()." | |
| ) | |
| # Match adapter weight dtype (DeepSpeed/AMP may cast adapter to bf16). | |
| # quotation_emb is fp32 from encode(); cast it to whatever the linear | |
| # weight is, run the MLP, then cast logits back to fp32 for stable | |
| # log_softmax + KL arithmetic. | |
| adapter_dtype = next(self.k_adapter.parameters()).dtype | |
| adapter_input = q_repr['quotation_emb'].to(dtype=adapter_dtype) # [B_q, H] | |
| adapter_logits = self.k_adapter(adapter_input).float() # [B_q, n_K^2] | |
| adapter_log_probs = F.log_softmax(adapter_logits, dim=-1) | |
| adapter_probs = adapter_log_probs.exp() | |
| # ── 5. Adapter loss: KL( teacher || adapter ). | |
| L_adapter = F.kl_div(adapter_log_probs, teacher_probs, reduction='batchmean') | |
| # ── 6. Retriever loss: soft mixture over cells weighted by adapter. | |
| # Early training: adapter ~uniform → every cell gets gradient | |
| # (matryoshka-style — all K prefixes become useful). | |
| # Late training: adapter peaky → mostly the chosen cell gets | |
| # gradient → specialization at that K. | |
| L_retriever = (adapter_probs * flat_losses).sum(dim=-1).mean() | |
| total_loss = L_retriever + self.adapter_weight * L_adapter | |
| # Diagnostics: marginal K_q and K_p distributions averaged over batch. | |
| with torch.no_grad(): | |
| cell_grid = adapter_probs.view(B_q, n_K, n_K) # [B_q, n_K_q, n_K_p] | |
| kq_dist = cell_grid.sum(dim=-1).mean(dim=0) # [n_K] | |
| kp_dist = cell_grid.sum(dim=-2).mean(dim=0) # [n_K] | |
| entropy = -(adapter_probs * adapter_log_probs).sum(dim=-1).mean() | |
| # Standalone sparse-only CE for logging (mirrors the legacy | |
| # `loss_sparse` field). This is the InfoNCE on raw sparse scores | |
| # — it's NOT used for the gradient (sparse already contributes | |
| # via the per-cell dense+sparse fused score). Reported just so | |
| # users can monitor sparse-only retrieval quality during training. | |
| if sparse_scores is not None: | |
| sparse_only_loss = F.cross_entropy(sparse_scores, labels) | |
| else: | |
| sparse_only_loss = torch.tensor(0.0, device=device) | |
| result = { | |
| 'loss': total_loss, | |
| 'loss_dense': L_retriever.detach(), | |
| 'loss_sparse': sparse_only_loss.detach(), # diagnostic only — not used for gradient | |
| 'loss_adapter': L_adapter.detach(), | |
| 'adapter_entropy': entropy.detach(), | |
| } | |
| # log marginals: kq_p1, kq_p2, kq_p4, kq_p8, kq_p16, kp_p1, ... | |
| for i, K in enumerate(K_opts): | |
| if i < kq_dist.shape[0]: | |
| result[f'kq_p{K}'] = kq_dist[i].detach() | |
| result[f'kp_p{K}'] = kp_dist[i].detach() | |
| return result | |
| # ---------------------------------------------------------------- | |
| # HF Trainer-compatible forward | |
| # ---------------------------------------------------------------- | |
| def forward( | |
| self, | |
| query_input_ids: torch.Tensor, | |
| query_attention_mask: torch.Tensor, | |
| passage_input_ids: torch.Tensor, | |
| passage_attention_mask: torch.Tensor, | |
| query_content_ids: Optional[List] = None, | |
| passage_content_ids: Optional[List] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| passages layout: [pos_0, neg_0_0, …, neg_0_M, pos_1, neg_1_0, …] | |
| positive for query i is at index i * (1 + n_neg). | |
| num_denoise_steps == 1: single-pass fast path. | |
| num_denoise_steps > 1: multi-step denoising training. | |
| Runs the full denoising loop (each step decodes n_per_step tokens, | |
| storing frozen hidden states in repr_buf), then computes loss once | |
| at the final step. Gradient flows from the final mixed representation | |
| only — no dilution across steps. Uniform unmasking: n_per_step = K // n_steps. | |
| """ | |
| B_q = query_input_ids.size(0) | |
| B_p = passage_input_ids.size(0) | |
| assert B_p % B_q == 0, f"B_p ({B_p}) must be divisible by B_q ({B_q})" | |
| n_paq = B_p // B_q | |
| device = query_input_ids.device | |
| labels = torch.arange(B_q, device=device) * n_paq | |
| n_steps = self.num_denoise_steps | |
| # ── Corruption augmentation (denoising-conditioned training) ───────── | |
| # Randomly mask text tokens (NOT gen MASKs) in passages to force robust, | |
| # diverse multi-vector representations. Unique to diffusion: the model | |
| # was pretrained on this exact corruption — zero distribution shift. | |
| if self.training and self.corruption_rate > 0: | |
| import random as _rng | |
| t = _rng.uniform(0.0, self.corruption_rate) | |
| if t > 0.01: # skip very small rates | |
| passage_input_ids = self._corrupt_text( | |
| passage_input_ids, passage_attention_mask, t) | |
| # ── Soft-token multi-step (differentiable denoising) ───────────────── | |
| if self.soft_denoising and n_steps > 1: | |
| return self._soft_multistep_forward( | |
| query_input_ids, query_attention_mask, | |
| passage_input_ids, passage_attention_mask, | |
| query_content_ids, passage_content_ids, | |
| n_steps, | |
| ) | |
| # ── K pre-encoder (two-stage encoding) ─────────────────────────────── | |
| # Predict K per item BEFORE main encoder, slice each input to its K, | |
| # then run main encoder at variable-length batch. True encoding | |
| # savings (vs the K-adapter which always encodes at K_max). | |
| if self.use_k_pre_encoder: | |
| return self._forward_with_k_pre_encoder( | |
| query_input_ids, query_attention_mask, | |
| passage_input_ids, passage_attention_mask, | |
| query_content_ids, passage_content_ids, | |
| labels, | |
| ) | |
| # ── Single-step fast path ───────────────────────────────────────────── | |
| if n_steps <= 1: | |
| K_q, K_p = self.n_gen_q_tokens, self.n_gen_p_tokens | |
| n_tail = self._n_tail | |
| need_sparse = self.sparse_weight > 0 | |
| # Optionally corrupt passages for denoising auxiliary loss | |
| if self.denoising_weight > 0: | |
| p_corrupted, p_denoise_targets, mask_ratio = self._apply_text_masking( | |
| passage_input_ids, passage_attention_mask) | |
| p_ids_for_fwd = p_corrupted | |
| else: | |
| p_ids_for_fwd = passage_input_ids | |
| need_logits = need_sparse or (self.denoising_weight > 0) | |
| L_q, L_p = query_input_ids.size(1), p_ids_for_fwd.size(1) | |
| if K_q == K_p: | |
| # ── Symmetric path: concat Q+P into one _fwd (memory-efficient). | |
| if L_q < L_p: | |
| pad = L_p - L_q | |
| q_ids_cat = F.pad(query_input_ids, (pad, 0)) # left-pad | |
| q_mask_cat = F.pad(query_attention_mask, (pad, 0)) | |
| p_ids_cat, p_mask_cat = p_ids_for_fwd, passage_attention_mask | |
| elif L_p < L_q: | |
| pad = L_q - L_p | |
| q_ids_cat, q_mask_cat = query_input_ids, query_attention_mask | |
| p_ids_cat = F.pad(p_ids_for_fwd, (pad, 0)) # left-pad | |
| p_mask_cat = F.pad(passage_attention_mask, (pad, 0)) | |
| else: | |
| q_ids_cat, q_mask_cat = query_input_ids, query_attention_mask | |
| p_ids_cat, p_mask_cat = p_ids_for_fwd, passage_attention_mask | |
| all_ids = torch.cat([q_ids_cat, p_ids_cat], dim=0) | |
| all_mask = torch.cat([q_mask_cat, p_mask_cat], dim=0) | |
| all_hidden, all_logits = self._fwd(all_ids, all_mask, | |
| need_logits=need_logits) | |
| L_all = all_ids.size(1) | |
| g_q_in_all = L_all - K_q - n_tail | |
| g_p_in_all = L_all - K_p - n_tail | |
| q_hidden, p_hidden = all_hidden[:B_q], all_hidden[B_q:] | |
| q_logits = all_logits[:B_q] if all_logits is not None else None | |
| p_logits = all_logits[B_q:] if all_logits is not None else None | |
| else: | |
| # ── Asymmetric path: K_q != K_p means different gen-block sizes, | |
| # so a concatenated forward would mis-align the slices. Run | |
| # two separate _fwd calls (~2× compute but correct). | |
| q_hidden, q_logits = self._fwd( | |
| query_input_ids, query_attention_mask, | |
| need_logits=need_logits) | |
| p_hidden, p_logits = self._fwd( | |
| p_ids_for_fwd, passage_attention_mask, | |
| need_logits=need_logits) | |
| g_q_in_all = L_q - K_q - n_tail | |
| g_p_in_all = L_p - K_p - n_tail | |
| # No padding/concat: each side keeps its own length. | |
| L_all = max(L_q, L_p) # only used for denoising slice | |
| # ── Extract Q repr ───────────────────────────────────────────── | |
| q_repr_hidden = q_hidden[:, g_q_in_all:g_q_in_all + K_q, :].float() | |
| q_quotation_emb = q_hidden[:, g_q_in_all - 1, :].float() | |
| q_sparse_max = None | |
| if need_sparse and q_logits is not None: | |
| # Monotonic trick: max in bf16 first → log1p(relu()) on [B, V]. | |
| q_sparse_max = torch.log1p(torch.relu( | |
| q_logits[:, g_q_in_all:g_q_in_all + K_q, :].max(dim=1).values)) | |
| if self.normalize: | |
| q_quotation_emb = F.normalize(q_quotation_emb, p=2, dim=-1) | |
| q_repr_hidden = F.normalize(q_repr_hidden, p=2, dim=-1) | |
| q_repr = {'repr_hidden': q_repr_hidden, 'quotation_emb': q_quotation_emb} | |
| if q_sparse_max is not None: | |
| if query_content_ids is not None: | |
| q_sparse_max = filter_sparse(q_sparse_max, query_content_ids) | |
| q_repr['sparse_acts'] = q_sparse_max | |
| # ── Extract P repr ───────────────────────────────────────────── | |
| p_repr_hidden = p_hidden[:, g_p_in_all:g_p_in_all + K_p, :].float() | |
| p_quotation_emb = p_hidden[:, g_p_in_all - 1, :].float() | |
| p_sparse_max = None | |
| if need_sparse and p_logits is not None: | |
| # Monotonic trick: max in bf16 first → log1p(relu()) on [B, V]. | |
| p_sparse_max = torch.log1p(torch.relu( | |
| p_logits[:, g_p_in_all:g_p_in_all + K_p, :].max(dim=1).values)) | |
| if self.normalize: | |
| p_quotation_emb = F.normalize(p_quotation_emb, p=2, dim=-1) | |
| p_repr_hidden = F.normalize(p_repr_hidden, p=2, dim=-1) | |
| p_repr = {'repr_hidden': p_repr_hidden, 'quotation_emb': p_quotation_emb} | |
| if p_sparse_max is not None: | |
| if passage_content_ids is not None: | |
| p_sparse_max = filter_sparse(p_sparse_max, passage_content_ids) | |
| p_repr['sparse_acts'] = p_sparse_max | |
| # ── Cross-GPU negative sharing ──────────────────────────────── | |
| q_repr = self._gather_repr(q_repr) | |
| p_repr = self._gather_repr(p_repr) | |
| # Recompute labels for gathered batch | |
| B_q_all = q_repr['repr_hidden'].size(0) | |
| B_p_all = p_repr['repr_hidden'].size(0) | |
| n_paq_g = B_p_all // B_q_all | |
| labels_g = torch.arange(B_q_all, device=device) * n_paq_g | |
| # ── Loss ────────────────────────────────────────────────────── | |
| loss_dict = self.compute_loss(q_repr, p_repr, labels_g) | |
| if self.debug_dense_metrics: | |
| loss_dict.update(self._dense_debug_stats( | |
| q_repr, p_repr, labels_g, | |
| q_raw=q_repr_hidden, p_raw=p_repr_hidden, | |
| )) | |
| # Scale loss to counter DDP gradient averaging | |
| if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: | |
| loss_dict['loss'] = loss_dict['loss'] * torch.distributed.get_world_size() | |
| if self.denoising_weight > 0: | |
| # Denoising targets were created from original passage length; | |
| # if passages were left-padded to match queries, slice | |
| # accordingly. In asymmetric mode (separate forward) no | |
| # padding was applied, so the slice is a no-op (dn_start=0). | |
| p_seq_len = p_logits.size(1) if p_logits is not None else L_p | |
| dn_start = p_seq_len - L_p | |
| p_logits_dn = p_logits[:, dn_start:, :] | |
| denoise_loss = self.compute_denoising_loss( | |
| p_logits_dn, p_denoise_targets, mask_ratio) | |
| loss_dict['loss'] = loss_dict['loss'] + self.denoising_weight * denoise_loss | |
| loss_dict['loss_denoising'] = denoise_loss.detach() | |
| return loss_dict | |
| # ── Multi-step: denoising training ─────────────────────────────────── | |
| K_q, K_p = self.n_gen_q_tokens, self.n_gen_p_tokens | |
| n_tail = self._n_tail | |
| n_per_step_q = max(1, K_q // n_steps) | |
| n_per_step_p = max(1, K_p // n_steps) | |
| device = query_input_ids.device | |
| H = self.hidden_size | |
| V = self.backbone.config.vocab_size | |
| L_q = query_input_ids.size(1) | |
| L_p = passage_input_ids.size(1) | |
| # Pre-compute 4D masks once (only for models that need them) | |
| q_mask_4d = self._build_4d_mask(query_input_ids.size(1), query_attention_mask) if (self.adapter.needs_4d_mask() if self.adapter else True) else None | |
| p_mask_4d = self._build_4d_mask(passage_input_ids.size(1), passage_attention_mask) if (self.adapter.needs_4d_mask() if self.adapter else True) else None | |
| q_curr = query_input_ids.clone() | |
| p_curr = passage_input_ids.clone() | |
| need_sparse = self.sparse_weight > 0 | |
| # Frozen buffers for positions already decoded in prior steps | |
| q_repr_buf = torch.zeros(B_q, K_q, H, device=device) | |
| p_repr_buf = torch.zeros(B_p, K_p, H, device=device) | |
| q_decoded_mask = torch.zeros(B_q, K_q, dtype=torch.bool, device=device) | |
| p_decoded_mask = torch.zeros(B_p, K_p, dtype=torch.bool, device=device) | |
| q_sparse_decoded = torch.zeros(B_q, V, device=device) if need_sparse else None | |
| p_sparse_decoded = torch.zeros(B_p, V, device=device) if need_sparse else None | |
| progressive = self.progressive_step_weight > 0 | |
| final_loss: Dict[str, torch.Tensor] = {} | |
| progressive_loss_sum = torch.tensor(0.0, device=device) | |
| def _build_repr_at_step(h, logits, repr_buf, decoded_mask, sparse_decoded, | |
| K_side: int, | |
| content_ids=None, use_fresh=False): | |
| """Build retrieval representation from current hidden states. | |
| Used for both intermediate (progressive) and final steps. | |
| Snapshot buffers with .clone() to isolate from later in-place updates. | |
| K_side: per-side K (n_gen_q_tokens for queries, n_gen_p_tokens | |
| for passages — needed under asymmetric configs). | |
| use_fresh: if True, ALL K positions use current hidden states | |
| (ignoring frozen repr_buf). Lets all tokens benefit from | |
| fully-decoded context at the final step. | |
| """ | |
| Ls = h.size(1) | |
| gs = Ls - K_side - n_tail # left-padded: gen block then structural tail | |
| curr_gen_h = h[:, gs:gs + K_side, :].float() # [B_loc, K_side, H] | |
| decoded_snap = decoded_mask.clone() | |
| repr_snap = repr_buf.detach().clone() | |
| mask_sparse = None | |
| if need_sparse and logits is not None: | |
| sp_all = torch.log(1.0 + torch.relu(logits[:, gs:gs + K_side, :].float())) | |
| if use_fresh: | |
| mask_sparse = sp_all.max(dim=1).values | |
| else: | |
| is_mask = (~decoded_snap).float().unsqueeze(-1) | |
| mask_sparse = (sp_all * is_mask).max(dim=1).values | |
| if use_fresh: | |
| mixed = curr_gen_h | |
| else: | |
| mixed = torch.where(decoded_snap.unsqueeze(-1), repr_snap, curr_gen_h) | |
| quotation_emb = h[:, gs - 1, :].float() | |
| if self.normalize: | |
| mixed = F.normalize(mixed, p=2, dim=-1) | |
| quotation_emb = F.normalize(quotation_emb, p=2, dim=-1) | |
| result = {'repr_hidden': mixed, 'quotation_emb': quotation_emb} | |
| if need_sparse and sparse_decoded is not None: | |
| sparse_snap = sparse_decoded.detach().clone() | |
| combined_sparse = torch.max(sparse_snap, mask_sparse) | |
| if content_ids is not None: | |
| combined_sparse = filter_sparse(combined_sparse, content_ids) | |
| result['sparse_acts'] = combined_sparse | |
| return result | |
| q_g = L_q - K_q - n_tail # left-padded: gen block starts here | |
| p_g = L_p - K_p - n_tail | |
| for step in range(n_steps): | |
| is_last = (step == n_steps - 1) | |
| q_h, q_logits = self._fwd(q_curr, query_attention_mask, | |
| need_logits=True, mask_4d=q_mask_4d) | |
| p_h, p_logits = self._fwd(p_curr, passage_attention_mask, | |
| need_logits=True, mask_4d=p_mask_4d) | |
| if progressive or is_last: | |
| _fresh = self.use_fresh_final and is_last | |
| q_repr_s = _build_repr_at_step( | |
| q_h, q_logits, | |
| q_repr_buf, q_decoded_mask, q_sparse_decoded, K_q, | |
| content_ids=query_content_ids, use_fresh=_fresh) | |
| p_repr_s = _build_repr_at_step( | |
| p_h, p_logits, | |
| p_repr_buf, p_decoded_mask, p_sparse_decoded, K_p, | |
| content_ids=passage_content_ids, use_fresh=_fresh) | |
| # Cross-GPU gather for multi-step | |
| q_repr_s = self._gather_repr(q_repr_s) | |
| p_repr_s = self._gather_repr(p_repr_s) | |
| B_q_g = q_repr_s['repr_hidden'].size(0) | |
| B_p_g = p_repr_s['repr_hidden'].size(0) | |
| labels_ms = torch.arange(B_q_g, device=device) * (B_p_g // B_q_g) | |
| step_loss_dict = self.compute_loss(q_repr_s, p_repr_s, labels_ms) | |
| if is_last: | |
| final_loss = step_loss_dict | |
| if self.debug_dense_metrics: | |
| final_loss.update(self._dense_debug_stats(q_repr_s, p_repr_s, labels_ms)) | |
| elif progressive: | |
| step_weight = (step + 1) / n_steps | |
| progressive_loss_sum = progressive_loss_sum + step_weight * step_loss_dict['loss'] | |
| if not is_last: | |
| q_curr, q_newly = self._unmask_step( | |
| q_curr, q_logits, K_q, n_per_step_q) | |
| p_curr, p_newly = self._unmask_step( | |
| p_curr, p_logits, K_p, n_per_step_p) | |
| with torch.no_grad(): | |
| for i, pos_list in enumerate(q_newly): | |
| for tok_pos in pos_list: | |
| q_repr_buf[i, tok_pos] = q_h[i, q_g + tok_pos].detach() | |
| q_decoded_mask[i, tok_pos] = True | |
| if q_sparse_decoded is not None and q_logits is not None: | |
| q_sparse_decoded[i] = torch.max(q_sparse_decoded[i], | |
| torch.log(1.0 + torch.relu(q_logits[i, q_g + tok_pos])).detach()) | |
| for i, pos_list in enumerate(p_newly): | |
| for tok_pos in pos_list: | |
| p_repr_buf[i, tok_pos] = p_h[i, p_g + tok_pos].detach() | |
| p_decoded_mask[i, tok_pos] = True | |
| if p_sparse_decoded is not None and p_logits is not None: | |
| p_sparse_decoded[i] = torch.max(p_sparse_decoded[i], | |
| torch.log(1.0 + torch.relu(p_logits[i, p_g + tok_pos])).detach()) | |
| # Add progressive step loss to final loss | |
| if progressive and final_loss: | |
| final_loss['loss'] = final_loss['loss'] + self.progressive_step_weight * progressive_loss_sum | |
| final_loss['loss_progressive'] = progressive_loss_sum.detach() | |
| # Scale loss for DDP (multi-step path) | |
| if final_loss and torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: | |
| final_loss['loss'] = final_loss['loss'] * torch.distributed.get_world_size() | |
| # Denoising auxiliary for multi-step (same as single-step: separate corrupted forward) | |
| if self.denoising_weight > 0 and final_loss: | |
| p_corrupted, p_denoise_targets, mask_ratio = self._apply_text_masking( | |
| passage_input_ids, passage_attention_mask) | |
| _, p_logits_dn = self._fwd(p_corrupted, passage_attention_mask, | |
| need_logits=True, mask_4d=p_mask_4d) | |
| denoise_loss = self.compute_denoising_loss(p_logits_dn, p_denoise_targets, mask_ratio) | |
| final_loss['loss'] = final_loss['loss'] + self.denoising_weight * denoise_loss | |
| final_loss['loss_denoising'] = denoise_loss.detach() | |
| return final_loss | |
| # ---------------------------------------------------------------- | |
| # Save / load helpers | |
| # ---------------------------------------------------------------- | |
| def _save_retriever_config(self, output_dir: str): | |
| import json, os | |
| config = { | |
| 'model_type': self.model_type, | |
| 'mask_token_id': self.mask_token_id, | |
| 'hidden_size': self.hidden_size, | |
| 'max_length': self.max_length, | |
| 'n_gen_tokens': self.n_gen_tokens, | |
| 'n_gen_q_tokens': self.n_gen_q_tokens, | |
| 'n_gen_p_tokens': self.n_gen_p_tokens, | |
| 'temperature': self.temperature, | |
| 'flops_weight': 0.0, # deprecated, kept for backwards compat | |
| 'num_denoise_steps': self.num_denoise_steps, | |
| 'sparse_weight': self.sparse_weight, | |
| 'normalize': self.normalize, | |
| 'query_prefix_ids': self._query_prefix_ids, | |
| 'query_suffix_ids': self._query_suffix_ids, | |
| 'passage_prefix_ids': self._passage_prefix_ids, | |
| 'passage_suffix_ids': self._passage_suffix_ids, | |
| 'lora_rank': getattr(self, 'lora_rank', 0), | |
| 'lora_alpha': getattr(self, 'lora_alpha', 64), | |
| 'use_eos': self.use_eos, | |
| 'n_tail': self._n_tail, | |
| # K-adapter persistence | |
| 'use_k_adapter': self.use_k_adapter, | |
| 'adapter_weight': self.adapter_weight, | |
| 'teacher_temperature': self.teacher_temperature, | |
| 'k_adapter_options': list(self.k_adapter_options), | |
| # K pre-encoder persistence | |
| 'use_k_pre_encoder': self.use_k_pre_encoder, | |
| 'gumbel_temperature': self.gumbel_temperature, | |
| 'k_cost_lambda': self.k_cost_lambda, | |
| 'k_pre_encoder_options': list(self.k_pre_encoder_options), | |
| } | |
| with open(os.path.join(output_dir, 'retriever_config.json'), 'w') as f: | |
| json.dump(config, f, indent=2) | |
| # Save adapter state separately (PEFT save_pretrained only saves LoRA; | |
| # the small KAdapter MLP isn't a LoRA module so we persist it ourselves). | |
| if self.use_k_adapter and self.k_adapter is not None: | |
| torch.save(self.k_adapter.state_dict(), | |
| os.path.join(output_dir, 'k_adapter.bin')) | |
| if self.use_k_pre_encoder and self.k_pre_encoder_q is not None: | |
| torch.save({'q': self.k_pre_encoder_q.state_dict(), | |
| 'p': self.k_pre_encoder_p.state_dict()}, | |
| os.path.join(output_dir, 'k_pre_encoder.bin')) | |
| def load(cls, model_dir: str, **fallback_kwargs) -> 'TrainableDiffusionRetriever': | |
| """Load a fine-tuned TrainableDiffusionRetriever from a saved directory. | |
| If retriever_config.json is missing (e.g. mid-training checkpoint or old checkpoint | |
| with matryoshka_config.json), falls back to from_backbone() using fallback_kwargs: | |
| model_type, query_prompt, passage_prompt, n_gen_tokens, | |
| sparse_weight, max_length. | |
| """ | |
| import json | |
| from transformers import AutoTokenizer | |
| # Support both new name and old name (matryoshka_config.json) for backwards compat | |
| config_path = Path(model_dir) / 'retriever_config.json' | |
| if not config_path.exists(): | |
| config_path = Path(model_dir) / 'matryoshka_config.json' | |
| if not config_path.exists(): | |
| model_type = fallback_kwargs.get('model_type') | |
| query_prompt = fallback_kwargs.get('query_prompt') | |
| passage_prompt = fallback_kwargs.get('passage_prompt') | |
| if not model_type or not query_prompt or not passage_prompt: | |
| raise FileNotFoundError( | |
| f"No retriever_config.json in {model_dir}. " | |
| "Pass model_type, query_prompt, passage_prompt as fallback kwargs.") | |
| _fallback_adapter = get_adapter(model_type) | |
| source_model = (fallback_kwargs.get('original_model') | |
| or _fallback_adapter.hub_model_name) | |
| logger.info(f"No retriever_config.json — loading architecture from {source_model}, " | |
| f"weights from {model_dir}") | |
| # Load state dict first so we can detect LoRA rank before building model | |
| checkpoint_dir = Path(model_dir) | |
| weight_file = checkpoint_dir / 'model.safetensors' | |
| bin_files = sorted(checkpoint_dir.glob('model-*.safetensors')) | |
| bin_pt = sorted(checkpoint_dir.glob('pytorch_model*.bin')) | |
| state_dict = None | |
| if weight_file.exists() or bin_files: | |
| from safetensors.torch import load_file | |
| if weight_file.exists(): | |
| state_dict = load_file(str(weight_file)) | |
| else: | |
| state_dict = {} | |
| for f in bin_files: | |
| state_dict.update(load_file(str(f))) | |
| elif bin_pt: | |
| state_dict = torch.load(str(bin_pt[0]), map_location='cpu') | |
| # Detect LoRA rank from key names (keys look like backbone.base_model.model.*) | |
| detected_lora_rank = fallback_kwargs.get('lora_rank', 0) | |
| if state_dict is not None and detected_lora_rank == 0: | |
| for k, v in state_dict.items(): | |
| if 'lora_A.default.weight' in k: | |
| detected_lora_rank = v.shape[0] | |
| logger.info(f"Detected LoRA rank {detected_lora_rank} from checkpoint keys") | |
| break | |
| model = cls.from_backbone( | |
| model_name=source_model, | |
| model_type=model_type, | |
| query_prompt=query_prompt, | |
| passage_prompt=passage_prompt, | |
| n_gen_tokens=fallback_kwargs.get('n_gen_tokens', 4), | |
| n_gen_q_tokens=fallback_kwargs.get('n_gen_q_tokens', None), | |
| n_gen_p_tokens=fallback_kwargs.get('n_gen_p_tokens', None), | |
| temperature=fallback_kwargs.get('temperature', 0.02), | |
| num_denoise_steps=fallback_kwargs.get('num_denoise_steps', None), | |
| sparse_weight=fallback_kwargs.get('sparse_weight', 1.0), | |
| normalize=fallback_kwargs.get('normalize', True), | |
| max_length=fallback_kwargs.get('max_length', 256), | |
| lora_rank=detected_lora_rank, | |
| device_map='auto', | |
| ) | |
| if state_dict is not None: | |
| # Load full state dict directly into the model (keys have backbone. prefix) | |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) | |
| backbone_keys_loaded = sum(1 for k in state_dict if k.startswith('backbone.')) | |
| logger.info(f"Loaded checkpoint weights from {model_dir} " | |
| f"(lora_rank={detected_lora_rank}, " | |
| f"{backbone_keys_loaded} backbone keys, " | |
| f"{len(missing)} missing, {len(unexpected)} unexpected)") | |
| else: | |
| logger.warning(f"No weight file found in {model_dir} — using original model weights") | |
| # Merge LoRA weights into base model for faster inference | |
| if detected_lora_rank > 0 and hasattr(model.backbone, 'merge_and_unload'): | |
| model.backbone = model.backbone.merge_and_unload() | |
| logger.info("Merged LoRA adapters for inference") | |
| return model | |
| with open(config_path) as f: | |
| cfg = json.load(f) | |
| model_type = cfg.get('model_type', 'dream') | |
| adapter = get_adapter(model_type) | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) | |
| tokenizer.padding_side = 'left' | |
| load_weights_separately = False | |
| # ── Released PEFT adapter repo (the layout published on the Hub) ────── | |
| # A released checkpoint ships only the LoRA adapter | |
| # (adapter_config.json + adapter_model.safetensors/.bin) alongside | |
| # retriever_config.json and the tokenizer — there is no base config.json | |
| # or full-weight file. Detect that and load the base backbone from the | |
| # Hub, then attach the adapter in place. PeftModel.from_pretrained | |
| # restores the trained LoRA weights, so no separate weight load is | |
| # needed and the merge step below folds them in for fast inference. | |
| _adapter_cfg = Path(model_dir) / 'adapter_config.json' | |
| _base_cfg = Path(model_dir) / 'config.json' | |
| _adapter_weights = ((Path(model_dir) / 'adapter_model.safetensors').exists() | |
| or (Path(model_dir) / 'adapter_model.bin').exists()) | |
| if _adapter_cfg.exists() and _adapter_weights and not _base_cfg.exists(): | |
| from peft import PeftModel | |
| logger.info(f"Detected released PEFT adapter in {model_dir} — " | |
| f"loading base {adapter.hub_model_name} from Hub and " | |
| f"attaching the LoRA adapter") | |
| base = adapter.load_backbone(adapter.hub_model_name, device_map='auto') | |
| backbone = PeftModel.from_pretrained(base, str(model_dir)) | |
| else: | |
| try: | |
| backbone = adapter.load_backbone(model_dir, device_map='auto') | |
| except ValueError: | |
| # model_dir has no config.json — load architecture from Hub | |
| logger.info(f"No valid config.json in {model_dir} — " | |
| f"loading architecture from {adapter.hub_model_name}") | |
| backbone = adapter.load_backbone(adapter.hub_model_name, device_map='auto') | |
| load_weights_separately = True | |
| # Detect if model.safetensors was saved as a full PEFT state dict by DeepSpeed | |
| _ckpt = Path(model_dir) / 'model.safetensors' | |
| if _ckpt.exists(): | |
| from safetensors import safe_open as _so | |
| with _so(str(_ckpt), framework='pt') as _f: | |
| _sample = list(_f.keys())[:8] | |
| if any('base_model.model' in k for k in _sample): | |
| logger.info("Detected PEFT/LoRA state dict — " | |
| "applying LoRA to backbone before loading weights") | |
| _lora_rank = cfg.get('lora_rank', 0) | |
| _lora_alpha = cfg.get('lora_alpha', 64) | |
| if _lora_rank == 0: | |
| from safetensors.torch import load_file as _lf | |
| _peek = _lf(str(_ckpt)) | |
| _la_keys = [k for k in _peek if 'lora_A.default.weight' in k] | |
| if _la_keys: | |
| _lora_rank = _peek[_la_keys[0]].shape[0] | |
| if _lora_rank > 0 and not getattr(backbone, 'peft_config', None): | |
| from peft import get_peft_model | |
| lora_cfg = adapter.get_lora_config(_lora_rank, _lora_alpha) | |
| backbone = get_peft_model(backbone, lora_cfg) | |
| logger.info(f"Applied LoRA rank={_lora_rank} alpha={_lora_alpha}") | |
| elif getattr(backbone, 'peft_config', None): | |
| logger.info("Backbone already has PEFT — skipping duplicate LoRA") | |
| load_weights_separately = True | |
| model = cls( | |
| backbone=backbone, | |
| tokenizer=tokenizer, | |
| mask_token_id=cfg['mask_token_id'], | |
| hidden_size=cfg['hidden_size'], | |
| query_prefix_ids=cfg['query_prefix_ids'], | |
| query_suffix_ids=cfg['query_suffix_ids'], | |
| passage_prefix_ids=cfg['passage_prefix_ids'], | |
| passage_suffix_ids=cfg['passage_suffix_ids'], | |
| max_length=cfg['max_length'], | |
| n_gen_tokens=cfg['n_gen_tokens'], | |
| n_gen_q_tokens=cfg.get('n_gen_q_tokens'), # back-compat: missing → fall through to n_gen_tokens | |
| n_gen_p_tokens=cfg.get('n_gen_p_tokens'), | |
| temperature=cfg['temperature'], | |
| num_denoise_steps=cfg['num_denoise_steps'], | |
| sparse_weight=cfg['sparse_weight'], | |
| normalize=cfg['normalize'], | |
| flash_attn=adapter.flash_attn, | |
| use_eos=cfg.get('use_eos', False), | |
| # K-adapter restore (back-compat: keys missing → defaults disable) | |
| use_k_adapter=cfg.get('use_k_adapter', False), | |
| adapter_weight=cfg.get('adapter_weight', 1.0), | |
| teacher_temperature=cfg.get('teacher_temperature', 1.0), | |
| k_adapter_options=tuple(cfg['k_adapter_options']) | |
| if cfg.get('k_adapter_options') is not None else None, | |
| # K pre-encoder restore | |
| use_k_pre_encoder=cfg.get('use_k_pre_encoder', False), | |
| gumbel_temperature=cfg.get('gumbel_temperature', 1.0), | |
| k_cost_lambda=cfg.get('k_cost_lambda', 0.001), | |
| k_pre_encoder_options=tuple(cfg['k_pre_encoder_options']) | |
| if cfg.get('k_pre_encoder_options') is not None else None, | |
| ) | |
| # Load adapter weights if a saved file exists. | |
| if model.use_k_adapter and model.k_adapter is not None: | |
| adapter_path = Path(model_dir) / 'k_adapter.bin' | |
| if adapter_path.exists(): | |
| model.k_adapter.load_state_dict( | |
| torch.load(str(adapter_path), map_location='cpu')) | |
| logger.info(f"Loaded KAdapter state from {adapter_path}") | |
| else: | |
| logger.warning(f"use_k_adapter=True but {adapter_path} missing; " | |
| f"adapter starts from random init.") | |
| # Load pre-encoder weights if a saved file exists. | |
| if model.use_k_pre_encoder and model.k_pre_encoder_q is not None: | |
| pe_path = Path(model_dir) / 'k_pre_encoder.bin' | |
| if pe_path.exists(): | |
| pe_state = torch.load(str(pe_path), map_location='cpu') | |
| model.k_pre_encoder_q.load_state_dict(pe_state['q']) | |
| model.k_pre_encoder_p.load_state_dict(pe_state['p']) | |
| logger.info(f"Loaded KPreEncoder state from {pe_path}") | |
| else: | |
| logger.warning(f"use_k_pre_encoder=True but {pe_path} missing; " | |
| f"pre-encoder starts from random init.") | |
| # Restore _n_tail from config. Old checkpoints (before structural tail) | |
| # don't have 'n_tail' — infer from use_eos for backward compat. | |
| if 'n_tail' in cfg: | |
| model._n_tail = cfg['n_tail'] | |
| model._tail_ids = model._build_tail_ids() | |
| else: | |
| # Legacy: use_eos=True → 1 tail token (EOS only), use_eos=False → 0 | |
| model._n_tail = 1 if cfg.get('use_eos', False) else 0 | |
| eos_id = model.tokenizer.eos_token_id | |
| model._tail_ids = [eos_id] if model._n_tail == 1 else [] | |
| logger.info(f"Legacy checkpoint: n_tail={model._n_tail} (from use_eos={cfg.get('use_eos', False)})") | |
| model.model_type = model_type | |
| model.adapter = adapter | |
| # Hook for efficient hidden state extraction | |
| model._last_hidden: Dict[str, torch.Tensor] = {} | |
| model._hook_registered = adapter.register_hidden_hook( | |
| backbone, model._last_hidden) | |
| if load_weights_separately: | |
| checkpoint_dir = Path(model_dir) | |
| weight_file = checkpoint_dir / 'model.safetensors' | |
| bin_files = sorted(checkpoint_dir.glob('model-*.safetensors')) | |
| bin_pt = sorted(checkpoint_dir.glob('pytorch_model*.bin')) | |
| if weight_file.exists() or bin_files: | |
| from safetensors.torch import load_file | |
| state_dict = load_file(str(weight_file)) if weight_file.exists() else {} | |
| for f in bin_files: | |
| state_dict.update(load_file(str(f))) | |
| backbone_dict = {k[len('backbone.'):]: v for k, v in state_dict.items() | |
| if k.startswith('backbone.')} | |
| model.backbone.load_state_dict(backbone_dict or state_dict, strict=False) | |
| logger.info(f"Loaded fine-tuned weights from {model_dir}") | |
| elif bin_pt: | |
| sd = torch.load(str(bin_pt[0]), map_location='cpu') | |
| backbone_dict = {k[len('backbone.'):]: v for k, v in sd.items() | |
| if k.startswith('backbone.')} | |
| model.backbone.load_state_dict(backbone_dict or sd, strict=False) | |
| logger.info(f"Loaded fine-tuned weights from {model_dir}") | |
| else: | |
| logger.warning(f"No weight file found in {model_dir} — using Hub weights") | |
| # Merge LoRA adapters into the base weights for inference. Training | |
| # keeps them separate for gradient flow; at inference each LoRA adds | |
| # per-layer overhead otherwise. | |
| # | |
| # Two shapes seen in saved checkpoints: | |
| # (a) `model.backbone` is a PeftModel wrapper → call | |
| # `merge_and_unload()` directly. | |
| # (b) `model.backbone` is the underlying DreamModel / LLaDAModel | |
| # with `LoraLinear` modules still in place (the top-level | |
| # PeftModel wrapper was dropped during save). `merge_and_unload` | |
| # is missing but each inner `LoraLinear` still exposes `.merge()` | |
| # which folds its `lora_A @ lora_B` into `base_layer.weight` in | |
| # place and flips `merged=True`, disabling the LoRA forward path. | |
| # We walk the module tree and merge case-by-case. | |
| if hasattr(model.backbone, 'merge_and_unload'): | |
| try: | |
| model.backbone = model.backbone.merge_and_unload() | |
| logger.info("Merged LoRA adapters into base backbone for inference") | |
| except Exception as exc: | |
| logger.warning(f"merge_and_unload failed ({exc}); " | |
| f"continuing with un-merged LoRA (slower inference)") | |
| else: | |
| merged_modules = 0 | |
| for _mod in model.backbone.modules(): | |
| if (hasattr(_mod, 'lora_A') and hasattr(_mod, 'merge') | |
| and callable(getattr(_mod, 'merge', None)) | |
| and not getattr(_mod, 'merged', False)): | |
| try: | |
| _mod.merge() | |
| merged_modules += 1 | |
| except Exception as exc: | |
| logger.warning(f"per-module LoRA merge failed ({exc}) " | |
| f"— continuing with un-merged layer") | |
| if merged_modules > 0: | |
| logger.info(f"Merged {merged_modules} LoraLinear modules into " | |
| f"base weights for inference") | |
| logger.info(f"Loaded TrainableDiffusionRetriever from {model_dir}") | |
| return model | |
| def save(self, output_dir: str): | |
| import os | |
| os.makedirs(output_dir, exist_ok=True) | |
| backbone = self.backbone | |
| if hasattr(backbone, 'save_pretrained'): | |
| backbone.save_pretrained(output_dir) | |
| else: | |
| backbone.base_model.save_pretrained(output_dir) | |
| self.tokenizer.save_pretrained(output_dir) | |
| self._save_retriever_config(output_dir) | |
| logger.info(f"Saved to {output_dir}") | |
| def config(self): | |
| # HF Trainer / DeepSpeed expect model.config to exist. | |
| return self.backbone.config | |
| def gradient_checkpointing_enable(self, **kwargs): | |
| if self.adapter: | |
| self.adapter.enable_gradient_checkpointing(self.backbone, **kwargs) | |
| else: | |
| self.backbone.gradient_checkpointing_enable(**kwargs) | |