|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Utility functions for constructing URSA model inputs during distillation.
|
|
|
| This module mirrors the token-splicing and RoPE-position logic from
|
| URSAPipeline.__call__ and URSATrainPipeline.process_inputs so that
|
| student/aux/teacher always see the exact same input distribution.
|
|
|
| Key design facts (verified from source):
|
| - transformer.config.lm_head_size = 64000 -> logit output dim (codebook_size)
|
| - transformer.config.lm_vocab_size = 151669 -> text-vocab offset for visual tokens
|
| - transformer.config.bov_token_id = 151652 -> beginning-of-video sentinel
|
| - Input visual token IDs are shifted: stored as (raw_code + lm_vocab_size)
|
| - BOV sentinel is prepended to the visual token block
|
| - Causal slice to recover visual logits: logits[:, -(N+1):-1] where N = T*H*W
|
| """
|
|
|
| from typing import Tuple
|
|
|
| import torch
|
| import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def compute_latents_shape(
|
| num_frames: int,
|
| height: int,
|
| width: int,
|
| temporal_stride: int = 4,
|
| spatial_stride: int = 8,
|
| ) -> Tuple[int, int, int]:
|
| """Return the VQ-token grid (T, H, W) matching URSAPipeline's convention.
|
|
|
| Matches the formula in URSAPipeline.__call__:
|
| T = (num_frames - 1) // temporal_stride + 1
|
| H = height // spatial_stride
|
| W = width // spatial_stride
|
| """
|
| T = (num_frames - 1) // temporal_stride + 1
|
| H = height // spatial_stride
|
| W = width // spatial_stride
|
| return T, H, W
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def build_ursa_inputs(
|
| transformer,
|
| txt_ids: torch.Tensor,
|
| visual_tokens: torch.Tensor,
|
| latents_shape: Tuple[int, int, int],
|
| device: torch.device,
|
| ) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| """Construct (input_ids, rope_pos, N) exactly as URSAPipeline does.
|
|
|
| This is the single source-of-truth for all three models
|
| (teacher / aux / student) so their input distributions match.
|
|
|
| Args:
|
| transformer: The URSATransformer3DModel (read config from it).
|
| txt_ids: Tokenised prompts, shape [B, L].
|
| visual_tokens: Raw codebook indices, shape [B, T, H, W] or [B, N], dtype long.
|
| latents_shape: (T, H, W) tuple – shape of one video's latent grid.
|
| device: Target device.
|
|
|
| Returns:
|
| input_ids: [B, L + N + 1], long (N = T*H*W)
|
| rope_pos: [B, L + N + 1, 3], int32
|
| N: number of visual tokens per sample (T*H*W)
|
|
|
| Notes:
|
| - BOV token is inserted at position L (just before the visual tokens).
|
| - Visual token IDs are shifted by lm_vocab_size before being concatenated.
|
| - rope_pos is batched (training convention), not the 2-D inference convention.
|
| """
|
| B, L = txt_ids.shape
|
|
|
|
|
|
|
| bov_token_id = transformer.config.bov_token_id
|
|
|
| latent_shift = transformer.config.lm_vocab_size
|
| T, H, W = latents_shape
|
| N = T * H * W
|
|
|
|
|
| assert visual_tokens.dtype == torch.long, \
|
| f"build_ursa_inputs: visual_tokens must be long, got {visual_tokens.dtype}"
|
| assert visual_tokens.numel() == B * N, (
|
| f"build_ursa_inputs: visual_tokens has {visual_tokens.numel()} elements, "
|
| f"expected B*N = {B}*{N} = {B*N}"
|
| )
|
|
|
|
|
|
|
| latents_flat = visual_tokens.view(B, N).to(device)
|
|
|
|
|
|
|
| img_ids = F.pad(latents_flat + latent_shift, (1, 0), value=bov_token_id)
|
|
|
|
|
| input_ids = torch.cat([txt_ids.to(device), img_ids], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| txt_pos = torch.arange(L, device=device).view(-1, 1).expand(-1, 3)
|
| blk_pos = transformer.model.flex_rope.get_pos(latents_shape, txt_pos.size(0))
|
| rope_pos_1d = torch.cat([txt_pos, blk_pos[0].to(device)], dim=0)
|
| rope_pos = rope_pos_1d.unsqueeze(0).expand(B, -1, -1).contiguous()
|
|
|
|
|
| expected_seq_len = L + N + 1
|
| assert input_ids.shape == (B, expected_seq_len), (
|
| f"build_ursa_inputs: input_ids shape={input_ids.shape} "
|
| f"expected ({B},{expected_seq_len}). "
|
| "txt_ids length or latents_shape may be wrong."
|
| )
|
| assert rope_pos.shape == (B, expected_seq_len, 3), (
|
| f"build_ursa_inputs: rope_pos shape={rope_pos.shape} "
|
| f"expected ({B},{expected_seq_len},3). "
|
| "BOV/blk_pos alignment is off — check flex_rope.get_pos return shape."
|
| )
|
|
|
| return input_ids, rope_pos, N
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def extract_visual_logits(
|
| logits: torch.Tensor,
|
| N: int,
|
| codebook_size: int,
|
| lm_head_size: int = None,
|
| ) -> torch.Tensor:
|
| """Slice and (if needed) project the transformer logits to [B, N, K].
|
|
|
| PITFALL 2: The lm_head projects hidden states to lm_head_size (=64000),
|
| NOT to the full vocab_size. We must never confuse text-vocab indices with
|
| codebook indices. This function is the single gate that converts raw
|
| transformer output to visual-codebook logits.
|
|
|
| Slicing convention (mirrors URSAPipeline):
|
| z = logits[:, -(N+1) : -1] # causal shift: BOV at -(N+1), last is EOS
|
|
|
| If the last dimension already equals codebook_size, return z directly.
|
| If the last dimension is larger (e.g. full vocab), slice the visual region.
|
| Otherwise raise a descriptive error so the caller can fix the config.
|
|
|
| Args:
|
| logits: Raw transformer output, shape [B, L+N+1, D].
|
| N: Number of visual tokens (T*H*W).
|
| codebook_size: Expected number of codebook entries (scheduler.codebook_size).
|
| lm_head_size: Deprecated alias for codebook_size; ignored if None.
|
|
|
| Returns:
|
| Tensor of shape [B, N, codebook_size].
|
| """
|
| B_in = logits.size(0)
|
|
|
|
|
|
|
| z = logits[:, -(N + 1) : -1]
|
|
|
|
|
| assert z.size(1) == N, (
|
| f"extract_visual_logits: slice produced seq_len={z.size(1)}, expected N={N}. "
|
| "Logit sequence length may be shorter than N+1. "
|
| "Check that input_ids was built with the correct latents_shape."
|
| )
|
|
|
| D = z.size(-1)
|
|
|
| if D == codebook_size:
|
|
|
| assert z.shape == (B_in, N, codebook_size), \
|
| f"extract_visual_logits: z.shape={z.shape} expected ({B_in},{N},{codebook_size})"
|
| return z
|
|
|
|
|
|
|
| if D > codebook_size:
|
| lm_vocab_size = D - codebook_size
|
| z_vis = z[..., lm_vocab_size:]
|
| assert z_vis.shape == (B_in, N, codebook_size), \
|
| f"extract_visual_logits (sliced): z_vis.shape={z_vis.shape}"
|
| return z_vis
|
|
|
| raise ValueError(
|
| f"extract_visual_logits: unexpected logit last-dim={D} < codebook_size={codebook_size}. "
|
| "Check transformer.config.lm_head_size and scheduler.codebook_size. "
|
| f"logits.shape={logits.shape}"
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def corrupt_tokens(tokens: torch.Tensor, r: float, K: int) -> torch.Tensor:
|
| """Replace a random fraction r of tokens with uniform random codes.
|
|
|
| Used for the 20% p_init mixing strategy:
|
| mask = Bernoulli(r)
|
| corrupted = mask * randint(K) + (1-mask) * tokens
|
|
|
| Args:
|
| tokens: Long tensor of codebook indices, any shape.
|
| r: Fraction of tokens to corrupt (0 < r < 1).
|
| K: Codebook size.
|
|
|
| Returns:
|
| Corrupted token tensor, same shape and dtype as ``tokens``.
|
| """
|
| mask = torch.bernoulli(torch.full_like(tokens, r, dtype=torch.float)).bool()
|
| rand_codes = torch.randint(0, K, tokens.shape, device=tokens.device, dtype=tokens.dtype)
|
| return torch.where(mask, rand_codes, tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def kl_divergence(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
| """KL(p || q) summed over last dimension, per-sample mean over tokens.
|
|
|
| Args:
|
| p: [B, N, K] probability tensor.
|
| q: [B, N, K] probability tensor.
|
|
|
| Returns:
|
| [B] per-sample KL divergence (mean over N tokens).
|
| """
|
| p = p.clamp(min=eps)
|
| q = q.clamp(min=eps)
|
| return (p * (p.log() - q.log())).sum(-1).mean(-1)
|
|
|
|
|
| def jeffrey_divergence(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
| """Symmetric KL (Jeffrey): KL(p||q) + KL(q||p), per-sample mean over tokens.
|
|
|
| Returns:
|
| [B] per-sample Jeffrey divergence.
|
| """
|
| return kl_divergence(p, q, eps) + kl_divergence(q, p, eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def sample_t_curriculum(
|
| B: int,
|
| device: torch.device,
|
| step: int,
|
| warmup_steps: int = 10_000,
|
| ) -> torch.Tensor:
|
| """Sample training timesteps with a curriculum biased toward large t early on.
|
|
|
| - For the first ``warmup_steps`` steps, use t = 1 - (1-u)^2 (biased high).
|
| - After warmup, fall back to a near-uniform u sampled straight from [0, 1).
|
| - t is clamped to [0.05, 0.995] to avoid degenerate paths.
|
|
|
| Returns:
|
| [B] float tensor of continuous timesteps.
|
| """
|
| u = torch.rand(B, device=device)
|
| if step < warmup_steps:
|
| t = 1.0 - (1.0 - u) ** 2
|
| else:
|
| t = u
|
| return t.clamp(0.05, 0.995)
|
|
|