Instructions to use BiliSakura/ProMoE-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/ProMoE-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/ProMoE-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .modeling_promoe_common import ( | |
| Attention, | |
| FinalLayer, | |
| LabelEmbedder, | |
| Mlp, | |
| MoeMLP_DiffMoE as MoeMLP, | |
| PatchEmbed, | |
| TimestepEmbedder, | |
| get_2d_sincos_pos_embed, | |
| modulate, | |
| ) | |
| class SparseMoEBlock(nn.Module): | |
| def __init__( | |
| self, | |
| experts, | |
| hidden_dim, | |
| num_experts, | |
| n_shared_experts=0, | |
| capacity=2, | |
| mlp_ratio=4.0, | |
| use_diff_expert=False, | |
| ): | |
| super().__init__() | |
| self.gate_weight = nn.Parameter(torch.empty((num_experts, hidden_dim))) | |
| nn.init.normal_(self.gate_weight, std=0.006) | |
| self.experts = nn.ModuleList(experts) | |
| self.capacity = capacity | |
| self.num_experts = num_experts | |
| self.n_shared_experts = n_shared_experts | |
| self.use_diff_expert = use_diff_expert | |
| if use_diff_expert: | |
| self.diff_expert = MoeMLP(hidden_size=hidden_dim, intermediate_size=int(hidden_dim * mlp_ratio)) | |
| self.capacity_predictor = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, self.num_experts, bias=True), | |
| ) | |
| if self.n_shared_experts > 0: | |
| mlp_hidden_dim = int(hidden_dim * mlp_ratio * 2) | |
| approx_gelu = lambda: nn.GELU(approximate="tanh") | |
| self.shared_experts = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) | |
| self.register_buffer("expert_threshold", torch.tensor([0.0] * num_experts)) | |
| self.register_buffer("ema_decay", torch.tensor([0.95])) | |
| def forward(self, x): | |
| if self.training: | |
| return self.forward_train(x) | |
| return self.forward_eval(x) | |
| def update_threshold(self, capacity_pred): | |
| if not self.training: | |
| return | |
| capacity_pred = torch.sigmoid(capacity_pred) | |
| seq_len = capacity_pred.size(0) | |
| topk = int((seq_len / self.num_experts) * self.capacity) | |
| threshold = self.expert_threshold | |
| ema_decay = self.ema_decay | |
| for i in range(self.num_experts): | |
| scores, _ = torch.topk(capacity_pred[:, i], k=topk, dim=-1, sorted=True) | |
| quantile = scores[-1].detach() | |
| threshold[i] = threshold[i] * ema_decay + (1 - ema_decay) * quantile | |
| if dist.is_available() and dist.is_initialized(): | |
| dist.all_reduce(threshold, op=dist.ReduceOp.SUM) | |
| threshold /= dist.get_world_size() | |
| self.expert_threshold = threshold | |
| def forward_train(self, x): | |
| bsz, seq_len, hidden_dim = x.shape | |
| identity = x | |
| x = x.view(-1, hidden_dim) | |
| total_tokens = x.shape[0] | |
| capacity_pred = self.capacity_predictor(x.detach()) | |
| k = int((total_tokens / self.num_experts) * self.capacity) | |
| logits = F.linear(x, self.gate_weight, None) | |
| scores = logits.softmax(dim=-1).permute(1, 0) | |
| gating, index = torch.topk(scores, k=k, dim=-1, sorted=False) | |
| mask = torch.zeros((self.num_experts, total_tokens), dtype=x.dtype, device=x.device) | |
| mask.scatter_(1, index, 1.0) | |
| expert_inputs = x[index] | |
| expert_outputs = torch.stack([expert(expert_inputs[i]) for i, expert in enumerate(self.experts)]) | |
| gated_outputs = gating.unsqueeze(-1) * expert_outputs | |
| y = torch.zeros((total_tokens * self.num_experts, hidden_dim), dtype=x.dtype, device=x.device) | |
| offset = torch.arange(0, self.num_experts, device=x.device).unsqueeze(1) * total_tokens | |
| flat_index = (index + offset.long()).view(-1) | |
| y = torch.scatter(y, 0, flat_index.unsqueeze(1).expand(-1, hidden_dim), gated_outputs.view(-1, hidden_dim)) | |
| y = y.view(self.num_experts, total_tokens, hidden_dim).sum(dim=0, keepdim=False) | |
| self.update_threshold(capacity_pred) | |
| x_out = y.view(bsz, seq_len, hidden_dim) | |
| ones = mask.permute(1, 0).view(bsz, seq_len, self.num_experts) | |
| capacity_pred = capacity_pred.view(bsz, seq_len, self.num_experts) | |
| if self.n_shared_experts > 0: | |
| x_out = x_out + self.shared_experts(identity) | |
| if self.use_diff_expert: | |
| x_out = x_out - self.diff_expert(identity) | |
| return x_out, ones, capacity_pred | |
| def forward_eval(self, x): | |
| bsz, seq_len, hidden_dim = x.shape | |
| identity = x | |
| x = x.view(-1, hidden_dim) | |
| total_tokens = x.shape[0] | |
| capacity_pred = torch.sigmoid(self.capacity_predictor(x.detach())) | |
| threshold = self.expert_threshold | |
| logits = F.linear(x, self.gate_weight, None) | |
| scores = logits.softmax(dim=-1).permute(-1, -2) | |
| y = torch.zeros_like(x, dtype=x.dtype) | |
| for i, expert in enumerate(self.experts): | |
| k_fixed = torch.where(capacity_pred[:, i] > threshold[i], 1, 0).sum() | |
| gating, index = torch.topk(scores[i], k=k_fixed, dim=-1, sorted=False) | |
| y[index, :] += gating.unsqueeze(-1) * expert(x[index, :]) | |
| x_out = y.view(bsz, seq_len, hidden_dim) | |
| if self.n_shared_experts > 0: | |
| x_out = x_out + self.shared_experts(identity) | |
| return x_out, None, None | |
| class DiTBlock(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size, | |
| num_heads, | |
| head_dim=None, | |
| mlp_ratio=4.0, | |
| use_swiglu=False, | |
| MoE_config=None, | |
| use_moe=False, | |
| qk_norm=False, | |
| **block_kwargs, | |
| ): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=qk_norm, **block_kwargs) | |
| self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
| self.use_moe = use_moe | |
| if use_moe: | |
| if not use_swiglu: | |
| approx_gelu = lambda: nn.GELU(approximate="tanh") | |
| experts = [ | |
| Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) | |
| for _ in range(MoE_config.num_experts) | |
| ] | |
| else: | |
| experts = [MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) for _ in range(MoE_config.num_experts)] | |
| self.mlp = SparseMoEBlock( | |
| experts=experts, | |
| hidden_dim=hidden_size, | |
| num_experts=MoE_config.num_experts, | |
| capacity=MoE_config.capacity, | |
| n_shared_experts=MoE_config.n_shared_experts, | |
| mlp_ratio=4.0, | |
| ) | |
| else: | |
| if not use_swiglu: | |
| approx_gelu = lambda: nn.GELU(approximate="tanh") | |
| self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) | |
| else: | |
| self.mlp = MoeMLP(hidden_size=hidden_size, intermediate_size=mlp_hidden_dim) | |
| self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) | |
| def forward(self, x, c): | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) | |
| x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) | |
| if self.use_moe: | |
| x_mlp, ones, pred_c = self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) | |
| x = x + gate_mlp.unsqueeze(1) * x_mlp | |
| return x, ones, pred_c | |
| x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) | |
| return x, None, None | |
| class DiT(nn.Module): | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4.0, | |
| qk_norm=False, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| use_swiglu=False, | |
| MoE_config=None, | |
| head_dim=None, | |
| CapacityPred_loss_weight=0.01, | |
| ): | |
| super().__init__() | |
| self.learn_sigma = learn_sigma | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels * 2 if learn_sigma else in_channels | |
| self.patch_size = patch_size | |
| self.num_heads = num_heads | |
| self.MoE_config = MoE_config | |
| use_moe_flag = [i % 2 == 1 for i in range(depth)] if self.MoE_config.interleave else [True] * depth | |
| self.CapacityPred_loss_weight = CapacityPred_loss_weight | |
| self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) | |
| self.t_embedder = TimestepEmbedder(hidden_size) | |
| self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) | |
| num_patches = self.x_embedder.num_patches | |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| DiTBlock( | |
| hidden_size, | |
| num_heads, | |
| head_dim=head_dim, | |
| mlp_ratio=mlp_ratio, | |
| qk_norm=qk_norm, | |
| use_swiglu=use_swiglu, | |
| MoE_config=MoE_config, | |
| use_moe=use_moe_flag[i], | |
| ) | |
| for i in range(depth) | |
| ] | |
| ) | |
| self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) | |
| self.init_MoeMLP = MoE_config.init_MoeMLP | |
| self.initialize_weights() | |
| self.capacity_schedule = MoE_config.get("capacity_schedule", None) | |
| if self.capacity_schedule: | |
| self.training_iters = -1 | |
| def initialize_weights(self): | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) | |
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
| w = self.x_embedder.proj.weight.data | |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| nn.init.constant_(self.x_embedder.proj.bias, 0) | |
| nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
| for block in self.blocks: | |
| nn.init.constant_(block.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(block.adaLN_modulation[-1].bias, 0) | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) | |
| nn.init.constant_(self.final_layer.linear.weight, 0) | |
| nn.init.constant_(self.final_layer.linear.bias, 0) | |
| def unpatchify(self, x): | |
| c = self.out_channels | |
| p = self.x_embedder.patch_size[0] | |
| h = w = int(x.shape[1] ** 0.5) | |
| x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) | |
| x = torch.einsum("nhwpqc->nchpwq", x) | |
| return x.reshape(shape=(x.shape[0], c, h * p, h * p)) | |
| def forward(self, x, t, context, **kwargs): | |
| y = context | |
| if len(x.shape) != 4: | |
| x = x.squeeze(2) | |
| if self.training and self.capacity_schedule: | |
| num_experts = self.MoE_config.num_experts | |
| capacity = self.MoE_config.capacity | |
| stage_i = self.MoE_config.capacity_schedule.capacity_schedule_stage_I_iters | |
| stage_ii = self.MoE_config.capacity_schedule.capacity_schedule_stage_II_iters | |
| if self.training_iters <= stage_i: | |
| capacity = num_experts | |
| elif self.training_iters <= stage_ii: | |
| capacity = capacity + (num_experts - capacity) * (stage_ii - self.training_iters) / (stage_ii - stage_i) | |
| for block in self.blocks: | |
| if hasattr(block.mlp, "capacity"): | |
| block.mlp.capacity = capacity | |
| x = self.x_embedder(x) + self.pos_embed | |
| t = self.t_embedder(t) | |
| y = self.y_embedder(y, self.training) | |
| c = t + y | |
| ones_list, pred_c_list, layer_idx_list = [], [], [] | |
| for layer_idx, block in enumerate(self.blocks): | |
| x, ones, pred_c = block(x, c) | |
| if ones is not None: | |
| ones_list.append(ones) | |
| pred_c_list.append(pred_c) | |
| layer_idx_list.append(layer_idx) | |
| x = self.final_layer(x, c) | |
| x = self.unpatchify(x) | |
| return x, "Capacity_Pred", layer_idx_list, ones_list, pred_c_list, self.CapacityPred_loss_weight | |