diffretriever-llada-8b-single / diffretriever_trainable.py
wshuai190's picture
Add self-contained DiffRetriever (trust_remote_code: code + config + adapter/)
a7c784c verified
Raw
History Blame Contribute Delete
123 kB
"""
DiffRetriever — fine-tunable variant (paper §3.4 supervised fine-tuning).
Contrastively fine-tunable wrapper around Dream / LLaDA backbones,
sharing the masked-position prediction encoder with the zero-shot
DreamRetriever / LLaDA2Retriever classes. Supports Dream, LLaDA1/1.5,
LLaDA2 backbones via backbone_adapters.py.
Encoding:
steps=1 (fast): Single forward pass over [prefix][text][suffix][MASK×K].
repr_hidden[:, s, :] from MASK position s; quotation_emb from the
closing " token.
steps>1 (rich): Iterative denoising loop with mixed representations.
At each step, gen positions already decoded in a prior step contribute
their frozen hidden state (no gradient); positions still MASK contribute
the current step's hidden (with gradient).
Uniform LLaDA-style unmasking schedule: n_per_step = K // n_steps
tokens are decoded per step.
Loss (training):
K > 1: ColBERT MaxSim InfoNCE over all K mixed vectors.
steps>1: loss computed only at the final denoising step so gradient
flows from the fully-contextualized representation without dilution.
If progressive_step_weight > 0, retrieval loss is also applied at
each intermediate step with linearly increasing weight (t/T),
giving direct supervision to early denoising steps.
Optional: sparse InfoNCE + FLOPS L1 regularization.
At eval time the same encode() output (repr_hidden, quotation_emb,
sparse_acts) supports all zero-shot retrieval modes: single_dense,
multi_dense_k*, multi-denoise-step variants, and sparse versions of each.
"""
import re
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
from pathlib import Path
import logging
from .sparse_utils import filter_sparse
from .backbone_adapters import get_adapter, BackboneAdapter
logger = logging.getLogger(__name__)
_NUM_WORDS = ['one', 'two', 'three', 'four', 'five',
'six', 'seven', 'eight', 'nine', 'ten']
class TrainableDiffusionRetriever(nn.Module):
"""
Trainable Diffusion Retriever.
Build via from_backbone(); use tokenize() in your data collator, then pass
pre-tokenized tensors to forward() during training.
"""
def __init__(
self,
backbone: nn.Module,
tokenizer,
mask_token_id: int,
hidden_size: int,
query_prefix_ids: List[int],
query_suffix_ids: List[int],
passage_prefix_ids: List[int],
passage_suffix_ids: List[int],
max_length: int = 512,
n_gen_tokens: int = 4,
n_gen_q_tokens: Optional[int] = None, # asymmetric K_q (defaults to n_gen_tokens)
n_gen_p_tokens: Optional[int] = None, # asymmetric K_p (defaults to n_gen_tokens)
temperature: float = 0.01,
num_denoise_steps: int = 4,
sparse_weight: float = 1.0,
normalize: bool = True,
flash_attn: bool = False,
use_eos: bool = False,
# K-adapter (joint training of K-router + retriever). When enabled,
# the model encodes at K_max and at training time computes per-cell
# InfoNCE losses; a tiny MLP head maps the query's quotation_emb to a
# softmax over K-cell choices and is supervised by KL(teacher || head)
# where teacher = softmax(-per_cell_loss / τ_T). See compute_loss().
use_k_adapter: bool = False,
adapter_weight: float = 1.0,
teacher_temperature: float = 1.0,
k_adapter_options: Optional[Tuple[int, ...]] = None,
# Two-stage K-pre-encoder (true encoding savings via per-item K).
# When enabled, a tiny MLP runs over the embedding-layer output and
# predicts K_q (for queries) and K_p (for passages) per item. Each
# item's input is sliced to its predicted K (removing K_max-K MASK
# tokens), the batch is padded to max-K-in-batch, and main encoder
# runs at variable length. See _forward_with_k_pre_encoder().
use_k_pre_encoder: bool = False,
gumbel_temperature: float = 1.0,
k_cost_lambda: float = 0.001,
k_pre_encoder_options: Optional[Tuple[int, ...]] = None,
):
super().__init__()
self.backbone = backbone
self.tokenizer = tokenizer
self.mask_token_id = mask_token_id
self.hidden_size = hidden_size
# Prompt IDs (cached lists, not parameters)
self.register_buffer('_dummy', torch.zeros(1)) # for .device
self._query_prefix_ids = list(query_prefix_ids)
self._query_suffix_ids = list(query_suffix_ids)
self._passage_prefix_ids = list(passage_prefix_ids)
self._passage_suffix_ids = list(passage_suffix_ids)
self.max_length = max_length
# Per-side K. When the user passes only `n_gen_tokens`, both sides
# use it (symmetric — original behaviour). Pass `n_gen_q_tokens` /
# `n_gen_p_tokens` to train asymmetric (e.g. K_q=4, K_p=16).
self.n_gen_tokens = n_gen_tokens # legacy, kept for back-compat
self.n_gen_q_tokens = n_gen_q_tokens if n_gen_q_tokens is not None else n_gen_tokens
self.n_gen_p_tokens = n_gen_p_tokens if n_gen_p_tokens is not None else n_gen_tokens
self.use_eos = use_eos # kept for config compat; _n_tail is authoritative
# _n_tail is the structural tail (quote + chat_end + EOS); it's the
# same regardless of K_q or K_p as long as at least one side has K>0.
max_k = max(self.n_gen_q_tokens, self.n_gen_p_tokens)
self._n_tail = 3 if max_k > 0 else 0
self._tail_ids: List[int] = self._build_tail_ids() # cached at init
self.temperature = temperature
# Denoise step count is capped by the larger K (the side with more
# mask tokens has more positions to iteratively decode).
self.num_denoise_steps = min(num_denoise_steps, max_k)
self.sparse_weight = sparse_weight
self.dense_weight = 1.0
self.normalize = normalize
self.flash_attn = flash_attn
self.model_type = 'unknown' # set by from_backbone() / load()
self.adapter: Optional[BackboneAdapter] = None # set by factory methods
self._hook_registered = False # set by _setup_hook()
# Auxiliary losses (activated by setting weight > 0)
self.denoising_weight = 0.0 # diffusion-native masked text denoising
self.diversity_weight = 0.0 # explicit multi-vector diversity
self.denoise_mask_ratio = 0.15 # fraction of text tokens to mask for denoising
self.progressive_step_weight = 0.0 # progressive step supervision (multi-step only)
self.use_fresh_final = False # "fresh" ablation: use current hidden for ALL K at final step
self.soft_denoising = False # soft-token multi-step: differentiable embeddings instead of hard tokens
self.soft_temperature = 1.0 # temperature for softmax in soft-token denoising
self.corruption_rate = 0.0 # max text corruption rate for denoising-conditioned training (0 = off)
self.debug_dense_metrics = False
self.debug_compare_hidden_once = False
self._debug_hidden_compared = False
# ── K-adapter (joint K-router + retriever) ─────────────────────────
self.use_k_adapter = bool(use_k_adapter)
self.adapter_weight = float(adapter_weight)
self.teacher_temperature = float(teacher_temperature)
# Default K options: factors of K_max that fall within the active K
# range. E.g., K_max=16 → {1,2,4,8,16}; K_max=4 → {1,2,4}.
if k_adapter_options is None:
k_max_active = max(self.n_gen_q_tokens, self.n_gen_p_tokens)
base = (1, 2, 4, 8, 16)
self.k_adapter_options: Tuple[int, ...] = tuple(
k for k in base if k <= k_max_active
)
else:
self.k_adapter_options = tuple(int(k) for k in k_adapter_options)
self.n_K = len(self.k_adapter_options)
if self.use_k_adapter:
# Tiny MLP: hidden_size → hidden_size//4 → n_K * n_K.
# Trained jointly with the backbone via the per-cell loss teacher.
# Kept in fp32 to match the .float() conversion of representations
# in encode(); HF Trainer's bf16 autocast handles mixed precision
# at the call site without needing a manual dtype cast here.
head_dim = max(64, hidden_size // 4)
self.k_adapter = nn.Sequential(
nn.Linear(hidden_size, head_dim),
nn.GELU(),
nn.Linear(head_dim, self.n_K * self.n_K),
)
logger.info(
f"KAdapter enabled: K options={self.k_adapter_options}, "
f"adapter_weight={self.adapter_weight}, "
f"teacher_temperature={self.teacher_temperature}"
)
else:
self.k_adapter = None
# ── K-pre-encoder (two-stage encoding) ─────────────────────────────
self.use_k_pre_encoder = bool(use_k_pre_encoder)
self.gumbel_temperature = float(gumbel_temperature)
self.k_cost_lambda = float(k_cost_lambda)
if k_pre_encoder_options is None:
k_max_active_pe = max(self.n_gen_q_tokens, self.n_gen_p_tokens)
base_pe = (1, 2, 4, 8, 16)
self.k_pre_encoder_options: Tuple[int, ...] = tuple(
k for k in base_pe if k <= k_max_active_pe
)
else:
self.k_pre_encoder_options = tuple(int(k) for k in k_pre_encoder_options)
self.n_K_pe = len(self.k_pre_encoder_options)
# Buffer of K values for differentiable expected-K (cost regularizer).
self.register_buffer(
'k_pe_options_tensor',
torch.tensor(self.k_pre_encoder_options, dtype=torch.float32),
persistent=False,
)
if self.use_k_pre_encoder:
head_dim_pe = max(64, hidden_size // 4)
# Separate q + p heads — query and passage have different optimal K
# distributions and texts differ enough to justify dedicated heads.
# Both heads operate on the (shared) embedding layer's output —
# mean-pooled over real tokens. Cheap (~1% of full forward).
self.k_pre_encoder_q = nn.Sequential(
nn.Linear(hidden_size, head_dim_pe),
nn.GELU(),
nn.Linear(head_dim_pe, self.n_K_pe),
)
self.k_pre_encoder_p = nn.Sequential(
nn.Linear(hidden_size, head_dim_pe),
nn.GELU(),
nn.Linear(head_dim_pe, self.n_K_pe),
)
logger.info(
f"K pre-encoder enabled: K options={self.k_pre_encoder_options}, "
f"gumbel_τ={self.gumbel_temperature}, "
f"cost_λ={self.k_cost_lambda}"
)
else:
self.k_pre_encoder_q = None
self.k_pre_encoder_p = None
def _k(self, is_query: bool) -> int:
"""Return the appropriate per-side K. Lets the rest of the code
say `K = self._k(is_query)` and stay correct under both symmetric
(K_q == K_p) and asymmetric configurations."""
return self.n_gen_q_tokens if is_query else self.n_gen_p_tokens
# ----------------------------------------------------------------
# Build from pretrained backbone
# ----------------------------------------------------------------
@classmethod
def from_backbone(
cls,
model_name: str,
model_type: str,
query_prompt: str,
passage_prompt: str,
max_length: int = 512,
n_gen_tokens: int = 4,
n_gen_q_tokens: Optional[int] = None,
n_gen_p_tokens: Optional[int] = None,
temperature: float = 0.01,
num_denoise_steps: int = 4,
sparse_weight: float = 1.0,
normalize: bool = True,
gradient_checkpointing: bool = True,
lora_rank: int = 0,
lora_alpha: int = 64,
lora_dropout: float = 0.05,
device_map=None,
use_eos: bool = False,
disable_hidden_hook: bool = False,
# K-adapter passthrough
use_k_adapter: bool = False,
adapter_weight: float = 1.0,
teacher_temperature: float = 1.0,
k_adapter_options: Optional[Tuple[int, ...]] = None,
# K pre-encoder passthrough
use_k_pre_encoder: bool = False,
gumbel_temperature: float = 1.0,
k_cost_lambda: float = 0.001,
k_pre_encoder_options: Optional[Tuple[int, ...]] = None,
) -> 'TrainableDiffusionRetriever':
from transformers import AutoTokenizer
adapter = get_adapter(model_type)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.padding_side = 'left'
backbone = adapter.load_backbone(model_name, device_map=device_map)
if lora_rank > 0:
from peft import get_peft_model
lora_cfg = adapter.get_lora_config(lora_rank, lora_alpha, lora_dropout)
backbone = get_peft_model(backbone, lora_cfg)
backbone.print_trainable_parameters()
if gradient_checkpointing:
backbone.enable_input_require_grads()
adapter.enable_gradient_checkpointing(backbone)
mask_token_id = adapter.mask_token_id
logger.info(f"Mask token: {model_type}{mask_token_id}")
# Resolve per-side K (default to symmetric n_gen_tokens). The
# YAML adaptation ('the relevant document is one word' → 'four
# words') uses the per-side K so the natural-language prompt
# matches the actual MASK count the model will see.
k_q = n_gen_q_tokens if n_gen_q_tokens is not None else n_gen_tokens
k_p = n_gen_p_tokens if n_gen_p_tokens is not None else n_gen_tokens
# Build prompt token IDs (per-side adaptation)
q_yaml = cls._load_yaml(query_prompt)
p_yaml = cls._load_yaml(passage_prompt)
q_yaml = dict(q_yaml)
q_yaml['user_suffix'] = cls._adapt_for_k(q_yaml.get('user_suffix', ''), k_q)
q_yaml['assistant_prefix'] = cls._adapt_for_k(q_yaml.get('assistant_prefix', ''), k_q)
p_yaml = dict(p_yaml)
p_yaml['user_suffix'] = cls._adapt_for_k(p_yaml.get('user_suffix', ''), k_p)
p_yaml['assistant_prefix'] = cls._adapt_for_k(p_yaml.get('assistant_prefix', ''), k_p)
q_prefix_ids, q_suffix_ids = cls._build_prompt_ids(tokenizer, q_yaml)
p_prefix_ids, p_suffix_ids = cls._build_prompt_ids(tokenizer, p_yaml)
logger.info(f"Query prompt: {len(q_prefix_ids)} prefix + {len(q_suffix_ids)} suffix tokens")
logger.info(f"Passage prompt: {len(p_prefix_ids)} prefix + {len(p_suffix_ids)} suffix tokens")
if k_q == k_p:
logger.info(f"n_gen_tokens={k_q} (symmetric), num_denoise_steps={num_denoise_steps}")
else:
logger.info(f"n_gen_q_tokens={k_q}, n_gen_p_tokens={k_p} (asymmetric), "
f"num_denoise_steps={num_denoise_steps}")
if sparse_weight > 0:
logger.info(f"Sparse loss: sparse_weight={sparse_weight}")
if num_denoise_steps is None:
num_denoise_steps = max(k_q, k_p)
hidden_size = backbone.config.hidden_size
model = cls(
backbone=backbone,
tokenizer=tokenizer,
mask_token_id=mask_token_id,
hidden_size=hidden_size,
query_prefix_ids=q_prefix_ids,
query_suffix_ids=q_suffix_ids,
passage_prefix_ids=p_prefix_ids,
passage_suffix_ids=p_suffix_ids,
max_length=max_length,
n_gen_tokens=n_gen_tokens,
n_gen_q_tokens=k_q,
n_gen_p_tokens=k_p,
temperature=temperature,
num_denoise_steps=num_denoise_steps,
sparse_weight=sparse_weight,
normalize=normalize,
flash_attn=adapter.flash_attn,
use_eos=use_eos,
use_k_adapter=use_k_adapter,
adapter_weight=adapter_weight,
teacher_temperature=teacher_temperature,
k_adapter_options=k_adapter_options,
use_k_pre_encoder=use_k_pre_encoder,
gumbel_temperature=gumbel_temperature,
k_cost_lambda=k_cost_lambda,
k_pre_encoder_options=k_pre_encoder_options,
)
model.model_type = model_type
model.adapter = adapter
model.lora_rank = lora_rank
model.lora_alpha = lora_alpha
# Hook on output projection for efficient hidden state extraction
model._last_hidden: Dict[str, torch.Tensor] = {}
if disable_hidden_hook:
logger.info(f"{model_type}: hidden hook disabled by flag")
model._hook_registered = False
else:
model._hook_registered = adapter.register_hidden_hook(
backbone, model._last_hidden)
if not model._hook_registered:
logger.info(f"{model_type}: no hook — will use output_hidden_states=True")
return model
# ----------------------------------------------------------------
# Prompt helpers (mirrors DreamRetriever)
# ----------------------------------------------------------------
@staticmethod
def _load_yaml(path: str) -> dict:
import yaml
p = Path(path)
if not p.exists():
raise FileNotFoundError(f"Prompt YAML not found: {path}")
return yaml.safe_load(p.read_text())
@staticmethod
def _adapt_for_k(text: str, k: int) -> str:
if k <= 1 or not text:
return text
count = _NUM_WORDS[k - 1] if k <= len(_NUM_WORDS) else str(k)
result = re.sub(
r'\b(?:' + '|'.join(_NUM_WORDS) + r')\b(\s+words?)',
lambda m: f'{count} words', text,
)
return re.sub(r'\bword is\b', 'words are', result)
@staticmethod
def _build_prompt_ids(tokenizer, yaml_dict: dict) -> Tuple[List[int], List[int]]:
system = yaml_dict.get('system', '')
user_prefix = yaml_dict.get('user_prefix', '')
user_suffix = yaml_dict.get('user_suffix', '')
assistant_prefix = yaml_dict.get('assistant_prefix', '')
SENTINEL = "XSENTINELX"
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": user_prefix + SENTINEL + user_suffix})
full_str = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
pre_str, post_str = full_str.split(SENTINEL, 1)
prefix_ids = tokenizer.encode(pre_str, add_special_tokens=False)
suffix_ids = tokenizer.encode(post_str + assistant_prefix, add_special_tokens=False)
return prefix_ids, suffix_ids
# ----------------------------------------------------------------
# Structural tail tokens (matches zero-shot dream/llada retrievers)
# ----------------------------------------------------------------
def _build_tail_ids(self) -> List[int]:
"""Build structural tail tokens: [quote, chat_end, EOS].
Matches the zero-shot DreamRetriever._build_tail_ids and
LLaDA2Retriever._build_tail_ids so that training sees the same
sequence layout as zero-shot inference.
"""
if self._n_tail <= 0:
return []
eos_id = self.tokenizer.eos_token_id
tail: List[int] = []
# Closing quote "
quote_ids = self.tokenizer.encode('"', add_special_tokens=False)
if len(quote_ids) == 1:
tail.append(quote_ids[0])
else:
logger.warning('quote \'"\' tokenises to %d tokens; using EOS in slot 0', len(quote_ids))
tail.append(eos_id)
# Chat-template end token (im_end for Dream/Qwen, eot_id for LLaDA)
_KNOWN_EOT = {'<|im_end|>': None, '<|eot_id|>': 126348} # fallback for LLaDA
for tok_str, fallback_id in _KNOWN_EOT.items():
tid = self.tokenizer.convert_tokens_to_ids(tok_str)
unk = getattr(self.tokenizer, 'unk_token_id', None)
if tid is not None and tid >= 0 and tid != unk and tid != eos_id:
tail.append(tid)
break
elif fallback_id is not None and fallback_id != eos_id:
tail.append(fallback_id)
break
# Fill remaining slots with EOS
while len(tail) < self._n_tail:
tail.append(eos_id)
return tail[:self._n_tail]
# ----------------------------------------------------------------
# Tokenization (called from the data collator, runs on CPU)
# ----------------------------------------------------------------
def tokenize(
self, texts: List[str], is_query: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns (input_ids, attention_mask) tensors on CPU."""
prefix_ids = self._query_prefix_ids if is_query else self._passage_prefix_ids
suffix_ids = self._query_suffix_ids if is_query else self._passage_suffix_ids
gen_ids = [self.mask_token_id] * self._k(is_query)
mask_block = gen_ids + self._tail_ids # [MASK×K | " | chat_end | EOS]
max_text_len = self.max_length - len(prefix_ids) - len(suffix_ids)
enc = self.tokenizer(
texts,
padding=False,
truncation=True,
max_length=max_text_len,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False,
)
enc['input_ids'] = [
prefix_ids + ids + suffix_ids + mask_block
for ids in enc['input_ids']
]
collated = self.tokenizer.pad(
enc,
padding=True,
return_attention_mask=True,
return_tensors='pt',
)
return collated['input_ids'], collated['attention_mask']
# ----------------------------------------------------------------
# Attention mask: bidirectional (padding-only masking)
# ----------------------------------------------------------------
def _build_4d_mask(
self, seq_len: int, attention_mask: torch.Tensor
) -> torch.Tensor:
"""4D [B, 1, S, S] bidirectional attention mask."""
dtype = next(self.backbone.parameters()).dtype
min_val = torch.finfo(dtype).min
B = attention_mask.size(0)
mask_4d = torch.zeros(B, 1, seq_len, seq_len,
device=attention_mask.device, dtype=dtype)
pad = ~attention_mask.bool()
mask_4d = mask_4d.masked_fill(pad.unsqueeze(1).unsqueeze(1), min_val)
mask_4d = mask_4d.masked_fill(pad.unsqueeze(1).unsqueeze(3), min_val)
return mask_4d
# ----------------------------------------------------------------
# Backbone forward
# ----------------------------------------------------------------
def _fwd(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
need_logits: bool = False,
mask_4d: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Single backbone pass. Returns (hidden [B,L,H], logits [B,L,V] or None).
mask_4d: pre-computed 4D attention mask — pass it in the multi-step loop
to avoid recomputing the same mask at every denoising step.
inputs_embeds: if provided, use these embeddings instead of input_ids.
"""
# Adapter determines mask format: 4D [B,1,L,L] or 2D [B,L]
use_4d = self.adapter.needs_4d_mask() if self.adapter else True
if use_4d:
if mask_4d is None:
seq_len = inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
mask_4d = self._build_4d_mask(seq_len, attention_mask)
fwd_mask = mask_4d
else:
fwd_mask = attention_mask
# If hook is registered, skip output_hidden_states (saves ~1.2GB).
need_hidden_states = not self._hook_registered
fwd_kwargs = dict(
attention_mask=fwd_mask,
output_hidden_states=need_hidden_states,
return_dict=True,
)
if inputs_embeds is not None:
fwd_kwargs['inputs_embeds'] = inputs_embeds
else:
fwd_kwargs['input_ids'] = input_ids
outputs = self.backbone(**fwd_kwargs)
# Keep hidden in native dtype (bf16). Callers convert only small slices.
if self._hook_registered and 'h' in self._last_hidden:
hidden = self._last_hidden.pop('h')
elif getattr(outputs, 'last_hidden_state', None) is not None:
hidden = outputs.last_hidden_state
elif hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
hidden = outputs.hidden_states[-1]
else:
hidden = outputs[0]
# Keep logits in native dtype (bf16) to avoid doubling memory.
# Downstream consumers convert small slices to float32 as needed.
logits = outputs.logits if need_logits and hasattr(outputs, 'logits') else None
if (self.debug_compare_hidden_once
and self._hook_registered
and not self._debug_hidden_compared):
should_log = (not torch.distributed.is_initialized()
or torch.distributed.get_rank() == 0)
if should_log:
with torch.no_grad():
ref_outputs = self.backbone(
input_ids=input_ids,
attention_mask=fwd_mask,
output_hidden_states=True,
return_dict=True,
)
if getattr(ref_outputs, 'last_hidden_state', None) is not None:
ref_hidden = ref_outputs.last_hidden_state
elif hasattr(ref_outputs, 'hidden_states') and ref_outputs.hidden_states is not None:
ref_hidden = ref_outputs.hidden_states[-1]
else:
ref_hidden = ref_outputs[0]
diff = (hidden.float() - ref_hidden.float()).abs()
logger.warning(
"Hidden-hook check: model=%s shape=%s mean_abs_diff=%.6g max_abs_diff=%.6g",
self.model_type, tuple(hidden.shape),
diff.mean().item(), diff.max().item(),
)
self._debug_hidden_compared = True
return hidden, logits
# ----------------------------------------------------------------
# Confidence-based unmasking
# ----------------------------------------------------------------
@staticmethod
def _sample_with_confidence(logits: torch.Tensor, alg: str = 'entropy'):
"""Greedy decode + confidence score for discrete diffusion unmasking.
Args:
logits: [N, V] logits at the N currently-masked positions
alg: 'entropy' (neg entropy; higher = more certain) or 'max_prob'
Returns:
(confidence [N], x0 [N]) — confidence scores and predicted token IDs
"""
probs = torch.softmax(logits.float(), dim=-1)
x0 = probs.argmax(dim=-1)
if alg == 'entropy':
log_probs = torch.log(probs + 1e-10)
confidence = (probs * log_probs).sum(dim=-1) # neg entropy
else:
confidence = probs.max(dim=-1).values
return confidence, x0
@torch.no_grad()
def _unmask_step(
self,
curr_ids: torch.Tensor, # [B, L]
logits: torch.Tensor, # [B, L, V]
K: int,
n_per_step: int,
) -> Tuple[torch.Tensor, List[List[int]]]:
"""Unmask n_per_step most-confident gen tokens (uniform LLaDA-style schedule).
Returns:
new_ids: updated token IDs (same shape as curr_ids)
newly_decoded: list-of-lists; newly_decoded[i] = gen-block positions
[0..K-1] that were just decoded for example i
"""
new_ids = curr_ids.clone()
B, L = curr_ids.shape
g_start = L - K - self._n_tail
newly_decoded: List[List[int]] = [[] for _ in range(B)]
for i in range(B):
gen_ids = new_ids[i, g_start:g_start + K]
mask_bool = (gen_ids == self.mask_token_id)
n_masked = mask_bool.sum().item()
if n_masked == 0:
continue
# Convert only the small slice [n_masked, V] to float32, not full logits
ml = logits[i, g_start:g_start + K][mask_bool].detach().float()
if not torch.isfinite(ml).all():
continue
conf, x0 = self._sample_with_confidence(ml)
n_tr = min(n_per_step, n_masked)
_, xfer = torch.topk(conf, n_tr)
masked_abs = torch.where(mask_bool)[0]
selected = masked_abs[xfer]
new_ids[i, g_start + selected] = x0[xfer]
newly_decoded[i] = selected.tolist()
return new_ids, newly_decoded
# ----------------------------------------------------------------
# Soft-token multi-step helpers
# ----------------------------------------------------------------
def _get_embed_layer(self) -> nn.Module:
"""Return the input embedding layer (works through PEFT wrapper)."""
return self.backbone.get_input_embeddings()
def _soft_unmask_step(
self,
curr_embeds: torch.Tensor, # [B, L, H_emb] — current embeddings (differentiable)
curr_ids: torch.Tensor, # [B, L] — tracking which positions are still MASK
logits: torch.Tensor, # [B, L, V] — logits from current forward pass
K: int,
n_per_step: int,
temperature: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]]]:
"""Soft-unmask n_per_step most confident gen tokens.
Instead of hard token replacement (argmax → embed lookup), computes
differentiable soft embeddings: softmax(logits / τ) @ embedding_matrix.
Gradient flows through logits → softmax → matmul → next-step embeddings.
Returns:
new_embeds: [B, L, H_emb] with soft tokens at decoded positions
new_ids: [B, L] updated token IDs (hard argmax, for tracking only)
newly_decoded: list of gen-block positions [0..K-1] decoded this step
"""
embed_weight = self._get_embed_layer().weight # [V, H_emb]
B, L = curr_ids.shape
g_start = L - K - self._n_tail
new_ids = curr_ids.clone()
newly_decoded: List[List[int]] = [[] for _ in range(B)]
# Compute soft embeddings for all K gen positions (differentiable)
gen_logits = logits[:, g_start:g_start + K, :] # [B, K, V]
soft_probs = F.softmax(gen_logits / temperature, dim=-1) # [B, K, V]
soft_embs = soft_probs @ embed_weight # [B, K, H_emb]
# Determine which positions to unmask (most confident, same schedule as hard)
replace_mask = torch.zeros(B, K, dtype=torch.bool, device=curr_embeds.device)
for i in range(B):
gen_ids = curr_ids[i, g_start:g_start + K]
mask_bool = (gen_ids == self.mask_token_id)
n_masked = mask_bool.sum().item()
if n_masked == 0:
continue
ml = gen_logits[i][mask_bool].detach().float()
if not torch.isfinite(ml).all():
continue
conf, x0 = self._sample_with_confidence(ml)
n_tr = min(n_per_step, n_masked)
_, xfer = torch.topk(conf, n_tr)
masked_abs = torch.where(mask_bool)[0]
selected = masked_abs[xfer]
replace_mask[i, selected] = True
new_ids[i, g_start + selected] = x0[xfer]
newly_decoded[i] = selected.tolist()
# Mix: replace selected gen positions with soft embeddings, keep rest
curr_gen_embs = curr_embeds[:, g_start:g_start + K, :] # [B, K, H_emb]
new_gen_embs = torch.where(
replace_mask.unsqueeze(-1), # [B, K, 1]
soft_embs, # [B, K, H_emb] — differentiable
curr_gen_embs, # [B, K, H_emb]
)
# Rebuild full sequence (avoids in-place ops on grad-carrying tensor)
new_embeds = torch.cat([
curr_embeds[:, :g_start, :],
new_gen_embs,
curr_embeds[:, g_start + K:, :],
], dim=1)
return new_embeds, new_ids, newly_decoded
def _soft_multistep_forward(
self,
query_input_ids: torch.Tensor,
query_attention_mask: torch.Tensor,
passage_input_ids: torch.Tensor,
passage_attention_mask: torch.Tensor,
query_content_ids: Optional[List],
passage_content_ids: Optional[List],
n_steps: int,
) -> Dict[str, torch.Tensor]:
"""Multi-step denoising training with soft (differentiable) token replacement.
Unlike hard multi-step where decoded positions are frozen (.detach()),
soft-token mode keeps gradient flowing through ALL K positions at every
step via softmax(logits/τ) @ embedding_matrix.
Per-step contrastive loss (GIRCSE-style) provides direct supervision at
each denoising step. A monotonicity regularizer penalizes regressions.
"""
K_q, K_p = self.n_gen_q_tokens, self.n_gen_p_tokens
n_tail = self._n_tail
n_per_step_q = max(1, K_q // n_steps)
n_per_step_p = max(1, K_p // n_steps)
device = query_input_ids.device
B_q = query_input_ids.size(0)
B_p = passage_input_ids.size(0)
n_paq = B_p // B_q
L_q = query_input_ids.size(1)
L_p = passage_input_ids.size(1)
need_sparse = self.sparse_weight > 0
# Pre-compute 4D masks (sequence structure doesn't change across steps)
use_4d = self.adapter.needs_4d_mask() if self.adapter else True
q_mask_4d = self._build_4d_mask(L_q, query_attention_mask) if use_4d else None
p_mask_4d = self._build_4d_mask(L_p, passage_attention_mask) if use_4d else None
# Initial embeddings from the backbone's embedding layer
embed_layer = self._get_embed_layer()
q_embeds = embed_layer(query_input_ids)
p_embeds = embed_layer(passage_input_ids)
q_curr_ids = query_input_ids.clone()
p_curr_ids = passage_input_ids.clone()
q_g = L_q - K_q - n_tail
p_g = L_p - K_p - n_tail
step_losses: List[torch.Tensor] = []
final_loss: Dict[str, torch.Tensor] = {}
for step in range(n_steps):
is_last = (step == n_steps - 1)
# Forward pass using embeddings (step 0: original, step 1+: soft tokens)
q_h, q_logits = self._fwd(
q_curr_ids, query_attention_mask,
need_logits=True, mask_4d=q_mask_4d,
inputs_embeds=q_embeds if step > 0 else None,
)
p_h, p_logits = self._fwd(
p_curr_ids, passage_attention_mask,
need_logits=True, mask_4d=p_mask_4d,
inputs_embeds=p_embeds if step > 0 else None,
)
# Extract representations — ALL K positions have gradient
q_repr_hidden = q_h[:, q_g:q_g + K_q, :].float()
q_quotation_emb = q_h[:, q_g - 1, :].float()
p_repr_hidden = p_h[:, p_g:p_g + K_p, :].float()
p_quotation_emb = p_h[:, p_g - 1, :].float()
q_sparse_max = None
p_sparse_max = None
# Monotonic trick: log1p(relu(.)) is non-decreasing, so
# max_k log1p(relu(x_k)) == log1p(relu(max_k x_k))
# Maxing in bf16 first avoids materialising [B, K, V] fp32 intermediate.
if need_sparse and q_logits is not None:
q_sparse_max = torch.log1p(torch.relu(
q_logits[:, q_g:q_g + K_q, :].max(dim=1).values))
if need_sparse and p_logits is not None:
p_sparse_max = torch.log1p(torch.relu(
p_logits[:, p_g:p_g + K_p, :].max(dim=1).values))
if self.normalize:
q_quotation_emb = F.normalize(q_quotation_emb, p=2, dim=-1)
q_repr_hidden = F.normalize(q_repr_hidden, p=2, dim=-1)
p_quotation_emb = F.normalize(p_quotation_emb, p=2, dim=-1)
p_repr_hidden = F.normalize(p_repr_hidden, p=2, dim=-1)
q_repr = {'repr_hidden': q_repr_hidden, 'quotation_emb': q_quotation_emb}
p_repr = {'repr_hidden': p_repr_hidden, 'quotation_emb': p_quotation_emb}
if q_sparse_max is not None:
if query_content_ids is not None:
q_sparse_max = filter_sparse(q_sparse_max, query_content_ids)
q_repr['sparse_acts'] = q_sparse_max
if p_sparse_max is not None:
if passage_content_ids is not None:
p_sparse_max = filter_sparse(p_sparse_max, passage_content_ids)
p_repr['sparse_acts'] = p_sparse_max
# Cross-GPU negative sharing
q_repr = self._gather_repr(q_repr)
p_repr = self._gather_repr(p_repr)
B_q_g = q_repr['repr_hidden'].size(0)
B_p_g = p_repr['repr_hidden'].size(0)
labels = torch.arange(B_q_g, device=device) * (B_p_g // B_q_g)
# Per-step contrastive loss
step_loss_dict = self.compute_loss(q_repr, p_repr, labels)
step_losses.append(step_loss_dict['loss'])
if is_last:
final_loss = step_loss_dict
# Soft unmask for next step (per-side K)
if not is_last:
q_embeds, q_curr_ids, _ = self._soft_unmask_step(
q_embeds, q_curr_ids, q_logits, K_q, n_per_step_q,
self.soft_temperature)
p_embeds, p_curr_ids, _ = self._soft_unmask_step(
p_embeds, p_curr_ids, p_logits, K_p, n_per_step_p,
self.soft_temperature)
# Accumulate per-step losses (linearly increasing weight)
if len(step_losses) > 1:
progressive_loss = sum(
(s + 1) / n_steps * loss
for s, loss in enumerate(step_losses[:-1])
)
final_loss['loss'] = final_loss['loss'] + progressive_loss
final_loss['loss_progressive'] = progressive_loss.detach()
# Monotonicity regularizer: penalize if later step is worse
mono_penalty = torch.stack([
torch.relu(step_losses[k + 1] - step_losses[k].detach())
for k in range(len(step_losses) - 1)
]).mean()
final_loss['loss'] = final_loss['loss'] + 0.1 * mono_penalty
final_loss['loss_monotonicity'] = mono_penalty.detach()
# DDP scaling
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
final_loss['loss'] = final_loss['loss'] * torch.distributed.get_world_size()
# Denoising auxiliary (same as other paths)
if self.denoising_weight > 0 and final_loss:
p_corrupted, p_denoise_targets, mask_ratio = self._apply_text_masking(
passage_input_ids, passage_attention_mask)
_, p_logits_dn = self._fwd(p_corrupted, passage_attention_mask,
need_logits=True, mask_4d=p_mask_4d)
denoise_loss = self.compute_denoising_loss(
p_logits_dn, p_denoise_targets, mask_ratio)
final_loss['loss'] = final_loss['loss'] + self.denoising_weight * denoise_loss
final_loss['loss_denoising'] = denoise_loss.detach()
return final_loss
# ----------------------------------------------------------------
# K pre-encoder (two-stage variable-length training)
# ----------------------------------------------------------------
def _pre_encoder_logits(
self,
input_ids: torch.Tensor, # [B, L]
attention_mask: torch.Tensor, # [B, L]
is_query: bool,
) -> torch.Tensor:
"""Tiny embedding-only K-router. Mean-pools the embedding-layer
output over real (non-pad) tokens, then runs a small MLP head.
Returns [B, n_K_pe] logits over self.k_pre_encoder_options.
Cost ≈ embedding lookup + 2-layer MLP — under 1% of full forward.
"""
head = self.k_pre_encoder_q if is_query else self.k_pre_encoder_p
# Use main encoder's embedding layer (parameter-shared).
embed_layer = self._get_embed_layer()
emb = embed_layer(input_ids) # [B, L, H]
# Mean-pool over real tokens (attention_mask = 0 at pads).
m = attention_mask.float().unsqueeze(-1)
pooled = (emb.float() * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) # [B, H]
# Match head dtype to avoid AMP/DeepSpeed dtype mismatch.
head_dtype = next(head.parameters()).dtype
logits = head(pooled.to(dtype=head_dtype)).float() # [B, n_K_pe]
return logits
def _slice_to_per_item_K(
self,
input_ids: torch.Tensor, # [B, L_max] left-padded, K_max MASKs at end
attention_mask: torch.Tensor,
K_per_item: torch.Tensor, # [B] integers in self.k_pre_encoder_options
K_max_side: int, # n_gen_q_tokens or n_gen_p_tokens for this side
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Slice each item's input from K_max MASK tokens down to K MASK tokens
(per-item K), preserving the structural tail [", chat_end, EOS].
Layout (all left-padded — content right-aligned):
[PAD ... PAD | prefix | text | suffix | K_max MASKs | tail]
After slicing item with K_i:
[PAD ... PAD | prefix | text | suffix | K_i MASKs | tail]
Batch is then padded to max-K-in-batch length.
"""
B = input_ids.size(0)
L_in = input_ids.size(1)
n_tail = self._n_tail
K_per_item = K_per_item.tolist() if torch.is_tensor(K_per_item) else list(K_per_item)
# Per-item sliced sequences.
sliced_ids: List[torch.Tensor] = []
sliced_mask: List[torch.Tensor] = []
for i in range(B):
K_i = int(K_per_item[i])
n_remove = K_max_side - K_i
# Cut tokens [-(n_tail + n_remove) : -n_tail).
# Equivalent: keep input_ids[: L_in - n_tail - n_remove] + input_ids[-n_tail:]
keep_until = L_in - n_tail - n_remove
new_ids = torch.cat([input_ids[i, :keep_until], input_ids[i, -n_tail:]
if n_tail > 0 else input_ids.new_empty(0)])
new_mask = torch.cat([attention_mask[i, :keep_until], attention_mask[i, -n_tail:]
if n_tail > 0 else attention_mask.new_empty(0)])
sliced_ids.append(new_ids)
sliced_mask.append(new_mask)
# Pad LEFT to max length so the structural tail stays at position -1.
L_out = max(s.size(0) for s in sliced_ids)
pad_id = (self.tokenizer.pad_token_id or self.tokenizer.eos_token_id or 0)
out_ids = input_ids.new_full((B, L_out), pad_id)
out_mask = attention_mask.new_zeros((B, L_out))
for i, (s_ids, s_mask) in enumerate(zip(sliced_ids, sliced_mask)):
L_i = s_ids.size(0)
out_ids[i, L_out - L_i:] = s_ids
out_mask[i, L_out - L_i:] = s_mask
return out_ids, out_mask
def _extract_repr_per_K(
self,
hidden: torch.Tensor, # [B, L_out, H]
K_per_item: torch.Tensor, # [B] int K per item
n_K_max: int, # max K across self.k_pre_encoder_options
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Extract per-item repr_hidden using each item's K.
Returns:
repr_hidden: [B, n_K_max, H] — first K_i positions filled, rest zero
quotation_emb: [B, H] — position before first MASK per item
With left-padding, the rightmost positions per item are:
[..., gen_K_i MASKs, tail tokens (n_tail)]
So mask block = positions [L_out - n_tail - K_i, L_out - n_tail).
Quotation = position L_out - n_tail - K_i - 1.
"""
B, L_out, H = hidden.shape
n_tail = self._n_tail
K_per_item = K_per_item.tolist() if torch.is_tensor(K_per_item) else list(K_per_item)
repr_hidden = hidden.new_zeros(B, n_K_max, H)
quotation_emb = hidden.new_zeros(B, H)
for i in range(B):
K_i = int(K_per_item[i])
mask_end = L_out - n_tail
mask_start = mask_end - K_i
repr_hidden[i, :K_i] = hidden[i, mask_start:mask_end].float()
if mask_start - 1 >= 0:
quotation_emb[i] = hidden[i, mask_start - 1].float()
return repr_hidden, quotation_emb
def _gumbel_st_sample_K(
self,
logits: torch.Tensor, # [B, n_K_pe]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Gumbel-softmax sample with straight-through.
Returns:
K_per_item: [B] int — actual K integer values to use for slicing
soft_one_hot: [B, n_K_pe] — hard one-hot in fwd, soft in backward
(used for cost regularizer + ST credit)
chosen_prob: [B] — one-hot's value at chosen index (always 1.0 in fwd)
used to inject gradient back through ST
"""
soft_one_hot = F.gumbel_softmax(
logits, tau=max(self.gumbel_temperature, 1e-3), hard=True
) # [B, n_K_pe]
# Integer K from one-hot
K_idx = soft_one_hot.argmax(dim=-1) # [B]
K_per_item = self.k_pe_options_tensor[K_idx].long() # [B]
# ST credit: chosen_prob is always 1.0 in forward, has gradient via ST
chosen_prob = soft_one_hot.gather(-1, K_idx.unsqueeze(-1)).squeeze(-1)
return K_per_item, soft_one_hot, chosen_prob
def _forward_with_k_pre_encoder(
self,
query_input_ids: torch.Tensor,
query_attention_mask: torch.Tensor,
passage_input_ids: torch.Tensor,
passage_attention_mask: torch.Tensor,
query_content_ids: Optional[List],
passage_content_ids: Optional[List],
labels: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""Two-stage forward: pre-encoder predicts K_q, K_p per item; main
encoder runs at variable length (per-item K). Real encoding savings
when batch's max-K is below K_max.
Backprop:
- Main encoder: gradient flows through standard InfoNCE on per-item
sliced representations.
- Pre-encoder: gradient flows via Gumbel straight-through on
chosen_prob (always 1.0 in fwd; gradient injected through ST).
Plus a cost regularizer λ * E[K] discourages collapsing to K=K_max.
"""
device = labels.device
K_max_q = self.n_gen_q_tokens
K_max_p = self.n_gen_p_tokens
# ── 1. Pre-encoder forward (cheap — embedding lookup + small MLP).
q_logits = self._pre_encoder_logits(
query_input_ids, query_attention_mask, is_query=True)
p_logits = self._pre_encoder_logits(
passage_input_ids, passage_attention_mask, is_query=False)
# ── 2. Gumbel-ST sample K per item (queries + passages).
K_q_per_item, q_soft, q_chosen = self._gumbel_st_sample_K(q_logits)
K_p_per_item, p_soft, p_chosen = self._gumbel_st_sample_K(p_logits)
# Clamp at the per-side K_max in case the pre-encoder options exceed it.
# (Shouldn't happen by construction but a safety net.)
K_q_per_item = K_q_per_item.clamp(min=1, max=K_max_q)
K_p_per_item = K_p_per_item.clamp(min=1, max=K_max_p)
# ── 3. Slice each item's input to its K (variable-length batch).
q_ids_sliced, q_mask_sliced = self._slice_to_per_item_K(
query_input_ids, query_attention_mask, K_q_per_item, K_max_q)
p_ids_sliced, p_mask_sliced = self._slice_to_per_item_K(
passage_input_ids, passage_attention_mask, K_p_per_item, K_max_p)
# ── 4. Main encoder forward (one pass each at variable length).
need_sparse = self.sparse_weight > 0
q_hidden, q_logits_main = self._fwd(
q_ids_sliced, q_mask_sliced, need_logits=need_sparse)
p_hidden, p_logits_main = self._fwd(
p_ids_sliced, p_mask_sliced, need_logits=need_sparse)
# ── 5. Per-item repr extraction.
n_K_max = max(self.k_pre_encoder_options)
q_repr_hidden, q_quotation_emb = self._extract_repr_per_K(
q_hidden, K_q_per_item, n_K_max)
p_repr_hidden, p_quotation_emb = self._extract_repr_per_K(
p_hidden, K_p_per_item, n_K_max)
# ── 6. Sparse: max-pool over each item's K positions.
q_sparse_max = None
p_sparse_max = None
if need_sparse and q_logits_main is not None:
# For each item, compute log(1+relu(...)) over its K MASK positions,
# then max-pool across positions. Same as legacy sparse computation.
B_q = q_ids_sliced.size(0)
L_q_out = q_ids_sliced.size(1)
q_sparse_max = q_logits_main.new_zeros(B_q, q_logits_main.size(-1))
for i in range(B_q):
K_i = int(K_q_per_item[i].item())
mask_end = L_q_out - self._n_tail
mask_start = mask_end - K_i
slogits = q_logits_main[i, mask_start:mask_end]
# Monotonic trick: max over K positions first → log1p(relu()) on [V].
q_sparse_max[i] = torch.log1p(torch.relu(slogits.max(dim=0).values))
if need_sparse and p_logits_main is not None:
B_p = p_ids_sliced.size(0)
L_p_out = p_ids_sliced.size(1)
p_sparse_max = p_logits_main.new_zeros(B_p, p_logits_main.size(-1))
for i in range(B_p):
K_i = int(K_p_per_item[i].item())
mask_end = L_p_out - self._n_tail
mask_start = mask_end - K_i
slogits = p_logits_main[i, mask_start:mask_end]
# Monotonic trick: max over K positions first → log1p(relu()) on [V].
p_sparse_max[i] = torch.log1p(torch.relu(slogits.max(dim=0).values))
# ── 7. Normalize.
if self.normalize:
q_repr_hidden = F.normalize(q_repr_hidden, p=2, dim=-1)
q_quotation_emb = F.normalize(q_quotation_emb, p=2, dim=-1)
p_repr_hidden = F.normalize(p_repr_hidden, p=2, dim=-1)
p_quotation_emb = F.normalize(p_quotation_emb, p=2, dim=-1)
# ── 8. ST credit injection — multiply repr by chosen_prob (always 1.0
# in fwd, has gradient through ST in backward). This is the standard
# Gumbel-ST trick for credit assignment when downstream operations
# (slicing, reshape) are non-differentiable.
q_repr_hidden = q_repr_hidden * q_chosen.float().unsqueeze(-1).unsqueeze(-1)
p_repr_hidden = p_repr_hidden * p_chosen.float().unsqueeze(-1).unsqueeze(-1)
q_quotation_emb = q_quotation_emb * q_chosen.float().unsqueeze(-1)
p_quotation_emb = p_quotation_emb * p_chosen.float().unsqueeze(-1)
# Build repr dicts and apply content filter to sparse.
q_repr = {'repr_hidden': q_repr_hidden, 'quotation_emb': q_quotation_emb}
p_repr = {'repr_hidden': p_repr_hidden, 'quotation_emb': p_quotation_emb}
if q_sparse_max is not None:
if query_content_ids is not None:
q_sparse_max = filter_sparse(q_sparse_max, query_content_ids)
q_repr['sparse_acts'] = q_sparse_max
if p_sparse_max is not None:
if passage_content_ids is not None:
p_sparse_max = filter_sparse(p_sparse_max, passage_content_ids)
p_repr['sparse_acts'] = p_sparse_max
# ── 9. Cross-GPU gather + recompute labels for gathered batch.
q_repr = self._gather_repr(q_repr)
p_repr = self._gather_repr(p_repr)
B_q_all = q_repr['repr_hidden'].size(0)
B_p_all = p_repr['repr_hidden'].size(0)
n_paq_g = B_p_all // B_q_all
labels_g = torch.arange(B_q_all, device=device) * n_paq_g
# ── 10. Standard InfoNCE on per-item-K representations (zero-padded
# positions contribute nothing thanks to MaxSim's max-over-positions).
loss_dict = self.compute_loss(q_repr, p_repr, labels_g)
# ── 11. Cost regularizer — λ * E[K_q + K_p] using soft probs (so it's
# differentiable w.r.t. logits). Discourages collapsing to K_max.
if self.k_cost_lambda > 0:
K_options_t = self.k_pe_options_tensor # [n_K_pe]
expected_K_q = (q_soft * K_options_t).sum(dim=-1).mean()
expected_K_p = (p_soft * K_options_t).sum(dim=-1).mean()
cost_loss = (expected_K_q + expected_K_p) / float(max(K_options_t.max().item(), 1.0))
loss_dict['loss'] = loss_dict['loss'] + self.k_cost_lambda * cost_loss
loss_dict['k_cost'] = cost_loss.detach()
loss_dict['expected_K_q'] = expected_K_q.detach()
loss_dict['expected_K_p'] = expected_K_p.detach()
# Diagnostics: marginal K distributions.
with torch.no_grad():
for i, K in enumerate(self.k_pre_encoder_options):
loss_dict[f'pe_kq_p{K}'] = q_soft[:, i].mean().detach()
loss_dict[f'pe_kp_p{K}'] = p_soft[:, i].mean().detach()
# DDP loss scaling (counters HF Trainer's gradient averaging).
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
loss_dict['loss'] = loss_dict['loss'] * torch.distributed.get_world_size()
return loss_dict
# ----------------------------------------------------------------
# Encode — single-pass or multi-step denoising (inference)
# ----------------------------------------------------------------
def encode(
self,
input_ids: torch.Tensor, # [B, L]
attention_mask: torch.Tensor, # [B, L]
is_query: bool = False,
compute_sparse: Optional[bool] = None,
content_token_ids: Optional[List] = None,
) -> Dict[str, torch.Tensor]:
"""
PromptReps encoding.
num_denoise_steps == 1: single forward pass (fast path, differentiable,
used in both training and inference).
num_denoise_steps > 1: iterative denoising loop (inference).
repr_hidden[i, k] is captured at the step when gen token k transitions
MASK → decoded, benefiting from previously-decoded token context.
Uniform schedule: n_per_step = K // n_steps tokens decoded per step.
quotation_emb is captured at the final step.
Call under torch.no_grad() for inference; use forward() for training.
Returns:
repr_hidden [B, K, H] ColBERT repr (one vector per gen token)
quotation_emb [B, H] closing " token (single-dense repr)
sparse_acts [B, V] sum of log(1+relu(logit)) across K positions
"""
K = self._k(is_query)
B = input_ids.size(0)
L = input_ids.size(1)
n_steps = self.num_denoise_steps
device = input_ids.device
g_start = L - K - self._n_tail # left-padded: MASK block then structural tail
if compute_sparse is None:
need_sparse = self.sparse_weight > 0
else:
need_sparse = compute_sparse
# Initialised here so both the single- and multi-step branches always leave
# these bound — the multi-step path only ever populates sparse_max.
sparse_max: Optional[torch.Tensor] = None
sparse_per_pos: Optional[torch.Tensor] = None
# ── Single-pass fast path (n_steps ≤ 1) ──────────────────────────────
if n_steps <= 1:
hidden, logits = self._fwd(input_ids, attention_mask, need_logits=need_sparse)
repr_hidden = hidden[:, g_start:g_start + K, :].float() # [B, K, H]
quotation_emb = hidden[:, g_start - 1, :].float() # [B, H]
if need_sparse and logits is not None:
# Per-position sparse acts for ColBERT-style multi-vector sparse
# (caller may pop result['sparse_acts_per_pos'] to do per-K topk).
# Keep in bf16 — the downstream topk → ×100 → int round path
# doesn't need fp32 precision, and the fp32 cast was the
# dominant memory bandwidth cost at K=16 (312 MB → 156 MB/batch
# for V=152k, B=32; 2× reduction).
gen_logits = logits[:, g_start:g_start + K, :] # [B, K, V] bf16
sparse_per_pos = torch.log1p(torch.relu(gen_logits)) # [B, K, V] bf16
# sparse_max stays fp32 (small, [B, V]) for downstream consumers.
sparse_max = sparse_per_pos.max(dim=1).values.float() # [B, V] fp32
# ── Multi-step denoising loop (n_steps > 1, inference) ───────────────
else:
curr_ids = input_ids.clone()
n_per_step = max(1, K // n_steps)
vocab_size = self.backbone.config.vocab_size
repr_buf = torch.zeros(B, K, self.hidden_size, device=device)
repr_saved = torch.zeros(B, K, dtype=torch.bool, device=device)
sparse_max = (torch.zeros(B, vocab_size, device=device)
if need_sparse else None)
quotation_emb = None
mask_4d = self._build_4d_mask(L, attention_mask) if (self.adapter.needs_4d_mask() if self.adapter else True) else None
for step in range(n_steps):
hidden, logits = self._fwd(curr_ids, attention_mask, need_logits=True, mask_4d=mask_4d)
is_last = (step == n_steps - 1)
# ── Vectorized over batch (replaces `for i in range(B): ... .item() ...`).
# Behavior preserved: greedy argmax, neg-entropy confidence, top-n_per_step
# transitions per item; on last step all remaining masks transition.
gen_ids_all = curr_ids[:, g_start:g_start + K] # [B, K] int
mask_bool_all = (gen_ids_all == self.mask_token_id) # [B, K] bool
ml_all = logits[:, g_start:g_start + K, :] # [B, K, V] bf16
# Confidence + argmax (matches _sample_with_confidence(alg='entropy'))
probs = F.softmax(ml_all.float(), dim=-1) # [B, K, V] fp32
log_probs = torch.log(probs + 1e-10)
conf = (probs * log_probs).sum(dim=-1) # [B, K] (neg entropy)
x0_pred = probs.argmax(dim=-1) # [B, K]
# Restrict topk to currently-masked positions
conf = torch.where(mask_bool_all, conf, torch.full_like(conf, -float('inf')))
if is_last:
xfer_mask = mask_bool_all # all remaining masks
else:
n_tr = min(n_per_step, K)
_, top_idx = conf.topk(n_tr, dim=-1) # [B, n_tr]
xfer_mask = torch.zeros_like(mask_bool_all)
xfer_mask.scatter_(1, top_idx, True)
xfer_mask = xfer_mask & mask_bool_all # only actually-masked positions
# ── Vectorized state updates ────────────────────────────────────
# 1. curr_ids: replace MASK with predicted at xfer positions
if xfer_mask.any():
new_K = torch.where(xfer_mask, x0_pred, gen_ids_all)
# In-place via clone (curr_ids may be a view of input_ids)
curr_ids = curr_ids.clone()
curr_ids[:, g_start:g_start + K] = new_K
# 2. repr_buf: save hidden at to_save positions (xfer & not yet saved)
to_save = xfer_mask & ~repr_saved
if to_save.any():
hidden_chunk = hidden[:, g_start:g_start + K, :] # [B, K, H]
repr_buf = torch.where(to_save.unsqueeze(-1), hidden_chunk, repr_buf)
# 3. sparse_max: max over to_save positions of log(1+relu(logits))
if sparse_max is not None:
sparse_chunk = torch.log1p(torch.relu(ml_all)) # [B, K, V] bf16
masked_sparse = torch.where(
to_save.unsqueeze(-1).expand_as(sparse_chunk),
sparse_chunk,
torch.full_like(sparse_chunk, -float('inf')))
new_max = masked_sparse.max(dim=1).values.float() # [B, V] fp32
sparse_max = torch.max(sparse_max, new_max)
repr_saved = repr_saved | to_save
if is_last:
quotation_emb = hidden[:, g_start - 1, :].float()
# Catch positions that were never masked (rare but possible
# if input already had a non-MASK token in the K block).
unsaved = ~repr_saved
if unsaved.any():
hidden_chunk = hidden[:, g_start:g_start + K, :]
repr_buf = torch.where(unsaved.unsqueeze(-1), hidden_chunk, repr_buf)
if sparse_max is not None:
sparse_chunk = torch.log1p(torch.relu(ml_all))
masked_sparse = torch.where(
unsaved.unsqueeze(-1).expand_as(sparse_chunk),
sparse_chunk,
torch.full_like(sparse_chunk, -float('inf')))
new_max = masked_sparse.max(dim=1).values.float()
sparse_max = torch.max(sparse_max, new_max)
repr_saved = repr_saved | unsaved
repr_buf = torch.nan_to_num(repr_buf)
repr_hidden = repr_buf
quotation_emb = torch.nan_to_num(quotation_emb)
# ── Normalize ─────────────────────────────────────────────────────────
if self.normalize:
quotation_emb = F.normalize(quotation_emb, p=2, dim=-1)
repr_hidden = F.normalize(repr_hidden, p=2, dim=-1)
result: Dict[str, torch.Tensor] = {
'repr_hidden': repr_hidden,
'quotation_emb': quotation_emb,
}
if sparse_max is not None:
if content_token_ids is not None:
sparse_max = filter_sparse(sparse_max, content_token_ids)
result['sparse_acts'] = sparse_max
# Per-position sparse [B, K, V] — single-pass only, K>1
if sparse_per_pos is not None and K > 1:
result['sparse_acts_per_pos'] = sparse_per_pos # unfiltered; caller applies content filter
return result
# ----------------------------------------------------------------
# Cross-GPU negative sharing
# ----------------------------------------------------------------
@staticmethod
def _dist_gather(t: torch.Tensor) -> torch.Tensor:
"""All-gather tensors across GPUs with gradient passthrough."""
if not (torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1):
return t
gathered = [torch.zeros_like(t) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered, t.contiguous())
# Replace own shard with original (keeps gradient)
gathered[torch.distributed.get_rank()] = t
return torch.cat(gathered, dim=0)
def _gather_repr(self, repr_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Gather all representation tensors across GPUs."""
out = {}
for k, v in repr_dict.items():
out[k] = self._dist_gather(v)
return out
# ----------------------------------------------------------------
# MaxSim (ColBERT-style)
# ----------------------------------------------------------------
@staticmethod
def maxsim(q: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
"""
q: [B_q, k, H], p: [B_p, k, H] → [B_q, B_p] MaxSim scores.
"""
sims = torch.einsum('ikh,jlh->ijkl', q, p) # [B_q, B_p, k_q, k_p]
return sims.max(dim=-1).values.sum(dim=-1) # [B_q, B_p]
# ----------------------------------------------------------------
# Auxiliary losses (diffusion-native)
# ----------------------------------------------------------------
@torch.no_grad()
def _corrupt_text(
self,
input_ids: torch.Tensor, # [B, L]
attention_mask: torch.Tensor, # [B, L]
rate: float,
) -> torch.Tensor:
"""Randomly replace rate% of text tokens with MASK for corruption augmentation.
Only corrupts real text tokens (not padding, not gen MASKs, not EOS).
Returns a new tensor (does not modify input in-place).
"""
B, L = input_ids.shape
K = self.n_gen_p_tokens # only called on passages
g_start = L - K - self._n_tail
corrupted = input_ids.clone()
is_real = attention_mask.bool()
is_gen = torch.zeros(L, dtype=torch.bool, device=input_ids.device)
is_gen[g_start:] = True
is_candidate = is_real & ~is_gen.unsqueeze(0) & (input_ids != self.mask_token_id)
for i in range(B):
cand_idx = torch.where(is_candidate[i])[0]
if cand_idx.numel() == 0:
continue
n_mask = max(1, int(cand_idx.numel() * rate))
perm = torch.randperm(cand_idx.numel(), device=input_ids.device)[:n_mask]
corrupted[i, cand_idx[perm]] = self.mask_token_id
return corrupted
def _apply_text_masking(
self,
input_ids: torch.Tensor, # [B, L]
attention_mask: torch.Tensor, # [B, L]
) -> Tuple[torch.Tensor, torch.Tensor, float]:
"""Randomly mask text tokens (NOT generation MASKs) for denoising auxiliary loss.
Returns:
corrupted_ids: input_ids with some text tokens replaced by MASK
denoise_targets: original token IDs at masked positions (-100 elsewhere)
mask_ratio: fraction of tokens masked (t for 1/t weighting)
"""
B, L = input_ids.shape
K = self.n_gen_p_tokens # called on passages
g_start = L - K - self._n_tail # left-padded: gen block + structural tail
corrupted = input_ids.clone()
targets = torch.full_like(input_ids, -100) # -100 = ignore in CE
# Candidate text positions: not padding, not in gen/EOS block, not already MASK
is_real = attention_mask.bool() # [B, L]
is_gen = torch.zeros(L, dtype=torch.bool, device=input_ids.device)
is_gen[g_start:] = True # marks MASK positions + EOS as non-candidate
is_candidate = is_real & ~is_gen.unsqueeze(0) & (input_ids != self.mask_token_id)
for i in range(B):
cand_idx = torch.where(is_candidate[i])[0]
if cand_idx.numel() == 0:
continue
n_mask = max(1, int(cand_idx.numel() * self.denoise_mask_ratio))
perm = torch.randperm(cand_idx.numel(), device=input_ids.device)[:n_mask]
mask_positions = cand_idx[perm]
targets[i, mask_positions] = input_ids[i, mask_positions]
corrupted[i, mask_positions] = self.mask_token_id
mask_ratio = self.denoise_mask_ratio
return corrupted, targets, mask_ratio
def compute_denoising_loss(
self,
logits: torch.Tensor, # [B, L, V] from forward pass on corrupted input
targets: torch.Tensor, # [B, L] with -100 for non-masked positions
mask_ratio: float, # t for 1/t weighting
) -> torch.Tensor:
"""Compute Dream/LLaDA-style denoising loss: weighted CE on masked text tokens."""
# Flatten and compute CE (ignores -100 positions automatically)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-100,
)
# Weight by 1/t (matching Dream/LLaDA pre-training ELBO)
return loss / max(mask_ratio, 1e-3)
def compute_diversity_loss(
self,
repr_hidden: torch.Tensor, # [B, K, H]
) -> torch.Tensor:
"""Push K representation vectors to be diverse (low pairwise cosine similarity)."""
K = repr_hidden.size(1)
if K <= 1:
return torch.tensor(0.0, device=repr_hidden.device)
# Pairwise cosine similarity [B, K, K] — repr_hidden is already L2-normalised
sim_matrix = torch.bmm(repr_hidden, repr_hidden.transpose(1, 2)) # [B, K, K]
# Mean of upper triangle (exclude diagonal)
mask = torch.triu(torch.ones(K, K, device=repr_hidden.device), diagonal=1).bool()
pairwise_sims = sim_matrix[:, mask] # [B, K*(K-1)/2]
# Hinge: penalize similarities above 0 (push toward orthogonal)
diversity_loss = torch.relu(pairwise_sims).mean()
return diversity_loss
@staticmethod
def _mean_offdiag_cos(repr_hidden: torch.Tensor) -> torch.Tensor:
"""Mean off-diagonal cosine similarity over K vectors."""
K = repr_hidden.size(1)
if K <= 1:
return torch.tensor(0.0, device=repr_hidden.device)
normed = F.normalize(repr_hidden.float(), p=2, dim=-1)
sim = torch.bmm(normed, normed.transpose(1, 2))
mask = torch.triu(torch.ones(K, K, device=repr_hidden.device), diagonal=1).bool()
return sim[:, mask].mean()
def _dense_debug_stats(
self,
q_repr: Dict[str, torch.Tensor],
p_repr: Dict[str, torch.Tensor],
labels: torch.Tensor,
q_raw: Optional[torch.Tensor] = None,
p_raw: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""Summaries that explain whether dense loss comes from poor score gaps or collapsed vectors."""
scores = self.maxsim(q_repr['repr_hidden'], p_repr['repr_hidden'])
row_idx = torch.arange(scores.size(0), device=scores.device)
pos = scores[row_idx, labels]
neg = scores.clone()
neg[row_idx, labels] = float('-inf')
hard_neg = neg.max(dim=1).values
stats = {
'debug_pos_score': pos.mean().detach(),
'debug_hardneg_score': hard_neg.mean().detach(),
'debug_score_gap': (pos - hard_neg).mean().detach(),
'debug_pos_beats_hardneg': (pos > hard_neg).float().mean().detach(),
'debug_q_mask_cos': self._mean_offdiag_cos(q_repr['repr_hidden']).detach(),
'debug_p_mask_cos': self._mean_offdiag_cos(p_repr['repr_hidden']).detach(),
}
if q_raw is not None:
stats['debug_q_mask_norm'] = q_raw.float().norm(dim=-1).mean().detach()
stats['debug_q_mask_raw_cos'] = self._mean_offdiag_cos(q_raw).detach()
if p_raw is not None:
stats['debug_p_mask_norm'] = p_raw.float().norm(dim=-1).mean().detach()
stats['debug_p_mask_raw_cos'] = self._mean_offdiag_cos(p_raw).detach()
return stats
# ----------------------------------------------------------------
# Loss
# ----------------------------------------------------------------
def compute_loss(
self,
q_repr: Dict[str, torch.Tensor],
p_repr: Dict[str, torch.Tensor],
labels: torch.Tensor, # [B_q] index of positive in B_p
) -> Dict[str, torch.Tensor]:
"""
Dense loss: ColBERT MaxSim on repr_hidden (MASK positions).
For K=1 this reduces to dot product of the single MASK vector.
Sparse InfoNCE (if sparse_weight > 0): raw dot product + temperature.
K-adapter: when self.use_k_adapter is True, dispatch to the per-cell
loss path that trains both the retriever and the K-routing head.
"""
if self.use_k_adapter:
return self._compute_loss_with_k_adapter(q_repr, p_repr, labels)
device = labels.device
K = q_repr['repr_hidden'].size(1)
total_loss = torch.tensor(0.0, device=device)
# Primary dense: ColBERT MaxSim on repr_hidden (MASK positions).
colbert_scores = self.maxsim(q_repr['repr_hidden'], p_repr['repr_hidden'])
colbert_loss = F.cross_entropy(colbert_scores / K / self.temperature, labels)
total_loss = self.dense_weight * colbert_loss
# Sparse InfoNCE — raw dot product, same as inference.
# Content-token filtering (PromptReps-style) applied in forward() keeps
# scores small (~10-100), so no clamping or normalization needed.
sparse_loss = torch.tensor(0.0, device=device)
if (self.sparse_weight > 0
and 'sparse_acts' in q_repr
and 'sparse_acts' in p_repr):
sparse_scores = q_repr['sparse_acts'] @ p_repr['sparse_acts'].T
sparse_loss = F.cross_entropy(sparse_scores, labels)
total_loss = total_loss + self.sparse_weight * sparse_loss
# Auxiliary: diversity loss on repr_hidden (if enabled)
diversity_loss = torch.tensor(0.0, device=device)
if self.diversity_weight > 0 and K > 1:
# Average over query and passage batches (both need diverse vectors)
diversity_loss = 0.5 * (
self.compute_diversity_loss(q_repr['repr_hidden']) +
self.compute_diversity_loss(p_repr['repr_hidden'])
)
total_loss = total_loss + self.diversity_weight * diversity_loss
result = {
'loss': total_loss,
'loss_dense': (self.dense_weight * colbert_loss).detach(),
'loss_sparse': sparse_loss.detach(),
}
if self.diversity_weight > 0:
result['loss_diversity'] = diversity_loss.detach()
# loss_denoising is added in forward() since it needs logits
return result
# ----------------------------------------------------------------
# K-adapter loss
# ----------------------------------------------------------------
def _compute_loss_with_k_adapter(
self,
q_repr: Dict[str, torch.Tensor],
p_repr: Dict[str, torch.Tensor],
labels: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""Joint training of retriever + K-routing head.
For each (K_q, K_p) cell in `self.k_adapter_options × self.k_adapter_options`:
1. Slice prefix of q_repr / p_repr at that K.
2. Compute MaxSim → fuse with sparse → CE loss vs in-batch negatives.
3. Per-query loss tensor of shape [B_q, n_K_q, n_K_p].
Teacher distribution per query: softmax(-per_cell_loss / τ_T).
Adapter (small MLP on quotation_emb) predicts a softmax over cells.
Adapter loss = KL(teacher || adapter_probs) — the adapter learns to
mimic which cell the model itself prefers (lowest CE) without any
oracle/MRR involvement.
Retriever loss = sum_c adapter_probs[i, c] * per_cell_loss[i, c]:
- early in training (uniform adapter): ~ matryoshka-style avg over
all cells, every prefix gets gradient → all K positions trained.
- late in training (peaky adapter): gradient concentrates on
adapter-preferred cell → cell specialization.
"""
device = labels.device
K_opts = self.k_adapter_options
n_K = self.n_K
K_max_q = q_repr['repr_hidden'].size(1)
K_max_p = p_repr['repr_hidden'].size(1)
# If model was instantiated at smaller K than max(k_adapter_options),
# filter the option list down to what we actually have positions for.
valid_q = [k for k in K_opts if k <= K_max_q]
valid_p = [k for k in K_opts if k <= K_max_p]
# Index lookup back into the n_K-wide cell grid (so options outside
# the model's K-range get teacher prob 0 / adapter logits ignored).
idx_q = [K_opts.index(k) for k in valid_q]
idx_p = [K_opts.index(k) for k in valid_p]
# ── 1. Sparse score (K-invariant — added to every cell uniformly).
# Stays max-pooled so it doesn't depend on K choice.
sparse_scores = None
if (self.sparse_weight > 0
and 'sparse_acts' in q_repr
and 'sparse_acts' in p_repr):
sparse_scores = q_repr['sparse_acts'] @ p_repr['sparse_acts'].T
# Note: sparse uses raw dot product (no temperature), matching the
# legacy compute_loss path. Adding it BEFORE softmax in CE won't
# change ranking since it shifts every doc by the same per-query
# amount — but it DOES affect per-cell CE values (different cells
# have different dense magnitudes; sparse adds a constant). This
# is intentional: the per-cell loss reflects the actual fused-score
# contrast the model is trained on.
#
# We deliberately apply temperature scaling to the dense part only,
# mirroring how inference fuses dense (MaxSim/K) + sparse_weight*sparse.
B_q = q_repr['repr_hidden'].size(0)
# ── 2. Per-cell dense scores → per-cell CE losses.
# [B_q, B_p, K_max_q, K_max_p] is the full per-position dot product.
# We compute MaxSim per (K_q, K_p) by slicing prefixes.
per_cell_losses = q_repr['repr_hidden'].new_zeros(B_q, n_K, n_K)
# Precompute the "all positions" inner-product tensor once:
# sims_full[b_q, b_p, i, j] = q_repr[b_q, i] · p_repr[b_q_p_pair, j]
# then maxsim(K_q, K_p) = sims_full[:, :, :K_q, :K_p].max(-1).sum(-1)
sims_full = torch.einsum('bih,djh->bdij',
q_repr['repr_hidden'].float(),
p_repr['repr_hidden'].float()) # [B_q, B_p, K_max_q, K_max_p]
for ii, K_q in enumerate(valid_q):
ki = idx_q[ii]
for jj, K_p in enumerate(valid_p):
kj = idx_p[jj]
# Slice → max over passage positions → sum over query positions.
cell_dense = sims_full[:, :, :K_q, :K_p].max(dim=-1).values.sum(dim=-1)
cell_dense = cell_dense / K_q # MaxSim normalization, matches legacy
cell_score = self.dense_weight * cell_dense / self.temperature
if sparse_scores is not None:
cell_score = cell_score + self.sparse_weight * sparse_scores
# Per-query CE — reduction='none' for [B_q] tensor, no batch mean.
cell_loss = F.cross_entropy(cell_score, labels, reduction='none') # [B_q]
per_cell_losses[:, ki, kj] = cell_loss
# ── 3. Teacher distribution: softmax(-loss/τ_T) over cells.
# Lower loss (better separation) → higher teacher weight.
flat_losses = per_cell_losses.view(B_q, n_K * n_K) # [B_q, n_K^2]
teacher_probs = F.softmax(
-flat_losses.detach() / max(self.teacher_temperature, 1e-6), dim=-1
)
# ── 4. Adapter prediction. Input: query's quotation_emb (K-invariant
# content-aware summary). Output: distribution over n_K^2 cells.
if 'quotation_emb' not in q_repr:
raise RuntimeError(
"K-adapter requires 'quotation_emb' in q_repr. "
"encode() must populate it before compute_loss()."
)
# Match adapter weight dtype (DeepSpeed/AMP may cast adapter to bf16).
# quotation_emb is fp32 from encode(); cast it to whatever the linear
# weight is, run the MLP, then cast logits back to fp32 for stable
# log_softmax + KL arithmetic.
adapter_dtype = next(self.k_adapter.parameters()).dtype
adapter_input = q_repr['quotation_emb'].to(dtype=adapter_dtype) # [B_q, H]
adapter_logits = self.k_adapter(adapter_input).float() # [B_q, n_K^2]
adapter_log_probs = F.log_softmax(adapter_logits, dim=-1)
adapter_probs = adapter_log_probs.exp()
# ── 5. Adapter loss: KL( teacher || adapter ).
L_adapter = F.kl_div(adapter_log_probs, teacher_probs, reduction='batchmean')
# ── 6. Retriever loss: soft mixture over cells weighted by adapter.
# Early training: adapter ~uniform → every cell gets gradient
# (matryoshka-style — all K prefixes become useful).
# Late training: adapter peaky → mostly the chosen cell gets
# gradient → specialization at that K.
L_retriever = (adapter_probs * flat_losses).sum(dim=-1).mean()
total_loss = L_retriever + self.adapter_weight * L_adapter
# Diagnostics: marginal K_q and K_p distributions averaged over batch.
with torch.no_grad():
cell_grid = adapter_probs.view(B_q, n_K, n_K) # [B_q, n_K_q, n_K_p]
kq_dist = cell_grid.sum(dim=-1).mean(dim=0) # [n_K]
kp_dist = cell_grid.sum(dim=-2).mean(dim=0) # [n_K]
entropy = -(adapter_probs * adapter_log_probs).sum(dim=-1).mean()
# Standalone sparse-only CE for logging (mirrors the legacy
# `loss_sparse` field). This is the InfoNCE on raw sparse scores
# — it's NOT used for the gradient (sparse already contributes
# via the per-cell dense+sparse fused score). Reported just so
# users can monitor sparse-only retrieval quality during training.
if sparse_scores is not None:
sparse_only_loss = F.cross_entropy(sparse_scores, labels)
else:
sparse_only_loss = torch.tensor(0.0, device=device)
result = {
'loss': total_loss,
'loss_dense': L_retriever.detach(),
'loss_sparse': sparse_only_loss.detach(), # diagnostic only — not used for gradient
'loss_adapter': L_adapter.detach(),
'adapter_entropy': entropy.detach(),
}
# log marginals: kq_p1, kq_p2, kq_p4, kq_p8, kq_p16, kp_p1, ...
for i, K in enumerate(K_opts):
if i < kq_dist.shape[0]:
result[f'kq_p{K}'] = kq_dist[i].detach()
result[f'kp_p{K}'] = kp_dist[i].detach()
return result
# ----------------------------------------------------------------
# HF Trainer-compatible forward
# ----------------------------------------------------------------
def forward(
self,
query_input_ids: torch.Tensor,
query_attention_mask: torch.Tensor,
passage_input_ids: torch.Tensor,
passage_attention_mask: torch.Tensor,
query_content_ids: Optional[List] = None,
passage_content_ids: Optional[List] = None,
) -> Dict[str, torch.Tensor]:
"""
passages layout: [pos_0, neg_0_0, …, neg_0_M, pos_1, neg_1_0, …]
positive for query i is at index i * (1 + n_neg).
num_denoise_steps == 1: single-pass fast path.
num_denoise_steps > 1: multi-step denoising training.
Runs the full denoising loop (each step decodes n_per_step tokens,
storing frozen hidden states in repr_buf), then computes loss once
at the final step. Gradient flows from the final mixed representation
only — no dilution across steps. Uniform unmasking: n_per_step = K // n_steps.
"""
B_q = query_input_ids.size(0)
B_p = passage_input_ids.size(0)
assert B_p % B_q == 0, f"B_p ({B_p}) must be divisible by B_q ({B_q})"
n_paq = B_p // B_q
device = query_input_ids.device
labels = torch.arange(B_q, device=device) * n_paq
n_steps = self.num_denoise_steps
# ── Corruption augmentation (denoising-conditioned training) ─────────
# Randomly mask text tokens (NOT gen MASKs) in passages to force robust,
# diverse multi-vector representations. Unique to diffusion: the model
# was pretrained on this exact corruption — zero distribution shift.
if self.training and self.corruption_rate > 0:
import random as _rng
t = _rng.uniform(0.0, self.corruption_rate)
if t > 0.01: # skip very small rates
passage_input_ids = self._corrupt_text(
passage_input_ids, passage_attention_mask, t)
# ── Soft-token multi-step (differentiable denoising) ─────────────────
if self.soft_denoising and n_steps > 1:
return self._soft_multistep_forward(
query_input_ids, query_attention_mask,
passage_input_ids, passage_attention_mask,
query_content_ids, passage_content_ids,
n_steps,
)
# ── K pre-encoder (two-stage encoding) ───────────────────────────────
# Predict K per item BEFORE main encoder, slice each input to its K,
# then run main encoder at variable-length batch. True encoding
# savings (vs the K-adapter which always encodes at K_max).
if self.use_k_pre_encoder:
return self._forward_with_k_pre_encoder(
query_input_ids, query_attention_mask,
passage_input_ids, passage_attention_mask,
query_content_ids, passage_content_ids,
labels,
)
# ── Single-step fast path ─────────────────────────────────────────────
if n_steps <= 1:
K_q, K_p = self.n_gen_q_tokens, self.n_gen_p_tokens
n_tail = self._n_tail
need_sparse = self.sparse_weight > 0
# Optionally corrupt passages for denoising auxiliary loss
if self.denoising_weight > 0:
p_corrupted, p_denoise_targets, mask_ratio = self._apply_text_masking(
passage_input_ids, passage_attention_mask)
p_ids_for_fwd = p_corrupted
else:
p_ids_for_fwd = passage_input_ids
need_logits = need_sparse or (self.denoising_weight > 0)
L_q, L_p = query_input_ids.size(1), p_ids_for_fwd.size(1)
if K_q == K_p:
# ── Symmetric path: concat Q+P into one _fwd (memory-efficient).
if L_q < L_p:
pad = L_p - L_q
q_ids_cat = F.pad(query_input_ids, (pad, 0)) # left-pad
q_mask_cat = F.pad(query_attention_mask, (pad, 0))
p_ids_cat, p_mask_cat = p_ids_for_fwd, passage_attention_mask
elif L_p < L_q:
pad = L_q - L_p
q_ids_cat, q_mask_cat = query_input_ids, query_attention_mask
p_ids_cat = F.pad(p_ids_for_fwd, (pad, 0)) # left-pad
p_mask_cat = F.pad(passage_attention_mask, (pad, 0))
else:
q_ids_cat, q_mask_cat = query_input_ids, query_attention_mask
p_ids_cat, p_mask_cat = p_ids_for_fwd, passage_attention_mask
all_ids = torch.cat([q_ids_cat, p_ids_cat], dim=0)
all_mask = torch.cat([q_mask_cat, p_mask_cat], dim=0)
all_hidden, all_logits = self._fwd(all_ids, all_mask,
need_logits=need_logits)
L_all = all_ids.size(1)
g_q_in_all = L_all - K_q - n_tail
g_p_in_all = L_all - K_p - n_tail
q_hidden, p_hidden = all_hidden[:B_q], all_hidden[B_q:]
q_logits = all_logits[:B_q] if all_logits is not None else None
p_logits = all_logits[B_q:] if all_logits is not None else None
else:
# ── Asymmetric path: K_q != K_p means different gen-block sizes,
# so a concatenated forward would mis-align the slices. Run
# two separate _fwd calls (~2× compute but correct).
q_hidden, q_logits = self._fwd(
query_input_ids, query_attention_mask,
need_logits=need_logits)
p_hidden, p_logits = self._fwd(
p_ids_for_fwd, passage_attention_mask,
need_logits=need_logits)
g_q_in_all = L_q - K_q - n_tail
g_p_in_all = L_p - K_p - n_tail
# No padding/concat: each side keeps its own length.
L_all = max(L_q, L_p) # only used for denoising slice
# ── Extract Q repr ─────────────────────────────────────────────
q_repr_hidden = q_hidden[:, g_q_in_all:g_q_in_all + K_q, :].float()
q_quotation_emb = q_hidden[:, g_q_in_all - 1, :].float()
q_sparse_max = None
if need_sparse and q_logits is not None:
# Monotonic trick: max in bf16 first → log1p(relu()) on [B, V].
q_sparse_max = torch.log1p(torch.relu(
q_logits[:, g_q_in_all:g_q_in_all + K_q, :].max(dim=1).values))
if self.normalize:
q_quotation_emb = F.normalize(q_quotation_emb, p=2, dim=-1)
q_repr_hidden = F.normalize(q_repr_hidden, p=2, dim=-1)
q_repr = {'repr_hidden': q_repr_hidden, 'quotation_emb': q_quotation_emb}
if q_sparse_max is not None:
if query_content_ids is not None:
q_sparse_max = filter_sparse(q_sparse_max, query_content_ids)
q_repr['sparse_acts'] = q_sparse_max
# ── Extract P repr ─────────────────────────────────────────────
p_repr_hidden = p_hidden[:, g_p_in_all:g_p_in_all + K_p, :].float()
p_quotation_emb = p_hidden[:, g_p_in_all - 1, :].float()
p_sparse_max = None
if need_sparse and p_logits is not None:
# Monotonic trick: max in bf16 first → log1p(relu()) on [B, V].
p_sparse_max = torch.log1p(torch.relu(
p_logits[:, g_p_in_all:g_p_in_all + K_p, :].max(dim=1).values))
if self.normalize:
p_quotation_emb = F.normalize(p_quotation_emb, p=2, dim=-1)
p_repr_hidden = F.normalize(p_repr_hidden, p=2, dim=-1)
p_repr = {'repr_hidden': p_repr_hidden, 'quotation_emb': p_quotation_emb}
if p_sparse_max is not None:
if passage_content_ids is not None:
p_sparse_max = filter_sparse(p_sparse_max, passage_content_ids)
p_repr['sparse_acts'] = p_sparse_max
# ── Cross-GPU negative sharing ────────────────────────────────
q_repr = self._gather_repr(q_repr)
p_repr = self._gather_repr(p_repr)
# Recompute labels for gathered batch
B_q_all = q_repr['repr_hidden'].size(0)
B_p_all = p_repr['repr_hidden'].size(0)
n_paq_g = B_p_all // B_q_all
labels_g = torch.arange(B_q_all, device=device) * n_paq_g
# ── Loss ──────────────────────────────────────────────────────
loss_dict = self.compute_loss(q_repr, p_repr, labels_g)
if self.debug_dense_metrics:
loss_dict.update(self._dense_debug_stats(
q_repr, p_repr, labels_g,
q_raw=q_repr_hidden, p_raw=p_repr_hidden,
))
# Scale loss to counter DDP gradient averaging
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
loss_dict['loss'] = loss_dict['loss'] * torch.distributed.get_world_size()
if self.denoising_weight > 0:
# Denoising targets were created from original passage length;
# if passages were left-padded to match queries, slice
# accordingly. In asymmetric mode (separate forward) no
# padding was applied, so the slice is a no-op (dn_start=0).
p_seq_len = p_logits.size(1) if p_logits is not None else L_p
dn_start = p_seq_len - L_p
p_logits_dn = p_logits[:, dn_start:, :]
denoise_loss = self.compute_denoising_loss(
p_logits_dn, p_denoise_targets, mask_ratio)
loss_dict['loss'] = loss_dict['loss'] + self.denoising_weight * denoise_loss
loss_dict['loss_denoising'] = denoise_loss.detach()
return loss_dict
# ── Multi-step: denoising training ───────────────────────────────────
K_q, K_p = self.n_gen_q_tokens, self.n_gen_p_tokens
n_tail = self._n_tail
n_per_step_q = max(1, K_q // n_steps)
n_per_step_p = max(1, K_p // n_steps)
device = query_input_ids.device
H = self.hidden_size
V = self.backbone.config.vocab_size
L_q = query_input_ids.size(1)
L_p = passage_input_ids.size(1)
# Pre-compute 4D masks once (only for models that need them)
q_mask_4d = self._build_4d_mask(query_input_ids.size(1), query_attention_mask) if (self.adapter.needs_4d_mask() if self.adapter else True) else None
p_mask_4d = self._build_4d_mask(passage_input_ids.size(1), passage_attention_mask) if (self.adapter.needs_4d_mask() if self.adapter else True) else None
q_curr = query_input_ids.clone()
p_curr = passage_input_ids.clone()
need_sparse = self.sparse_weight > 0
# Frozen buffers for positions already decoded in prior steps
q_repr_buf = torch.zeros(B_q, K_q, H, device=device)
p_repr_buf = torch.zeros(B_p, K_p, H, device=device)
q_decoded_mask = torch.zeros(B_q, K_q, dtype=torch.bool, device=device)
p_decoded_mask = torch.zeros(B_p, K_p, dtype=torch.bool, device=device)
q_sparse_decoded = torch.zeros(B_q, V, device=device) if need_sparse else None
p_sparse_decoded = torch.zeros(B_p, V, device=device) if need_sparse else None
progressive = self.progressive_step_weight > 0
final_loss: Dict[str, torch.Tensor] = {}
progressive_loss_sum = torch.tensor(0.0, device=device)
def _build_repr_at_step(h, logits, repr_buf, decoded_mask, sparse_decoded,
K_side: int,
content_ids=None, use_fresh=False):
"""Build retrieval representation from current hidden states.
Used for both intermediate (progressive) and final steps.
Snapshot buffers with .clone() to isolate from later in-place updates.
K_side: per-side K (n_gen_q_tokens for queries, n_gen_p_tokens
for passages — needed under asymmetric configs).
use_fresh: if True, ALL K positions use current hidden states
(ignoring frozen repr_buf). Lets all tokens benefit from
fully-decoded context at the final step.
"""
Ls = h.size(1)
gs = Ls - K_side - n_tail # left-padded: gen block then structural tail
curr_gen_h = h[:, gs:gs + K_side, :].float() # [B_loc, K_side, H]
decoded_snap = decoded_mask.clone()
repr_snap = repr_buf.detach().clone()
mask_sparse = None
if need_sparse and logits is not None:
sp_all = torch.log(1.0 + torch.relu(logits[:, gs:gs + K_side, :].float()))
if use_fresh:
mask_sparse = sp_all.max(dim=1).values
else:
is_mask = (~decoded_snap).float().unsqueeze(-1)
mask_sparse = (sp_all * is_mask).max(dim=1).values
if use_fresh:
mixed = curr_gen_h
else:
mixed = torch.where(decoded_snap.unsqueeze(-1), repr_snap, curr_gen_h)
quotation_emb = h[:, gs - 1, :].float()
if self.normalize:
mixed = F.normalize(mixed, p=2, dim=-1)
quotation_emb = F.normalize(quotation_emb, p=2, dim=-1)
result = {'repr_hidden': mixed, 'quotation_emb': quotation_emb}
if need_sparse and sparse_decoded is not None:
sparse_snap = sparse_decoded.detach().clone()
combined_sparse = torch.max(sparse_snap, mask_sparse)
if content_ids is not None:
combined_sparse = filter_sparse(combined_sparse, content_ids)
result['sparse_acts'] = combined_sparse
return result
q_g = L_q - K_q - n_tail # left-padded: gen block starts here
p_g = L_p - K_p - n_tail
for step in range(n_steps):
is_last = (step == n_steps - 1)
q_h, q_logits = self._fwd(q_curr, query_attention_mask,
need_logits=True, mask_4d=q_mask_4d)
p_h, p_logits = self._fwd(p_curr, passage_attention_mask,
need_logits=True, mask_4d=p_mask_4d)
if progressive or is_last:
_fresh = self.use_fresh_final and is_last
q_repr_s = _build_repr_at_step(
q_h, q_logits,
q_repr_buf, q_decoded_mask, q_sparse_decoded, K_q,
content_ids=query_content_ids, use_fresh=_fresh)
p_repr_s = _build_repr_at_step(
p_h, p_logits,
p_repr_buf, p_decoded_mask, p_sparse_decoded, K_p,
content_ids=passage_content_ids, use_fresh=_fresh)
# Cross-GPU gather for multi-step
q_repr_s = self._gather_repr(q_repr_s)
p_repr_s = self._gather_repr(p_repr_s)
B_q_g = q_repr_s['repr_hidden'].size(0)
B_p_g = p_repr_s['repr_hidden'].size(0)
labels_ms = torch.arange(B_q_g, device=device) * (B_p_g // B_q_g)
step_loss_dict = self.compute_loss(q_repr_s, p_repr_s, labels_ms)
if is_last:
final_loss = step_loss_dict
if self.debug_dense_metrics:
final_loss.update(self._dense_debug_stats(q_repr_s, p_repr_s, labels_ms))
elif progressive:
step_weight = (step + 1) / n_steps
progressive_loss_sum = progressive_loss_sum + step_weight * step_loss_dict['loss']
if not is_last:
q_curr, q_newly = self._unmask_step(
q_curr, q_logits, K_q, n_per_step_q)
p_curr, p_newly = self._unmask_step(
p_curr, p_logits, K_p, n_per_step_p)
with torch.no_grad():
for i, pos_list in enumerate(q_newly):
for tok_pos in pos_list:
q_repr_buf[i, tok_pos] = q_h[i, q_g + tok_pos].detach()
q_decoded_mask[i, tok_pos] = True
if q_sparse_decoded is not None and q_logits is not None:
q_sparse_decoded[i] = torch.max(q_sparse_decoded[i],
torch.log(1.0 + torch.relu(q_logits[i, q_g + tok_pos])).detach())
for i, pos_list in enumerate(p_newly):
for tok_pos in pos_list:
p_repr_buf[i, tok_pos] = p_h[i, p_g + tok_pos].detach()
p_decoded_mask[i, tok_pos] = True
if p_sparse_decoded is not None and p_logits is not None:
p_sparse_decoded[i] = torch.max(p_sparse_decoded[i],
torch.log(1.0 + torch.relu(p_logits[i, p_g + tok_pos])).detach())
# Add progressive step loss to final loss
if progressive and final_loss:
final_loss['loss'] = final_loss['loss'] + self.progressive_step_weight * progressive_loss_sum
final_loss['loss_progressive'] = progressive_loss_sum.detach()
# Scale loss for DDP (multi-step path)
if final_loss and torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
final_loss['loss'] = final_loss['loss'] * torch.distributed.get_world_size()
# Denoising auxiliary for multi-step (same as single-step: separate corrupted forward)
if self.denoising_weight > 0 and final_loss:
p_corrupted, p_denoise_targets, mask_ratio = self._apply_text_masking(
passage_input_ids, passage_attention_mask)
_, p_logits_dn = self._fwd(p_corrupted, passage_attention_mask,
need_logits=True, mask_4d=p_mask_4d)
denoise_loss = self.compute_denoising_loss(p_logits_dn, p_denoise_targets, mask_ratio)
final_loss['loss'] = final_loss['loss'] + self.denoising_weight * denoise_loss
final_loss['loss_denoising'] = denoise_loss.detach()
return final_loss
# ----------------------------------------------------------------
# Save / load helpers
# ----------------------------------------------------------------
def _save_retriever_config(self, output_dir: str):
import json, os
config = {
'model_type': self.model_type,
'mask_token_id': self.mask_token_id,
'hidden_size': self.hidden_size,
'max_length': self.max_length,
'n_gen_tokens': self.n_gen_tokens,
'n_gen_q_tokens': self.n_gen_q_tokens,
'n_gen_p_tokens': self.n_gen_p_tokens,
'temperature': self.temperature,
'flops_weight': 0.0, # deprecated, kept for backwards compat
'num_denoise_steps': self.num_denoise_steps,
'sparse_weight': self.sparse_weight,
'normalize': self.normalize,
'query_prefix_ids': self._query_prefix_ids,
'query_suffix_ids': self._query_suffix_ids,
'passage_prefix_ids': self._passage_prefix_ids,
'passage_suffix_ids': self._passage_suffix_ids,
'lora_rank': getattr(self, 'lora_rank', 0),
'lora_alpha': getattr(self, 'lora_alpha', 64),
'use_eos': self.use_eos,
'n_tail': self._n_tail,
# K-adapter persistence
'use_k_adapter': self.use_k_adapter,
'adapter_weight': self.adapter_weight,
'teacher_temperature': self.teacher_temperature,
'k_adapter_options': list(self.k_adapter_options),
# K pre-encoder persistence
'use_k_pre_encoder': self.use_k_pre_encoder,
'gumbel_temperature': self.gumbel_temperature,
'k_cost_lambda': self.k_cost_lambda,
'k_pre_encoder_options': list(self.k_pre_encoder_options),
}
with open(os.path.join(output_dir, 'retriever_config.json'), 'w') as f:
json.dump(config, f, indent=2)
# Save adapter state separately (PEFT save_pretrained only saves LoRA;
# the small KAdapter MLP isn't a LoRA module so we persist it ourselves).
if self.use_k_adapter and self.k_adapter is not None:
torch.save(self.k_adapter.state_dict(),
os.path.join(output_dir, 'k_adapter.bin'))
if self.use_k_pre_encoder and self.k_pre_encoder_q is not None:
torch.save({'q': self.k_pre_encoder_q.state_dict(),
'p': self.k_pre_encoder_p.state_dict()},
os.path.join(output_dir, 'k_pre_encoder.bin'))
@classmethod
def load(cls, model_dir: str, **fallback_kwargs) -> 'TrainableDiffusionRetriever':
"""Load a fine-tuned TrainableDiffusionRetriever from a saved directory.
If retriever_config.json is missing (e.g. mid-training checkpoint or old checkpoint
with matryoshka_config.json), falls back to from_backbone() using fallback_kwargs:
model_type, query_prompt, passage_prompt, n_gen_tokens,
sparse_weight, max_length.
"""
import json
from transformers import AutoTokenizer
# Support both new name and old name (matryoshka_config.json) for backwards compat
config_path = Path(model_dir) / 'retriever_config.json'
if not config_path.exists():
config_path = Path(model_dir) / 'matryoshka_config.json'
if not config_path.exists():
model_type = fallback_kwargs.get('model_type')
query_prompt = fallback_kwargs.get('query_prompt')
passage_prompt = fallback_kwargs.get('passage_prompt')
if not model_type or not query_prompt or not passage_prompt:
raise FileNotFoundError(
f"No retriever_config.json in {model_dir}. "
"Pass model_type, query_prompt, passage_prompt as fallback kwargs.")
_fallback_adapter = get_adapter(model_type)
source_model = (fallback_kwargs.get('original_model')
or _fallback_adapter.hub_model_name)
logger.info(f"No retriever_config.json — loading architecture from {source_model}, "
f"weights from {model_dir}")
# Load state dict first so we can detect LoRA rank before building model
checkpoint_dir = Path(model_dir)
weight_file = checkpoint_dir / 'model.safetensors'
bin_files = sorted(checkpoint_dir.glob('model-*.safetensors'))
bin_pt = sorted(checkpoint_dir.glob('pytorch_model*.bin'))
state_dict = None
if weight_file.exists() or bin_files:
from safetensors.torch import load_file
if weight_file.exists():
state_dict = load_file(str(weight_file))
else:
state_dict = {}
for f in bin_files:
state_dict.update(load_file(str(f)))
elif bin_pt:
state_dict = torch.load(str(bin_pt[0]), map_location='cpu')
# Detect LoRA rank from key names (keys look like backbone.base_model.model.*)
detected_lora_rank = fallback_kwargs.get('lora_rank', 0)
if state_dict is not None and detected_lora_rank == 0:
for k, v in state_dict.items():
if 'lora_A.default.weight' in k:
detected_lora_rank = v.shape[0]
logger.info(f"Detected LoRA rank {detected_lora_rank} from checkpoint keys")
break
model = cls.from_backbone(
model_name=source_model,
model_type=model_type,
query_prompt=query_prompt,
passage_prompt=passage_prompt,
n_gen_tokens=fallback_kwargs.get('n_gen_tokens', 4),
n_gen_q_tokens=fallback_kwargs.get('n_gen_q_tokens', None),
n_gen_p_tokens=fallback_kwargs.get('n_gen_p_tokens', None),
temperature=fallback_kwargs.get('temperature', 0.02),
num_denoise_steps=fallback_kwargs.get('num_denoise_steps', None),
sparse_weight=fallback_kwargs.get('sparse_weight', 1.0),
normalize=fallback_kwargs.get('normalize', True),
max_length=fallback_kwargs.get('max_length', 256),
lora_rank=detected_lora_rank,
device_map='auto',
)
if state_dict is not None:
# Load full state dict directly into the model (keys have backbone. prefix)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
backbone_keys_loaded = sum(1 for k in state_dict if k.startswith('backbone.'))
logger.info(f"Loaded checkpoint weights from {model_dir} "
f"(lora_rank={detected_lora_rank}, "
f"{backbone_keys_loaded} backbone keys, "
f"{len(missing)} missing, {len(unexpected)} unexpected)")
else:
logger.warning(f"No weight file found in {model_dir} — using original model weights")
# Merge LoRA weights into base model for faster inference
if detected_lora_rank > 0 and hasattr(model.backbone, 'merge_and_unload'):
model.backbone = model.backbone.merge_and_unload()
logger.info("Merged LoRA adapters for inference")
return model
with open(config_path) as f:
cfg = json.load(f)
model_type = cfg.get('model_type', 'dream')
adapter = get_adapter(model_type)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
tokenizer.padding_side = 'left'
load_weights_separately = False
# ── Released PEFT adapter repo (the layout published on the Hub) ──────
# A released checkpoint ships only the LoRA adapter
# (adapter_config.json + adapter_model.safetensors/.bin) alongside
# retriever_config.json and the tokenizer — there is no base config.json
# or full-weight file. Detect that and load the base backbone from the
# Hub, then attach the adapter in place. PeftModel.from_pretrained
# restores the trained LoRA weights, so no separate weight load is
# needed and the merge step below folds them in for fast inference.
_adapter_cfg = Path(model_dir) / 'adapter_config.json'
_base_cfg = Path(model_dir) / 'config.json'
_adapter_weights = ((Path(model_dir) / 'adapter_model.safetensors').exists()
or (Path(model_dir) / 'adapter_model.bin').exists())
if _adapter_cfg.exists() and _adapter_weights and not _base_cfg.exists():
from peft import PeftModel
logger.info(f"Detected released PEFT adapter in {model_dir} — "
f"loading base {adapter.hub_model_name} from Hub and "
f"attaching the LoRA adapter")
base = adapter.load_backbone(adapter.hub_model_name, device_map='auto')
backbone = PeftModel.from_pretrained(base, str(model_dir))
else:
try:
backbone = adapter.load_backbone(model_dir, device_map='auto')
except ValueError:
# model_dir has no config.json — load architecture from Hub
logger.info(f"No valid config.json in {model_dir} — "
f"loading architecture from {adapter.hub_model_name}")
backbone = adapter.load_backbone(adapter.hub_model_name, device_map='auto')
load_weights_separately = True
# Detect if model.safetensors was saved as a full PEFT state dict by DeepSpeed
_ckpt = Path(model_dir) / 'model.safetensors'
if _ckpt.exists():
from safetensors import safe_open as _so
with _so(str(_ckpt), framework='pt') as _f:
_sample = list(_f.keys())[:8]
if any('base_model.model' in k for k in _sample):
logger.info("Detected PEFT/LoRA state dict — "
"applying LoRA to backbone before loading weights")
_lora_rank = cfg.get('lora_rank', 0)
_lora_alpha = cfg.get('lora_alpha', 64)
if _lora_rank == 0:
from safetensors.torch import load_file as _lf
_peek = _lf(str(_ckpt))
_la_keys = [k for k in _peek if 'lora_A.default.weight' in k]
if _la_keys:
_lora_rank = _peek[_la_keys[0]].shape[0]
if _lora_rank > 0 and not getattr(backbone, 'peft_config', None):
from peft import get_peft_model
lora_cfg = adapter.get_lora_config(_lora_rank, _lora_alpha)
backbone = get_peft_model(backbone, lora_cfg)
logger.info(f"Applied LoRA rank={_lora_rank} alpha={_lora_alpha}")
elif getattr(backbone, 'peft_config', None):
logger.info("Backbone already has PEFT — skipping duplicate LoRA")
load_weights_separately = True
model = cls(
backbone=backbone,
tokenizer=tokenizer,
mask_token_id=cfg['mask_token_id'],
hidden_size=cfg['hidden_size'],
query_prefix_ids=cfg['query_prefix_ids'],
query_suffix_ids=cfg['query_suffix_ids'],
passage_prefix_ids=cfg['passage_prefix_ids'],
passage_suffix_ids=cfg['passage_suffix_ids'],
max_length=cfg['max_length'],
n_gen_tokens=cfg['n_gen_tokens'],
n_gen_q_tokens=cfg.get('n_gen_q_tokens'), # back-compat: missing → fall through to n_gen_tokens
n_gen_p_tokens=cfg.get('n_gen_p_tokens'),
temperature=cfg['temperature'],
num_denoise_steps=cfg['num_denoise_steps'],
sparse_weight=cfg['sparse_weight'],
normalize=cfg['normalize'],
flash_attn=adapter.flash_attn,
use_eos=cfg.get('use_eos', False),
# K-adapter restore (back-compat: keys missing → defaults disable)
use_k_adapter=cfg.get('use_k_adapter', False),
adapter_weight=cfg.get('adapter_weight', 1.0),
teacher_temperature=cfg.get('teacher_temperature', 1.0),
k_adapter_options=tuple(cfg['k_adapter_options'])
if cfg.get('k_adapter_options') is not None else None,
# K pre-encoder restore
use_k_pre_encoder=cfg.get('use_k_pre_encoder', False),
gumbel_temperature=cfg.get('gumbel_temperature', 1.0),
k_cost_lambda=cfg.get('k_cost_lambda', 0.001),
k_pre_encoder_options=tuple(cfg['k_pre_encoder_options'])
if cfg.get('k_pre_encoder_options') is not None else None,
)
# Load adapter weights if a saved file exists.
if model.use_k_adapter and model.k_adapter is not None:
adapter_path = Path(model_dir) / 'k_adapter.bin'
if adapter_path.exists():
model.k_adapter.load_state_dict(
torch.load(str(adapter_path), map_location='cpu'))
logger.info(f"Loaded KAdapter state from {adapter_path}")
else:
logger.warning(f"use_k_adapter=True but {adapter_path} missing; "
f"adapter starts from random init.")
# Load pre-encoder weights if a saved file exists.
if model.use_k_pre_encoder and model.k_pre_encoder_q is not None:
pe_path = Path(model_dir) / 'k_pre_encoder.bin'
if pe_path.exists():
pe_state = torch.load(str(pe_path), map_location='cpu')
model.k_pre_encoder_q.load_state_dict(pe_state['q'])
model.k_pre_encoder_p.load_state_dict(pe_state['p'])
logger.info(f"Loaded KPreEncoder state from {pe_path}")
else:
logger.warning(f"use_k_pre_encoder=True but {pe_path} missing; "
f"pre-encoder starts from random init.")
# Restore _n_tail from config. Old checkpoints (before structural tail)
# don't have 'n_tail' — infer from use_eos for backward compat.
if 'n_tail' in cfg:
model._n_tail = cfg['n_tail']
model._tail_ids = model._build_tail_ids()
else:
# Legacy: use_eos=True → 1 tail token (EOS only), use_eos=False → 0
model._n_tail = 1 if cfg.get('use_eos', False) else 0
eos_id = model.tokenizer.eos_token_id
model._tail_ids = [eos_id] if model._n_tail == 1 else []
logger.info(f"Legacy checkpoint: n_tail={model._n_tail} (from use_eos={cfg.get('use_eos', False)})")
model.model_type = model_type
model.adapter = adapter
# Hook for efficient hidden state extraction
model._last_hidden: Dict[str, torch.Tensor] = {}
model._hook_registered = adapter.register_hidden_hook(
backbone, model._last_hidden)
if load_weights_separately:
checkpoint_dir = Path(model_dir)
weight_file = checkpoint_dir / 'model.safetensors'
bin_files = sorted(checkpoint_dir.glob('model-*.safetensors'))
bin_pt = sorted(checkpoint_dir.glob('pytorch_model*.bin'))
if weight_file.exists() or bin_files:
from safetensors.torch import load_file
state_dict = load_file(str(weight_file)) if weight_file.exists() else {}
for f in bin_files:
state_dict.update(load_file(str(f)))
backbone_dict = {k[len('backbone.'):]: v for k, v in state_dict.items()
if k.startswith('backbone.')}
model.backbone.load_state_dict(backbone_dict or state_dict, strict=False)
logger.info(f"Loaded fine-tuned weights from {model_dir}")
elif bin_pt:
sd = torch.load(str(bin_pt[0]), map_location='cpu')
backbone_dict = {k[len('backbone.'):]: v for k, v in sd.items()
if k.startswith('backbone.')}
model.backbone.load_state_dict(backbone_dict or sd, strict=False)
logger.info(f"Loaded fine-tuned weights from {model_dir}")
else:
logger.warning(f"No weight file found in {model_dir} — using Hub weights")
# Merge LoRA adapters into the base weights for inference. Training
# keeps them separate for gradient flow; at inference each LoRA adds
# per-layer overhead otherwise.
#
# Two shapes seen in saved checkpoints:
# (a) `model.backbone` is a PeftModel wrapper → call
# `merge_and_unload()` directly.
# (b) `model.backbone` is the underlying DreamModel / LLaDAModel
# with `LoraLinear` modules still in place (the top-level
# PeftModel wrapper was dropped during save). `merge_and_unload`
# is missing but each inner `LoraLinear` still exposes `.merge()`
# which folds its `lora_A @ lora_B` into `base_layer.weight` in
# place and flips `merged=True`, disabling the LoRA forward path.
# We walk the module tree and merge case-by-case.
if hasattr(model.backbone, 'merge_and_unload'):
try:
model.backbone = model.backbone.merge_and_unload()
logger.info("Merged LoRA adapters into base backbone for inference")
except Exception as exc:
logger.warning(f"merge_and_unload failed ({exc}); "
f"continuing with un-merged LoRA (slower inference)")
else:
merged_modules = 0
for _mod in model.backbone.modules():
if (hasattr(_mod, 'lora_A') and hasattr(_mod, 'merge')
and callable(getattr(_mod, 'merge', None))
and not getattr(_mod, 'merged', False)):
try:
_mod.merge()
merged_modules += 1
except Exception as exc:
logger.warning(f"per-module LoRA merge failed ({exc}) "
f"— continuing with un-merged layer")
if merged_modules > 0:
logger.info(f"Merged {merged_modules} LoraLinear modules into "
f"base weights for inference")
logger.info(f"Loaded TrainableDiffusionRetriever from {model_dir}")
return model
def save(self, output_dir: str):
import os
os.makedirs(output_dir, exist_ok=True)
backbone = self.backbone
if hasattr(backbone, 'save_pretrained'):
backbone.save_pretrained(output_dir)
else:
backbone.base_model.save_pretrained(output_dir)
self.tokenizer.save_pretrained(output_dir)
self._save_retriever_config(output_dir)
logger.info(f"Saved to {output_dir}")
@property
def config(self):
# HF Trainer / DeepSpeed expect model.config to exist.
return self.backbone.config
def gradient_checkpointing_enable(self, **kwargs):
if self.adapter:
self.adapter.enable_gradient_checkpointing(self.backbone, **kwargs)
else:
self.backbone.gradient_checkpointing_enable(**kwargs)