Instructions to use BiliSakura/ProMoE-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/ProMoE-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/ProMoE-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn as nn | |
| from .modeling_promoe_common import ( | |
| Attention, | |
| FinalLayer, | |
| LabelEmbedder, | |
| Mlp, | |
| PatchEmbed, | |
| TimestepEmbedder, | |
| get_2d_sincos_pos_embed, | |
| modulate, | |
| ) | |
| class DiTBlock(nn.Module): | |
| def __init__(self, hidden_size, num_heads, head_dim=None, mlp_ratio=4.0, use_swiglu=False, **block_kwargs): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) | |
| self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
| approx_gelu = lambda: nn.GELU(approximate="tanh") | |
| self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) | |
| self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) | |
| def forward(self, x, c): | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) | |
| x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) | |
| x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) | |
| return x | |
| class DiT(nn.Module): | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4.0, | |
| qk_norm=False, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| head_dim=None, | |
| use_swiglu=False, | |
| ): | |
| super().__init__() | |
| self.learn_sigma = learn_sigma | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels * 2 if learn_sigma else in_channels | |
| self.patch_size = patch_size | |
| self.num_heads = num_heads | |
| self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) | |
| self.t_embedder = TimestepEmbedder(hidden_size) | |
| self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) | |
| num_patches = self.x_embedder.num_patches | |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| DiTBlock( | |
| hidden_size, | |
| num_heads, | |
| head_dim=head_dim, | |
| mlp_ratio=mlp_ratio, | |
| qk_norm=qk_norm, | |
| use_swiglu=use_swiglu, | |
| ) | |
| for _ in range(depth) | |
| ] | |
| ) | |
| self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) | |
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
| w = self.x_embedder.proj.weight.data | |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| nn.init.constant_(self.x_embedder.proj.bias, 0) | |
| nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
| for block in self.blocks: | |
| nn.init.constant_(block.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(block.adaLN_modulation[-1].bias, 0) | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) | |
| nn.init.constant_(self.final_layer.linear.weight, 0) | |
| nn.init.constant_(self.final_layer.linear.bias, 0) | |
| def unpatchify(self, x): | |
| c = self.out_channels | |
| p = self.x_embedder.patch_size[0] | |
| h = w = int(x.shape[1] ** 0.5) | |
| x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) | |
| x = torch.einsum("nhwpqc->nchpwq", x) | |
| return x.reshape(shape=(x.shape[0], c, h * p, h * p)) | |
| def forward(self, x, t, context, **kwargs): | |
| y = context | |
| if len(x.shape) != 4: | |
| x = x.squeeze(2) | |
| x = self.x_embedder(x) + self.pos_embed | |
| t = self.t_embedder(t) | |
| y = self.y_embedder(y, self.training) | |
| c = t + y | |
| for block in self.blocks: | |
| x = block(x, c) | |
| x = self.final_layer(x, c) | |
| return self.unpatchify(x) | |