Instructions to use BryanW/43.wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BryanW/43.wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BryanW/43.wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # ------------------------------------------------------------------------ | |
| # Copyright (c) 2024-present, BAAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ------------------------------------------------------------------------ | |
| """Vision Transformer.""" | |
| from typing import Tuple | |
| import torch | |
| from torch import nn | |
| from torch.utils.checkpoint import checkpoint as apply_ckpt | |
| from diffnext.models.embeddings import PatchEmbed, RotaryEmbed3D | |
| from diffnext.models.flex_attention import FlexAttentionCausal2D | |
| class MLP(nn.Module): | |
| """Two layers MLP.""" | |
| def __init__(self, dim, mlp_ratio=4): | |
| super(MLP, self).__init__() | |
| self.fc1 = nn.Linear(dim, int(dim * mlp_ratio)) | |
| self.fc2 = nn.Linear(int(dim * mlp_ratio), dim) | |
| self.activation = nn.GELU() | |
| def forward(self, x) -> torch.Tensor: | |
| return self.fc2(self.activation(self.fc1(x))) | |
| class Attention(nn.Module): | |
| """Multihead attention.""" | |
| def __init__(self, dim, num_heads, qkv_bias=True): | |
| super(Attention, self).__init__() | |
| self.num_heads, self.head_dim = num_heads, dim // num_heads | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.proj = nn.Linear(dim, dim) | |
| self.attn_mask, self.cache_kv, self.pe_func, self.flex_attn = None, None, None, None | |
| def forward(self, x) -> torch.Tensor: | |
| qkv_shape = [-1, x.size(1), 3, self.num_heads, self.head_dim] | |
| q, k, v = self.qkv(x).view(qkv_shape).permute(2, 0, 3, 1, 4).unbind(dim=0) | |
| q, k = (self.pe_func(q), self.pe_func(k)) if self.pe_func else (q, k) | |
| if self.cache_kv is not None and self.cache_kv: | |
| if isinstance(self.cache_kv, list): | |
| k = self.cache_kv[0] = torch.cat([self.cache_kv[0], k], dim=2) | |
| v = self.cache_kv[1] = torch.cat([self.cache_kv[1], v], dim=2) | |
| else: | |
| self.cache_kv = [k, v] | |
| if self.flex_attn and self.flex_attn.offsets: | |
| return self.proj(self.flex_attn(q, k, v).transpose(1, 2).flatten(2)) | |
| o = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask) | |
| return self.proj(o.transpose(1, 2).flatten(2)) | |
| class Block(nn.Module): | |
| """Transformer block.""" | |
| def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True): | |
| super(Block, self).__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.mlp = MLP(dim, mlp_ratio=mlp_ratio) | |
| self.attn_checkpointing, self.mlp_checkpointing = False, False | |
| def forward_attn(self, x) -> torch.Tensor: | |
| return self.norm1(self.attn(x)) | |
| def forward_mlp(self, x) -> torch.Tensor: | |
| return self.norm2(self.mlp(x)) | |
| def forward_ckpt(self, x, name) -> torch.Tensor: | |
| if getattr(self, f"{name}_checkpointing", False) and x.requires_grad: | |
| return apply_ckpt(getattr(self, f"forward_{name}"), x, use_reentrant=False) | |
| return getattr(self, f"forward_{name}")(x) | |
| def forward(self, x, pe_func: callable = None) -> torch.Tensor: | |
| self.attn.pe_func = pe_func | |
| x = self.forward_ckpt(x, "attn").add_(x) | |
| return self.forward_ckpt(x, "mlp").add_(x) | |
| class VisionTransformer(nn.Module): | |
| """Vision transformer.""" | |
| def __init__( | |
| self, | |
| depth, | |
| embed_dim, | |
| num_heads, | |
| mlp_ratio=4, | |
| patch_size=2, | |
| image_size=32, | |
| image_dim=4, | |
| encoder_depth=None, | |
| ): | |
| super(VisionTransformer, self).__init__() | |
| self.embed_dim, self.image_size, self.image_dim = embed_dim, image_size, image_dim | |
| self.patch_embed = PatchEmbed(image_dim, embed_dim, patch_size) | |
| self.pos_embed, self.rope = nn.Identity(), RotaryEmbed3D(embed_dim // num_heads) | |
| self.blocks = nn.ModuleList(Block(embed_dim, num_heads, mlp_ratio) for _ in range(depth)) | |
| self.norm, self.mixer = nn.LayerNorm(embed_dim), nn.Identity() | |
| self.encoder_depth = len(self.blocks) // 2 if encoder_depth is None else encoder_depth | |
| self.flex_attn = FlexAttentionCausal2D() | |
| [setattr(blk.attn, "flex_attn", self.flex_attn) for blk in self.blocks] | |
| def prepare_pe(self, c=None, ids=None, pos=None) -> Tuple[callable, callable]: | |
| pad = 0 if c is None else c.size(1) | |
| pe1 = pe2 = self.rope.get_func(pos, pad) | |
| pe1 = self.rope.get_func(pos, pad, ids.expand(-1, -1, 3)) if ids is not None else pe1 | |
| return pe1, pe2 | |
| def enable_kvcache(self, mode=True): | |
| [setattr(blk.attn, "cache_kv", mode) for blk in self.blocks] | |
| def forward(self, x, c=None, prev_ids=None, pos=None) -> torch.Tensor: | |
| x, prev_ids = x if isinstance(x, (tuple, list)) else (x, prev_ids) | |
| prev_ids = prev_ids if self.encoder_depth else None | |
| x = x_masked = self.pos_embed(self.patch_embed(x)) | |
| pe1, pe2 = self.prepare_pe(c, prev_ids, pos) if pos is not None else [None] * 2 | |
| if prev_ids is not None: # Split mask from x. | |
| prev_ids = prev_ids.expand(-1, -1, x.size(-1)) | |
| x = x.gather(1, prev_ids) | |
| x = x if c is None else torch.cat([c, x], dim=1) | |
| for blk in self.blocks[: self.encoder_depth]: | |
| x = blk(x, pe1) | |
| if prev_ids is not None and c is not None: # Split c from x. | |
| c, x = x.split((c.size(1), x.size(1) - c.size(1)), dim=1) | |
| if prev_ids is not None: # Merge mask with x. | |
| x = x_masked.to(dtype=x.dtype).scatter(1, prev_ids, x) | |
| x = x if c is None else torch.cat([c, x], dim=1) | |
| for blk in self.blocks[self.encoder_depth :]: | |
| x = blk(x, pe2) | |
| return self.norm(x if c is None else x[:, c.size(1) :]) | |