# Copyright (c) 2024-present, BAAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------- """Utility functions for constructing URSA model inputs during distillation. This module mirrors the token-splicing and RoPE-position logic from URSAPipeline.__call__ and URSATrainPipeline.process_inputs so that student/aux/teacher always see the exact same input distribution. Key design facts (verified from source): - transformer.config.lm_head_size = 64000 -> logit output dim (codebook_size) - transformer.config.lm_vocab_size = 151669 -> text-vocab offset for visual tokens - transformer.config.bov_token_id = 151652 -> beginning-of-video sentinel - Input visual token IDs are shifted: stored as (raw_code + lm_vocab_size) - BOV sentinel is prepended to the visual token block - Causal slice to recover visual logits: logits[:, -(N+1):-1] where N = T*H*W """ from typing import Tuple import torch import torch.nn.functional as F # --------------------------------------------------------------------------- # Latent shape helpers # --------------------------------------------------------------------------- def compute_latents_shape( num_frames: int, height: int, width: int, temporal_stride: int = 4, spatial_stride: int = 8, ) -> Tuple[int, int, int]: """Return the VQ-token grid (T, H, W) matching URSAPipeline's convention. Matches the formula in URSAPipeline.__call__: T = (num_frames - 1) // temporal_stride + 1 H = height // spatial_stride W = width // spatial_stride """ T = (num_frames - 1) // temporal_stride + 1 H = height // spatial_stride W = width // spatial_stride return T, H, W # --------------------------------------------------------------------------- # Core input builder # --------------------------------------------------------------------------- def build_ursa_inputs( transformer, txt_ids: torch.Tensor, visual_tokens: torch.Tensor, latents_shape: Tuple[int, int, int], device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor, int]: """Construct (input_ids, rope_pos, N) exactly as URSAPipeline does. This is the single source-of-truth for all three models (teacher / aux / student) so their input distributions match. Args: transformer: The URSATransformer3DModel (read config from it). txt_ids: Tokenised prompts, shape [B, L]. visual_tokens: Raw codebook indices, shape [B, T, H, W] or [B, N], dtype long. latents_shape: (T, H, W) tuple – shape of one video's latent grid. device: Target device. Returns: input_ids: [B, L + N + 1], long (N = T*H*W) rope_pos: [B, L + N + 1, 3], int32 N: number of visual tokens per sample (T*H*W) Notes: - BOV token is inserted at position L (just before the visual tokens). - Visual token IDs are shifted by lm_vocab_size before being concatenated. - rope_pos is batched (training convention), not the 2-D inference convention. """ B, L = txt_ids.shape # -- Config values --------------------------------------------------- # PITFALL 1: always read from config, never hard-code. bov_token_id = transformer.config.bov_token_id # lm_vocab_size == len(tokenizer): the visual-token vocab offset. latent_shift = transformer.config.lm_vocab_size T, H, W = latents_shape N = T * H * W # -- Input validation ------------------------------------------------ assert visual_tokens.dtype == torch.long, \ f"build_ursa_inputs: visual_tokens must be long, got {visual_tokens.dtype}" assert visual_tokens.numel() == B * N, ( f"build_ursa_inputs: visual_tokens has {visual_tokens.numel()} elements, " f"expected B*N = {B}*{N} = {B*N}" ) # -- Visual token block ---------------------------------------------- # Flatten to [B, N] so pad/cat are straightforward. latents_flat = visual_tokens.view(B, N).to(device) # [B, N], long # Shift raw codebook indices into the visual-vocab region and prepend BOV. # Mirrors: img_ids = pad(latents_flat + latent_shift, (1,0), value=bov_token_id) img_ids = F.pad(latents_flat + latent_shift, (1, 0), value=bov_token_id) # [B, N+1] # -- Full input sequence: [txt | bov | vis_0 ... vis_{N-1}] ---------- input_ids = torch.cat([txt_ids.to(device), img_ids], dim=1) # [B, L+N+1] # -- RoPE positions -------------------------------------------------- # Mirrors URSAPipeline: # txt_pos = arange(L).view(-1,1).expand(-1,3) -> [L, 3] # blk_pos = flex_rope.get_pos(latents_shape, L) -> [1, N+1, 3] # rope_pos = cat([txt_pos, blk_pos[0]]) -> [L+N+1, 3] # Then batch-expand (training convention): # rope_pos = rope_pos.unsqueeze(0).expand(B,-1,-1).contiguous() -> [B, L+N+1, 3] txt_pos = torch.arange(L, device=device).view(-1, 1).expand(-1, 3) # [L, 3] blk_pos = transformer.model.flex_rope.get_pos(latents_shape, txt_pos.size(0)) # [1, N+1, 3] rope_pos_1d = torch.cat([txt_pos, blk_pos[0].to(device)], dim=0) # [L+N+1, 3] rope_pos = rope_pos_1d.unsqueeze(0).expand(B, -1, -1).contiguous() # [B, L+N+1, 3] # -- Output shape assertions ----------------------------------------- expected_seq_len = L + N + 1 assert input_ids.shape == (B, expected_seq_len), ( f"build_ursa_inputs: input_ids shape={input_ids.shape} " f"expected ({B},{expected_seq_len}). " "txt_ids length or latents_shape may be wrong." ) assert rope_pos.shape == (B, expected_seq_len, 3), ( f"build_ursa_inputs: rope_pos shape={rope_pos.shape} " f"expected ({B},{expected_seq_len},3). " "BOV/blk_pos alignment is off — check flex_rope.get_pos return shape." ) return input_ids, rope_pos, N # --------------------------------------------------------------------------- # Visual logit extractor # --------------------------------------------------------------------------- def extract_visual_logits( logits: torch.Tensor, N: int, codebook_size: int, lm_head_size: int = None, ) -> torch.Tensor: """Slice and (if needed) project the transformer logits to [B, N, K]. PITFALL 2: The lm_head projects hidden states to lm_head_size (=64000), NOT to the full vocab_size. We must never confuse text-vocab indices with codebook indices. This function is the single gate that converts raw transformer output to visual-codebook logits. Slicing convention (mirrors URSAPipeline): z = logits[:, -(N+1) : -1] # causal shift: BOV at -(N+1), last is EOS If the last dimension already equals codebook_size, return z directly. If the last dimension is larger (e.g. full vocab), slice the visual region. Otherwise raise a descriptive error so the caller can fix the config. Args: logits: Raw transformer output, shape [B, L+N+1, D]. N: Number of visual tokens (T*H*W). codebook_size: Expected number of codebook entries (scheduler.codebook_size). lm_head_size: Deprecated alias for codebook_size; ignored if None. Returns: Tensor of shape [B, N, codebook_size]. """ B_in = logits.size(0) # PITFALL 2: causal slice – exactly as URSAPipeline uses it. # logits[:, -(N+1):-1] extracts the N positions after the BOV token. z = logits[:, -(N + 1) : -1] # [B, N, D] # Verify sliced sequence length matches N. assert z.size(1) == N, ( f"extract_visual_logits: slice produced seq_len={z.size(1)}, expected N={N}. " "Logit sequence length may be shorter than N+1. " "Check that input_ids was built with the correct latents_shape." ) D = z.size(-1) if D == codebook_size: # Happy path: lm_head_size == codebook_size (default URSA config). assert z.shape == (B_in, N, codebook_size), \ f"extract_visual_logits: z.shape={z.shape} expected ({B_in},{N},{codebook_size})" return z # If the head includes a text prefix (shouldn't happen with default config, # but guard anyway). if D > codebook_size: lm_vocab_size = D - codebook_size z_vis = z[..., lm_vocab_size:] assert z_vis.shape == (B_in, N, codebook_size), \ f"extract_visual_logits (sliced): z_vis.shape={z_vis.shape}" return z_vis raise ValueError( f"extract_visual_logits: unexpected logit last-dim={D} < codebook_size={codebook_size}. " "Check transformer.config.lm_head_size and scheduler.codebook_size. " f"logits.shape={logits.shape}" ) # --------------------------------------------------------------------------- # Corrupt helper (for p_init mixing) # --------------------------------------------------------------------------- def corrupt_tokens(tokens: torch.Tensor, r: float, K: int) -> torch.Tensor: """Replace a random fraction r of tokens with uniform random codes. Used for the 20% p_init mixing strategy: mask = Bernoulli(r) corrupted = mask * randint(K) + (1-mask) * tokens Args: tokens: Long tensor of codebook indices, any shape. r: Fraction of tokens to corrupt (0 < r < 1). K: Codebook size. Returns: Corrupted token tensor, same shape and dtype as ``tokens``. """ mask = torch.bernoulli(torch.full_like(tokens, r, dtype=torch.float)).bool() rand_codes = torch.randint(0, K, tokens.shape, device=tokens.device, dtype=tokens.dtype) return torch.where(mask, rand_codes, tokens) # --------------------------------------------------------------------------- # KL / Jeffrey divergence helpers # --------------------------------------------------------------------------- def kl_divergence(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: """KL(p || q) summed over last dimension, per-sample mean over tokens. Args: p: [B, N, K] probability tensor. q: [B, N, K] probability tensor. Returns: [B] per-sample KL divergence (mean over N tokens). """ p = p.clamp(min=eps) q = q.clamp(min=eps) return (p * (p.log() - q.log())).sum(-1).mean(-1) # [B] def jeffrey_divergence(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: """Symmetric KL (Jeffrey): KL(p||q) + KL(q||p), per-sample mean over tokens. Returns: [B] per-sample Jeffrey divergence. """ return kl_divergence(p, q, eps) + kl_divergence(q, p, eps) # --------------------------------------------------------------------------- # Timestep curriculum # --------------------------------------------------------------------------- def sample_t_curriculum( B: int, device: torch.device, step: int, warmup_steps: int = 10_000, ) -> torch.Tensor: """Sample training timesteps with a curriculum biased toward large t early on. - For the first ``warmup_steps`` steps, use t = 1 - (1-u)^2 (biased high). - After warmup, fall back to a near-uniform u sampled straight from [0, 1). - t is clamped to [0.05, 0.995] to avoid degenerate paths. Returns: [B] float tensor of continuous timesteps. """ u = torch.rand(B, device=device) if step < warmup_steps: t = 1.0 - (1.0 - u) ** 2 # squish toward 1 (data end) else: t = u return t.clamp(0.05, 0.995)