Instructions to use BryanW/43.wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BryanW/43.wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BryanW/43.wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2024-present, BAAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ----------------------------------------------------------------------- | |
| """Utility functions for constructing URSA model inputs during distillation. | |
| This module mirrors the token-splicing and RoPE-position logic from | |
| URSAPipeline.__call__ and URSATrainPipeline.process_inputs so that | |
| student/aux/teacher always see the exact same input distribution. | |
| Key design facts (verified from source): | |
| - transformer.config.lm_head_size = 64000 -> logit output dim (codebook_size) | |
| - transformer.config.lm_vocab_size = 151669 -> text-vocab offset for visual tokens | |
| - transformer.config.bov_token_id = 151652 -> beginning-of-video sentinel | |
| - Input visual token IDs are shifted: stored as (raw_code + lm_vocab_size) | |
| - BOV sentinel is prepended to the visual token block | |
| - Causal slice to recover visual logits: logits[:, -(N+1):-1] where N = T*H*W | |
| """ | |
| from typing import Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| # --------------------------------------------------------------------------- | |
| # Latent shape helpers | |
| # --------------------------------------------------------------------------- | |
| def compute_latents_shape( | |
| num_frames: int, | |
| height: int, | |
| width: int, | |
| temporal_stride: int = 4, | |
| spatial_stride: int = 8, | |
| ) -> Tuple[int, int, int]: | |
| """Return the VQ-token grid (T, H, W) matching URSAPipeline's convention. | |
| Matches the formula in URSAPipeline.__call__: | |
| T = (num_frames - 1) // temporal_stride + 1 | |
| H = height // spatial_stride | |
| W = width // spatial_stride | |
| """ | |
| T = (num_frames - 1) // temporal_stride + 1 | |
| H = height // spatial_stride | |
| W = width // spatial_stride | |
| return T, H, W | |
| # --------------------------------------------------------------------------- | |
| # Core input builder | |
| # --------------------------------------------------------------------------- | |
| def build_ursa_inputs( | |
| transformer, | |
| txt_ids: torch.Tensor, | |
| visual_tokens: torch.Tensor, | |
| latents_shape: Tuple[int, int, int], | |
| device: torch.device, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
| """Construct (input_ids, rope_pos, N) exactly as URSAPipeline does. | |
| This is the single source-of-truth for all three models | |
| (teacher / aux / student) so their input distributions match. | |
| Args: | |
| transformer: The URSATransformer3DModel (read config from it). | |
| txt_ids: Tokenised prompts, shape [B, L]. | |
| visual_tokens: Raw codebook indices, shape [B, T, H, W] or [B, N], dtype long. | |
| latents_shape: (T, H, W) tuple – shape of one video's latent grid. | |
| device: Target device. | |
| Returns: | |
| input_ids: [B, L + N + 1], long (N = T*H*W) | |
| rope_pos: [B, L + N + 1, 3], int32 | |
| N: number of visual tokens per sample (T*H*W) | |
| Notes: | |
| - BOV token is inserted at position L (just before the visual tokens). | |
| - Visual token IDs are shifted by lm_vocab_size before being concatenated. | |
| - rope_pos is batched (training convention), not the 2-D inference convention. | |
| """ | |
| B, L = txt_ids.shape | |
| # -- Config values --------------------------------------------------- | |
| # PITFALL 1: always read from config, never hard-code. | |
| bov_token_id = transformer.config.bov_token_id | |
| # lm_vocab_size == len(tokenizer): the visual-token vocab offset. | |
| latent_shift = transformer.config.lm_vocab_size | |
| T, H, W = latents_shape | |
| N = T * H * W | |
| # -- Input validation ------------------------------------------------ | |
| assert visual_tokens.dtype == torch.long, \ | |
| f"build_ursa_inputs: visual_tokens must be long, got {visual_tokens.dtype}" | |
| assert visual_tokens.numel() == B * N, ( | |
| f"build_ursa_inputs: visual_tokens has {visual_tokens.numel()} elements, " | |
| f"expected B*N = {B}*{N} = {B*N}" | |
| ) | |
| # -- Visual token block ---------------------------------------------- | |
| # Flatten to [B, N] so pad/cat are straightforward. | |
| latents_flat = visual_tokens.view(B, N).to(device) # [B, N], long | |
| # Shift raw codebook indices into the visual-vocab region and prepend BOV. | |
| # Mirrors: img_ids = pad(latents_flat + latent_shift, (1,0), value=bov_token_id) | |
| img_ids = F.pad(latents_flat + latent_shift, (1, 0), value=bov_token_id) # [B, N+1] | |
| # -- Full input sequence: [txt | bov | vis_0 ... vis_{N-1}] ---------- | |
| input_ids = torch.cat([txt_ids.to(device), img_ids], dim=1) # [B, L+N+1] | |
| # -- RoPE positions -------------------------------------------------- | |
| # Mirrors URSAPipeline: | |
| # txt_pos = arange(L).view(-1,1).expand(-1,3) -> [L, 3] | |
| # blk_pos = flex_rope.get_pos(latents_shape, L) -> [1, N+1, 3] | |
| # rope_pos = cat([txt_pos, blk_pos[0]]) -> [L+N+1, 3] | |
| # Then batch-expand (training convention): | |
| # rope_pos = rope_pos.unsqueeze(0).expand(B,-1,-1).contiguous() -> [B, L+N+1, 3] | |
| txt_pos = torch.arange(L, device=device).view(-1, 1).expand(-1, 3) # [L, 3] | |
| blk_pos = transformer.model.flex_rope.get_pos(latents_shape, txt_pos.size(0)) # [1, N+1, 3] | |
| rope_pos_1d = torch.cat([txt_pos, blk_pos[0].to(device)], dim=0) # [L+N+1, 3] | |
| rope_pos = rope_pos_1d.unsqueeze(0).expand(B, -1, -1).contiguous() # [B, L+N+1, 3] | |
| # -- Output shape assertions ----------------------------------------- | |
| expected_seq_len = L + N + 1 | |
| assert input_ids.shape == (B, expected_seq_len), ( | |
| f"build_ursa_inputs: input_ids shape={input_ids.shape} " | |
| f"expected ({B},{expected_seq_len}). " | |
| "txt_ids length or latents_shape may be wrong." | |
| ) | |
| assert rope_pos.shape == (B, expected_seq_len, 3), ( | |
| f"build_ursa_inputs: rope_pos shape={rope_pos.shape} " | |
| f"expected ({B},{expected_seq_len},3). " | |
| "BOV/blk_pos alignment is off — check flex_rope.get_pos return shape." | |
| ) | |
| return input_ids, rope_pos, N | |
| # --------------------------------------------------------------------------- | |
| # Visual logit extractor | |
| # --------------------------------------------------------------------------- | |
| def extract_visual_logits( | |
| logits: torch.Tensor, | |
| N: int, | |
| codebook_size: int, | |
| lm_head_size: int = None, | |
| ) -> torch.Tensor: | |
| """Slice and (if needed) project the transformer logits to [B, N, K]. | |
| PITFALL 2: The lm_head projects hidden states to lm_head_size (=64000), | |
| NOT to the full vocab_size. We must never confuse text-vocab indices with | |
| codebook indices. This function is the single gate that converts raw | |
| transformer output to visual-codebook logits. | |
| Slicing convention (mirrors URSAPipeline): | |
| z = logits[:, -(N+1) : -1] # causal shift: BOV at -(N+1), last is EOS | |
| If the last dimension already equals codebook_size, return z directly. | |
| If the last dimension is larger (e.g. full vocab), slice the visual region. | |
| Otherwise raise a descriptive error so the caller can fix the config. | |
| Args: | |
| logits: Raw transformer output, shape [B, L+N+1, D]. | |
| N: Number of visual tokens (T*H*W). | |
| codebook_size: Expected number of codebook entries (scheduler.codebook_size). | |
| lm_head_size: Deprecated alias for codebook_size; ignored if None. | |
| Returns: | |
| Tensor of shape [B, N, codebook_size]. | |
| """ | |
| B_in = logits.size(0) | |
| # PITFALL 2: causal slice – exactly as URSAPipeline uses it. | |
| # logits[:, -(N+1):-1] extracts the N positions after the BOV token. | |
| z = logits[:, -(N + 1) : -1] # [B, N, D] | |
| # Verify sliced sequence length matches N. | |
| assert z.size(1) == N, ( | |
| f"extract_visual_logits: slice produced seq_len={z.size(1)}, expected N={N}. " | |
| "Logit sequence length may be shorter than N+1. " | |
| "Check that input_ids was built with the correct latents_shape." | |
| ) | |
| D = z.size(-1) | |
| if D == codebook_size: | |
| # Happy path: lm_head_size == codebook_size (default URSA config). | |
| assert z.shape == (B_in, N, codebook_size), \ | |
| f"extract_visual_logits: z.shape={z.shape} expected ({B_in},{N},{codebook_size})" | |
| return z | |
| # If the head includes a text prefix (shouldn't happen with default config, | |
| # but guard anyway). | |
| if D > codebook_size: | |
| lm_vocab_size = D - codebook_size | |
| z_vis = z[..., lm_vocab_size:] | |
| assert z_vis.shape == (B_in, N, codebook_size), \ | |
| f"extract_visual_logits (sliced): z_vis.shape={z_vis.shape}" | |
| return z_vis | |
| raise ValueError( | |
| f"extract_visual_logits: unexpected logit last-dim={D} < codebook_size={codebook_size}. " | |
| "Check transformer.config.lm_head_size and scheduler.codebook_size. " | |
| f"logits.shape={logits.shape}" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Corrupt helper (for p_init mixing) | |
| # --------------------------------------------------------------------------- | |
| def corrupt_tokens(tokens: torch.Tensor, r: float, K: int) -> torch.Tensor: | |
| """Replace a random fraction r of tokens with uniform random codes. | |
| Used for the 20% p_init mixing strategy: | |
| mask = Bernoulli(r) | |
| corrupted = mask * randint(K) + (1-mask) * tokens | |
| Args: | |
| tokens: Long tensor of codebook indices, any shape. | |
| r: Fraction of tokens to corrupt (0 < r < 1). | |
| K: Codebook size. | |
| Returns: | |
| Corrupted token tensor, same shape and dtype as ``tokens``. | |
| """ | |
| mask = torch.bernoulli(torch.full_like(tokens, r, dtype=torch.float)).bool() | |
| rand_codes = torch.randint(0, K, tokens.shape, device=tokens.device, dtype=tokens.dtype) | |
| return torch.where(mask, rand_codes, tokens) | |
| # --------------------------------------------------------------------------- | |
| # KL / Jeffrey divergence helpers | |
| # --------------------------------------------------------------------------- | |
| def kl_divergence(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: | |
| """KL(p || q) summed over last dimension, per-sample mean over tokens. | |
| Args: | |
| p: [B, N, K] probability tensor. | |
| q: [B, N, K] probability tensor. | |
| Returns: | |
| [B] per-sample KL divergence (mean over N tokens). | |
| """ | |
| p = p.clamp(min=eps) | |
| q = q.clamp(min=eps) | |
| return (p * (p.log() - q.log())).sum(-1).mean(-1) # [B] | |
| def jeffrey_divergence(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: | |
| """Symmetric KL (Jeffrey): KL(p||q) + KL(q||p), per-sample mean over tokens. | |
| Returns: | |
| [B] per-sample Jeffrey divergence. | |
| """ | |
| return kl_divergence(p, q, eps) + kl_divergence(q, p, eps) | |
| # --------------------------------------------------------------------------- | |
| # Timestep curriculum | |
| # --------------------------------------------------------------------------- | |
| def sample_t_curriculum( | |
| B: int, | |
| device: torch.device, | |
| step: int, | |
| warmup_steps: int = 10_000, | |
| ) -> torch.Tensor: | |
| """Sample training timesteps with a curriculum biased toward large t early on. | |
| - For the first ``warmup_steps`` steps, use t = 1 - (1-u)^2 (biased high). | |
| - After warmup, fall back to a near-uniform u sampled straight from [0, 1). | |
| - t is clamped to [0.05, 0.995] to avoid degenerate paths. | |
| Returns: | |
| [B] float tensor of continuous timesteps. | |
| """ | |
| u = torch.rand(B, device=device) | |
| if step < warmup_steps: | |
| t = 1.0 - (1.0 - u) ** 2 # squish toward 1 (data end) | |
| else: | |
| t = u | |
| return t.clamp(0.05, 0.995) | |