Instructions to use BryanW/43.wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BryanW/43.wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BryanW/43.wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2024-present, BAAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ------------------------------------------------------------------------ | |
| """Simple implementation of AutoEncoderVQ.""" | |
| import torch | |
| from torch import nn | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_outputs import AutoencoderKLOutput | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffnext.models.autoencoders.autoencoder_kl import Attention, Decoder, Encoder | |
| from diffnext.models.autoencoders.modeling_utils import DecoderOutput, IdentityDistribution | |
| from diffnext.models.autoencoders import quantizers | |
| class AutoencoderVQ(ModelMixin, ConfigMixin): | |
| """AutoEncoder VQ.""" | |
| def __init__( | |
| self, | |
| in_channels=3, | |
| out_channels=3, | |
| down_block_types=("DownEncoderBlock2D",) * 4, | |
| up_block_types=("UpDecoderBlock2D",) * 4, | |
| block_out_channels=(128, 256, 512, 512), | |
| layers_per_block=2, | |
| act_fn="silu", | |
| latent_channels=16, | |
| norm_num_groups=32, | |
| sample_size=1024, | |
| num_vq_embeddings=16384, | |
| vq_embed_dim=8, | |
| attn_down_block=False, | |
| attn_up_block=False, | |
| force_upcast=False, | |
| temporal_stride=1, | |
| spatial_stride=16, | |
| decoder_dtype=None, | |
| _quantizer_name="VQuantizer", | |
| ): | |
| super(AutoencoderVQ, self).__init__() | |
| channels, layers = block_out_channels, layers_per_block | |
| self.encoder = Encoder(in_channels, latent_channels, channels, layers) | |
| self.decoder = Decoder(latent_channels, out_channels, channels, layers) | |
| self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) | |
| self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) | |
| if attn_down_block: | |
| attentions = [Attention(block_out_channels[-1]) for _ in range(layers_per_block)] | |
| self.encoder.down_blocks[-1].attentions += attentions | |
| if attn_up_block: | |
| attentions = [Attention(block_out_channels[-1]) for _ in range(layers_per_block + 1)] | |
| self.decoder.up_blocks[0].attentions += attentions | |
| self.quantizer = getattr(quantizers, _quantizer_name)(num_vq_embeddings, vq_embed_dim) | |
| self.latent_dist = IdentityDistribution | |
| def to(self, *args, **kwargs): | |
| """Convert to given device and dtype.""" | |
| super().to(*args, **kwargs) | |
| if self.config.decoder_dtype: | |
| self.decoder.to(dtype=getattr(torch, self.config.decoder_dtype)) | |
| return self | |
| def scale_(self, x) -> torch.Tensor: | |
| """Scale the input latents.""" | |
| return x | |
| def unscale_(self, x) -> torch.Tensor: | |
| """Unscale the input latents.""" | |
| return x | |
| def encode(self, x) -> AutoencoderKLOutput: | |
| """Encode the input samples.""" | |
| z = self.encoder(self.forward(x)) | |
| z = self.quant_conv(z) | |
| posterior = self.latent_dist(self.quantizer.quantize(z)) | |
| return AutoencoderKLOutput(latent_dist=posterior) | |
| def decode(self, ids) -> DecoderOutput: | |
| """Decode the input indices.""" | |
| z = self.quantizer.dequantize(ids) | |
| t = z.size(2) if z.dim() == 5 else 1 | |
| z = z.transpose(1, 2).flatten(0, 1) if t > 1 else z | |
| z = z.squeeze_(2) if z.dim() == 5 else z | |
| x = self.post_quant_conv(self.forward(z)) | |
| x = self.decoder(x.to(self.decoder.conv_in.weight)) | |
| x = x.view(-1, t, *x.shape[1:]).transpose(1, 2) if t > 1 else x | |
| return DecoderOutput(sample=x) | |
| def forward(self, x): # NOOP. | |
| return x | |