| #ifndef __MMDIT_HPP__ |
| #define __MMDIT_HPP__ |
|
|
| #include "ggml_extend.hpp" |
| #include "model.h" |
|
|
| #define MMDIT_GRAPH_SIZE 10240 |
|
|
| struct Mlp : public GGMLBlock { |
| public: |
| Mlp(int64_t in_features, |
| int64_t hidden_features = -1, |
| int64_t out_features = -1, |
| bool bias = true) { |
| |
| |
| |
| if (hidden_features == -1) { |
| hidden_features = in_features; |
| } |
| if (out_features == -1) { |
| out_features = in_features; |
| } |
| blocks["fc1"] = std::shared_ptr<GGMLBlock>(new Linear(in_features, hidden_features, bias)); |
| blocks["fc2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_features, out_features, bias)); |
| } |
|
|
| struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { |
| |
| auto fc1 = std::dynamic_pointer_cast<Linear>(blocks["fc1"]); |
| auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]); |
|
|
| x = fc1->forward(ctx, x); |
| x = ggml_gelu_inplace(ctx, x); |
| x = fc2->forward(ctx, x); |
| return x; |
| } |
| }; |
|
|
| struct PatchEmbed : public GGMLBlock { |
| |
| protected: |
| bool flatten; |
| bool dynamic_img_pad; |
| int patch_size; |
|
|
| public: |
| PatchEmbed(int64_t img_size = 224, |
| int patch_size = 16, |
| int64_t in_chans = 3, |
| int64_t embed_dim = 1536, |
| bool bias = true, |
| bool flatten = true, |
| bool dynamic_img_pad = true) |
| : patch_size(patch_size), |
| flatten(flatten), |
| dynamic_img_pad(dynamic_img_pad) { |
| |
| |
| |
| |
| |
|
|
| blocks["proj"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_chans, |
| embed_dim, |
| {patch_size, patch_size}, |
| {patch_size, patch_size}, |
| {0, 0}, |
| {1, 1}, |
| bias)); |
| } |
|
|
| struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { |
| |
| |
| auto proj = std::dynamic_pointer_cast<Conv2d>(blocks["proj"]); |
|
|
| if (dynamic_img_pad) { |
| int64_t W = x->ne[0]; |
| int64_t H = x->ne[1]; |
| int pad_h = (patch_size - H % patch_size) % patch_size; |
| int pad_w = (patch_size - W % patch_size) % patch_size; |
| x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); |
| } |
| x = proj->forward(ctx, x); |
|
|
| if (flatten) { |
| x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]); |
| x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); |
| } |
| return x; |
| } |
| }; |
|
|
| struct TimestepEmbedder : public GGMLBlock { |
| |
| protected: |
| int64_t frequency_embedding_size; |
|
|
| public: |
| TimestepEmbedder(int64_t hidden_size, |
| int64_t frequency_embedding_size = 256) |
| : frequency_embedding_size(frequency_embedding_size) { |
| blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(frequency_embedding_size, hidden_size, true, true)); |
| blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true)); |
| } |
|
|
| struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* t) { |
| |
| |
| auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]); |
| auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]); |
|
|
| auto t_freq = ggml_nn_timestep_embedding(ctx, t, frequency_embedding_size); |
|
|
| auto t_emb = mlp_0->forward(ctx, t_freq); |
| t_emb = ggml_silu_inplace(ctx, t_emb); |
| t_emb = mlp_2->forward(ctx, t_emb); |
| return t_emb; |
| } |
| }; |
|
|
| struct VectorEmbedder : public GGMLBlock { |
| |
| public: |
| VectorEmbedder(int64_t input_dim, |
| int64_t hidden_size) { |
| blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(input_dim, hidden_size, true, true)); |
| blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true)); |
| } |
|
|
| struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { |
| |
| |
| auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]); |
| auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]); |
|
|
| x = mlp_0->forward(ctx, x); |
| x = ggml_silu_inplace(ctx, x); |
| x = mlp_2->forward(ctx, x); |
| return x; |
| } |
| }; |
|
|
| class RMSNorm : public UnaryBlock { |
| protected: |
| int64_t hidden_size; |
| float eps; |
|
|
| void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") { |
| enum ggml_type wtype = GGML_TYPE_F32; |
| params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); |
| } |
|
|
| public: |
| RMSNorm(int64_t hidden_size, |
| float eps = 1e-06f) |
| : hidden_size(hidden_size), |
| eps(eps) {} |
|
|
| struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { |
| struct ggml_tensor* w = params["weight"]; |
| x = ggml_rms_norm(ctx, x, eps); |
| x = ggml_mul(ctx, x, w); |
| return x; |
| } |
| }; |
|
|
| class SelfAttention : public GGMLBlock { |
| public: |
| int64_t num_heads; |
| bool pre_only; |
| std::string qk_norm; |
|
|
| public: |
| SelfAttention(int64_t dim, |
| int64_t num_heads = 8, |
| std::string qk_norm = "", |
| bool qkv_bias = false, |
| bool pre_only = false) |
| : num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) { |
| int64_t d_head = dim / num_heads; |
| blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias)); |
| if (!pre_only) { |
| blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim)); |
| } |
| if (qk_norm == "rms") { |
| blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(d_head, 1.0e-6)); |
| blocks["ln_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(d_head, 1.0e-6)); |
| } else if (qk_norm == "ln") { |
| blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_head, 1.0e-6)); |
| blocks["ln_k"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_head, 1.0e-6)); |
| } |
| } |
|
|
| std::vector<struct ggml_tensor*> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) { |
| auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]); |
|
|
| auto qkv = qkv_proj->forward(ctx, x); |
| auto qkv_vec = split_qkv(ctx, qkv); |
| int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; |
| auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); |
| auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); |
| auto v = qkv_vec[2]; |
|
|
| if (qk_norm == "rms" || qk_norm == "ln") { |
| auto ln_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["ln_q"]); |
| auto ln_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["ln_k"]); |
| q = ln_q->forward(ctx, q); |
| k = ln_k->forward(ctx, k); |
| } |
|
|
| q = ggml_reshape_3d(ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); |
| k = ggml_reshape_3d(ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); |
|
|
| return {q, k, v}; |
| } |
|
|
| struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) { |
| GGML_ASSERT(!pre_only); |
|
|
| auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]); |
|
|
| x = proj->forward(ctx, x); |
| return x; |
| } |
|
|
| |
| struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { |
| auto qkv = pre_attention(ctx, x); |
| x = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); |
| x = post_attention(ctx, x); |
| return x; |
| } |
| }; |
|
|
| __STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx, |
| struct ggml_tensor* x, |
| struct ggml_tensor* shift, |
| struct ggml_tensor* scale) { |
| |
| |
| |
| scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); |
| shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); |
| x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); |
| x = ggml_add(ctx, x, shift); |
| return x; |
| } |
|
|
| struct DismantledBlock : public GGMLBlock { |
| |
| public: |
| int64_t num_heads; |
| bool pre_only; |
| bool self_attn; |
|
|
| public: |
| DismantledBlock(int64_t hidden_size, |
| int64_t num_heads, |
| float mlp_ratio = 4.0, |
| std::string qk_norm = "", |
| bool qkv_bias = false, |
| bool pre_only = false, |
| bool self_attn = false) |
| : num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) { |
| |
| |
| |
| blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false)); |
| blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only)); |
|
|
| if (self_attn) { |
| blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false)); |
| } |
|
|
| if (!pre_only) { |
| blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false)); |
| int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio); |
| blocks["mlp"] = std::shared_ptr<GGMLBlock>(new Mlp(hidden_size, mlp_hidden_dim)); |
| } |
|
|
| int64_t n_mods = 6; |
| if (pre_only) { |
| n_mods = 2; |
| } |
| if (self_attn) { |
| n_mods = 9; |
| } |
| blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, n_mods * hidden_size)); |
| } |
|
|
| std::tuple<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention_x(struct ggml_context* ctx, |
| struct ggml_tensor* x, |
| struct ggml_tensor* c) { |
| GGML_ASSERT(self_attn); |
| |
| |
| auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]); |
| auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]); |
| auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]); |
| auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); |
|
|
| int64_t n_mods = 9; |
| auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); |
| m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); |
| m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); |
|
|
| int64_t offset = m->nb[1] * m->ne[1]; |
| auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); |
| auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); |
| auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); |
|
|
| auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); |
| auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); |
| auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); |
|
|
| auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); |
| auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); |
| auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); |
|
|
| auto x_norm = norm1->forward(ctx, x); |
|
|
| auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa); |
| auto qkv = attn->pre_attention(ctx, attn_in); |
|
|
| auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2); |
| auto qkv2 = attn2->pre_attention(ctx, attn2_in); |
|
|
| return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}}; |
| } |
|
|
| std::pair<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention(struct ggml_context* ctx, |
| struct ggml_tensor* x, |
| struct ggml_tensor* c) { |
| |
| |
| auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]); |
| auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]); |
| auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); |
|
|
| int64_t n_mods = 6; |
| if (pre_only) { |
| n_mods = 2; |
| } |
| auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); |
| m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); |
| m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); |
|
|
| int64_t offset = m->nb[1] * m->ne[1]; |
| auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); |
| auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); |
| if (!pre_only) { |
| auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); |
| auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); |
| auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); |
| auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); |
|
|
| auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa); |
|
|
| auto qkv = attn->pre_attention(ctx, attn_in); |
|
|
| return {qkv, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp}}; |
| } else { |
| auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa); |
| auto qkv = attn->pre_attention(ctx, attn_in); |
|
|
| return {qkv, {NULL, NULL, NULL, NULL, NULL}}; |
| } |
| } |
|
|
| struct ggml_tensor* post_attention_x(struct ggml_context* ctx, |
| struct ggml_tensor* attn_out, |
| struct ggml_tensor* attn2_out, |
| struct ggml_tensor* x, |
| struct ggml_tensor* gate_msa, |
| struct ggml_tensor* shift_mlp, |
| struct ggml_tensor* scale_mlp, |
| struct ggml_tensor* gate_mlp, |
| struct ggml_tensor* gate_msa2) { |
| |
| |
| |
| |
| |
| |
| |
| GGML_ASSERT(!pre_only); |
|
|
| auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]); |
| auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]); |
| auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]); |
| auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]); |
|
|
| gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); |
| gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); |
| gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); |
|
|
| attn_out = attn->post_attention(ctx, attn_out); |
| attn2_out = attn2->post_attention(ctx, attn2_out); |
|
|
| x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa)); |
| x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2)); |
| auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); |
| x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp)); |
|
|
| return x; |
| } |
|
|
| struct ggml_tensor* post_attention(struct ggml_context* ctx, |
| struct ggml_tensor* attn_out, |
| struct ggml_tensor* x, |
| struct ggml_tensor* gate_msa, |
| struct ggml_tensor* shift_mlp, |
| struct ggml_tensor* scale_mlp, |
| struct ggml_tensor* gate_mlp) { |
| |
| |
| |
| |
| |
| |
| |
| GGML_ASSERT(!pre_only); |
|
|
| auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]); |
| auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]); |
| auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]); |
|
|
| gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); |
| gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); |
|
|
| attn_out = attn->post_attention(ctx, attn_out); |
|
|
| x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa)); |
| auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); |
| x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp)); |
|
|
| return x; |
| } |
|
|
| struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { |
| |
| |
| |
|
|
| auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]); |
| if (self_attn) { |
| auto qkv_intermediates = pre_attention_x(ctx, x, c); |
| |
| |
| |
| auto qkv = std::get<0>(qkv_intermediates); |
| auto qkv2 = std::get<1>(qkv_intermediates); |
| auto intermediates = std::get<2>(qkv_intermediates); |
|
|
| auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); |
| auto attn2_out = ggml_nn_attention_ext(ctx, qkv2[0], qkv2[1], qkv2[2], num_heads); |
| x = post_attention_x(ctx, |
| attn_out, |
| attn2_out, |
| intermediates[0], |
| intermediates[1], |
| intermediates[2], |
| intermediates[3], |
| intermediates[4], |
| intermediates[5]); |
| return x; |
| } else { |
| auto qkv_intermediates = pre_attention(ctx, x, c); |
| auto qkv = qkv_intermediates.first; |
| auto intermediates = qkv_intermediates.second; |
|
|
| auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); |
| x = post_attention(ctx, |
| attn_out, |
| intermediates[0], |
| intermediates[1], |
| intermediates[2], |
| intermediates[3], |
| intermediates[4]); |
| return x; |
| } |
| } |
| }; |
|
|
| __STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> |
| block_mixing(struct ggml_context* ctx, |
| struct ggml_tensor* context, |
| struct ggml_tensor* x, |
| struct ggml_tensor* c, |
| std::shared_ptr<DismantledBlock> context_block, |
| std::shared_ptr<DismantledBlock> x_block) { |
| |
| |
| |
| auto context_qkv_intermediates = context_block->pre_attention(ctx, context, c); |
| auto context_qkv = context_qkv_intermediates.first; |
| auto context_intermediates = context_qkv_intermediates.second; |
|
|
| std::vector<ggml_tensor*> x_qkv, x_qkv2, x_intermediates; |
|
|
| if (x_block->self_attn) { |
| auto x_qkv_intermediates = x_block->pre_attention_x(ctx, x, c); |
| x_qkv = std::get<0>(x_qkv_intermediates); |
| x_qkv2 = std::get<1>(x_qkv_intermediates); |
| x_intermediates = std::get<2>(x_qkv_intermediates); |
| } else { |
| auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c); |
| x_qkv = x_qkv_intermediates.first; |
| x_intermediates = x_qkv_intermediates.second; |
| } |
| std::vector<struct ggml_tensor*> qkv; |
| for (int i = 0; i < 3; i++) { |
| qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); |
| } |
|
|
| auto attn = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], x_block->num_heads); |
| attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); |
| auto context_attn = ggml_view_3d(ctx, |
| attn, |
| attn->ne[0], |
| attn->ne[1], |
| context->ne[1], |
| attn->nb[1], |
| attn->nb[2], |
| 0); |
| context_attn = ggml_cont(ctx, ggml_permute(ctx, context_attn, 0, 2, 1, 3)); |
| auto x_attn = ggml_view_3d(ctx, |
| attn, |
| attn->ne[0], |
| attn->ne[1], |
| x->ne[1], |
| attn->nb[1], |
| attn->nb[2], |
| attn->nb[2] * context->ne[1]); |
| x_attn = ggml_cont(ctx, ggml_permute(ctx, x_attn, 0, 2, 1, 3)); |
|
|
| if (!context_block->pre_only) { |
| context = context_block->post_attention(ctx, |
| context_attn, |
| context_intermediates[0], |
| context_intermediates[1], |
| context_intermediates[2], |
| context_intermediates[3], |
| context_intermediates[4]); |
| } else { |
| context = NULL; |
| } |
|
|
| if (x_block->self_attn) { |
| auto attn2 = ggml_nn_attention_ext(ctx, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); |
|
|
| x = x_block->post_attention_x(ctx, |
| x_attn, |
| attn2, |
| x_intermediates[0], |
| x_intermediates[1], |
| x_intermediates[2], |
| x_intermediates[3], |
| x_intermediates[4], |
| x_intermediates[5]); |
| } else { |
| x = x_block->post_attention(ctx, |
| x_attn, |
| x_intermediates[0], |
| x_intermediates[1], |
| x_intermediates[2], |
| x_intermediates[3], |
| x_intermediates[4]); |
| } |
|
|
| return {context, x}; |
| } |
|
|
| struct JointBlock : public GGMLBlock { |
| public: |
| JointBlock(int64_t hidden_size, |
| int64_t num_heads, |
| float mlp_ratio = 4.0, |
| std::string qk_norm = "", |
| bool qkv_bias = false, |
| bool pre_only = false, |
| bool self_attn_x = false) { |
| blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only)); |
| blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x)); |
| } |
|
|
| std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx, |
| struct ggml_tensor* context, |
| struct ggml_tensor* x, |
| struct ggml_tensor* c) { |
| auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]); |
| auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]); |
|
|
| return block_mixing(ctx, context, x, c, context_block, x_block); |
| } |
| }; |
|
|
| struct FinalLayer : public GGMLBlock { |
| |
| public: |
| FinalLayer(int64_t hidden_size, |
| int64_t patch_size, |
| int64_t out_channels) { |
| |
| blocks["norm_final"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false)); |
| blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels, true, true)); |
| blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size)); |
| } |
|
|
| struct ggml_tensor* forward(struct ggml_context* ctx, |
| struct ggml_tensor* x, |
| struct ggml_tensor* c) { |
| |
| |
| |
| auto norm_final = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_final"]); |
| auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]); |
| auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]); |
|
|
| auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); |
| m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); |
| m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); |
|
|
| int64_t offset = m->nb[1] * m->ne[1]; |
| auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); |
| auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); |
|
|
| x = modulate(ctx, norm_final->forward(ctx, x), shift, scale); |
| x = linear->forward(ctx, x); |
|
|
| return x; |
| } |
| }; |
|
|
| struct MMDiT : public GGMLBlock { |
| |
| protected: |
| int64_t input_size = -1; |
| int64_t patch_size = 2; |
| int64_t in_channels = 16; |
| int64_t d_self = -1; |
| int64_t depth = 24; |
| float mlp_ratio = 4.0f; |
| int64_t adm_in_channels = 2048; |
| int64_t out_channels = 16; |
| int64_t pos_embed_max_size = 192; |
| int64_t num_patchs = 36864; |
| int64_t context_size = 4096; |
| int64_t context_embedder_out_dim = 1536; |
| int64_t hidden_size; |
| std::string qk_norm; |
|
|
| void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") { |
| enum ggml_type wtype = GGML_TYPE_F32; |
| params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1); |
| } |
|
|
| public: |
| MMDiT(std::map<std::string, enum ggml_type>& tensor_types) { |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| for (auto pair : tensor_types) { |
| std::string tensor_name = pair.first; |
| if (tensor_name.find("model.diffusion_model.") == std::string::npos) |
| continue; |
| size_t jb = tensor_name.find("joint_blocks."); |
| if (jb != std::string::npos) { |
| tensor_name = tensor_name.substr(jb); |
| int block_depth = atoi(tensor_name.substr(13, tensor_name.find(".", 13)).c_str()); |
| if (block_depth + 1 > depth) { |
| depth = block_depth + 1; |
| } |
| if (tensor_name.find("attn.ln") != std::string::npos) { |
| if (tensor_name.find(".bias") != std::string::npos) { |
| qk_norm = "ln"; |
| } else { |
| qk_norm = "rms"; |
| } |
| } |
| if (tensor_name.find("attn2") != std::string::npos) { |
| if (block_depth > d_self) { |
| d_self = block_depth; |
| } |
| } |
| } |
| } |
|
|
| if (d_self >= 0) { |
| pos_embed_max_size *= 2; |
| num_patchs *= 4; |
| } |
|
|
| LOG_INFO("MMDiT layers: %d (including %d MMDiT-x layers)", depth, d_self + 1); |
|
|
| int64_t default_out_channels = in_channels; |
| hidden_size = 64 * depth; |
| context_embedder_out_dim = 64 * depth; |
| int64_t num_heads = depth; |
|
|
| blocks["x_embedder"] = std::shared_ptr<GGMLBlock>(new PatchEmbed(input_size, patch_size, in_channels, hidden_size, true)); |
| blocks["t_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedder(hidden_size)); |
|
|
| if (adm_in_channels != -1) { |
| blocks["y_embedder"] = std::shared_ptr<GGMLBlock>(new VectorEmbedder(adm_in_channels, hidden_size)); |
| } |
|
|
| blocks["context_embedder"] = std::shared_ptr<GGMLBlock>(new Linear(4096, context_embedder_out_dim, true, true)); |
|
|
| for (int i = 0; i < depth; i++) { |
| blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new JointBlock(hidden_size, |
| num_heads, |
| mlp_ratio, |
| qk_norm, |
| true, |
| i == depth - 1, |
| i <= d_self)); |
| } |
|
|
| blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels)); |
| } |
|
|
| struct ggml_tensor* |
| cropped_pos_embed(struct ggml_context* ctx, |
| int64_t h, |
| int64_t w) { |
| auto pos_embed = params["pos_embed"]; |
|
|
| h = (h + 1) / patch_size; |
| w = (w + 1) / patch_size; |
|
|
| GGML_ASSERT(h <= pos_embed_max_size && h > 0); |
| GGML_ASSERT(w <= pos_embed_max_size && w > 0); |
|
|
| int64_t top = (pos_embed_max_size - h) / 2; |
| int64_t left = (pos_embed_max_size - w) / 2; |
|
|
| auto spatial_pos_embed = ggml_reshape_3d(ctx, pos_embed, hidden_size, pos_embed_max_size, pos_embed_max_size); |
|
|
| |
| spatial_pos_embed = ggml_view_3d(ctx, |
| spatial_pos_embed, |
| hidden_size, |
| pos_embed_max_size, |
| h, |
| spatial_pos_embed->nb[1], |
| spatial_pos_embed->nb[2], |
| spatial_pos_embed->nb[2] * top); |
| spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); |
| spatial_pos_embed = ggml_view_3d(ctx, |
| spatial_pos_embed, |
| hidden_size, |
| h, |
| w, |
| spatial_pos_embed->nb[1], |
| spatial_pos_embed->nb[2], |
| spatial_pos_embed->nb[2] * left); |
| spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); |
| spatial_pos_embed = ggml_reshape_3d(ctx, spatial_pos_embed, hidden_size, h * w, 1); |
| return spatial_pos_embed; |
| } |
|
|
| struct ggml_tensor* unpatchify(struct ggml_context* ctx, |
| struct ggml_tensor* x, |
| int64_t h, |
| int64_t w) { |
| |
| |
| int64_t n = x->ne[2]; |
| int64_t c = out_channels; |
| int64_t p = patch_size; |
| h = (h + 1) / p; |
| w = (w + 1) / p; |
|
|
| GGML_ASSERT(h * w == x->ne[1]); |
|
|
| x = ggml_reshape_4d(ctx, x, c, p * p, w * h, n); |
| x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); |
| x = ggml_reshape_4d(ctx, x, p, p, w, h * c * n); |
| x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); |
| x = ggml_reshape_4d(ctx, x, p * w, p * h, c, n); |
| return x; |
| } |
|
|
| struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx, |
| struct ggml_tensor* x, |
| struct ggml_tensor* c_mod, |
| struct ggml_tensor* context, |
| std::vector<int> skip_layers = std::vector<int>()) { |
| |
| |
| |
| |
| auto final_layer = std::dynamic_pointer_cast<FinalLayer>(blocks["final_layer"]); |
|
|
| for (int i = 0; i < depth; i++) { |
| |
| if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { |
| continue; |
| } |
|
|
| auto block = std::dynamic_pointer_cast<JointBlock>(blocks["joint_blocks." + std::to_string(i)]); |
|
|
| auto context_x = block->forward(ctx, context, x, c_mod); |
| context = context_x.first; |
| x = context_x.second; |
| } |
|
|
| x = final_layer->forward(ctx, x, c_mod); |
|
|
| return x; |
| } |
|
|
| struct ggml_tensor* forward(struct ggml_context* ctx, |
| struct ggml_tensor* x, |
| struct ggml_tensor* t, |
| struct ggml_tensor* y = NULL, |
| struct ggml_tensor* context = NULL, |
| std::vector<int> skip_layers = std::vector<int>()) { |
| |
| |
| |
| |
| |
| |
| auto x_embedder = std::dynamic_pointer_cast<PatchEmbed>(blocks["x_embedder"]); |
| auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]); |
|
|
| int64_t w = x->ne[0]; |
| int64_t h = x->ne[1]; |
|
|
| auto patch_embed = x_embedder->forward(ctx, x); |
| auto pos_embed = cropped_pos_embed(ctx, h, w); |
| x = ggml_add(ctx, patch_embed, pos_embed); |
|
|
| auto c = t_embedder->forward(ctx, t); |
| if (y != NULL && adm_in_channels != -1) { |
| auto y_embedder = std::dynamic_pointer_cast<VectorEmbedder>(blocks["y_embedder"]); |
|
|
| y = y_embedder->forward(ctx, y); |
| c = ggml_add(ctx, c, y); |
| } |
|
|
| if (context != NULL) { |
| auto context_embedder = std::dynamic_pointer_cast<Linear>(blocks["context_embedder"]); |
|
|
| context = context_embedder->forward(ctx, context); |
| } |
|
|
| x = forward_core_with_concat(ctx, x, c, context, skip_layers); |
|
|
| x = unpatchify(ctx, x, h, w); |
|
|
| return x; |
| } |
| }; |
| struct MMDiTRunner : public GGMLRunner { |
| MMDiT mmdit; |
|
|
| static std::map<std::string, enum ggml_type> empty_tensor_types; |
|
|
| MMDiTRunner(ggml_backend_t backend, |
| std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types, |
| const std::string prefix = "") |
| : GGMLRunner(backend), mmdit(tensor_types) { |
| mmdit.init(params_ctx, tensor_types, prefix); |
| } |
|
|
| std::string get_desc() { |
| return "mmdit"; |
| } |
|
|
| void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) { |
| mmdit.get_param_tensors(tensors, prefix); |
| } |
|
|
| struct ggml_cgraph* build_graph(struct ggml_tensor* x, |
| struct ggml_tensor* timesteps, |
| struct ggml_tensor* context, |
| struct ggml_tensor* y, |
| std::vector<int> skip_layers = std::vector<int>()) { |
| struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, MMDIT_GRAPH_SIZE, false); |
|
|
| x = to_backend(x); |
| context = to_backend(context); |
| y = to_backend(y); |
| timesteps = to_backend(timesteps); |
|
|
| struct ggml_tensor* out = mmdit.forward(compute_ctx, |
| x, |
| timesteps, |
| y, |
| context, |
| skip_layers); |
|
|
| ggml_build_forward_expand(gf, out); |
|
|
| return gf; |
| } |
|
|
| void compute(int n_threads, |
| struct ggml_tensor* x, |
| struct ggml_tensor* timesteps, |
| struct ggml_tensor* context, |
| struct ggml_tensor* y, |
| struct ggml_tensor** output = NULL, |
| struct ggml_context* output_ctx = NULL, |
| std::vector<int> skip_layers = std::vector<int>()) { |
| |
| |
| |
| |
| auto get_graph = [&]() -> struct ggml_cgraph* { |
| return build_graph(x, timesteps, context, y, skip_layers); |
| }; |
|
|
| GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); |
| } |
|
|
| void test() { |
| struct ggml_init_params params; |
| params.mem_size = static_cast<size_t>(10 * 1024 * 1024); |
| params.mem_buffer = NULL; |
| params.no_alloc = false; |
|
|
| struct ggml_context* work_ctx = ggml_init(params); |
| GGML_ASSERT(work_ctx != NULL); |
|
|
| { |
| |
| |
| |
| |
| auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 128, 128, 16, 1); |
| std::vector<float> timesteps_vec(1, 999.f); |
| auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); |
| ggml_set_f32(x, 0.01f); |
| |
|
|
| auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 154, 1); |
| ggml_set_f32(context, 0.01f); |
| |
|
|
| auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 2048, 1); |
| ggml_set_f32(y, 0.01f); |
| |
|
|
| struct ggml_tensor* out = NULL; |
|
|
| int t0 = ggml_time_ms(); |
| compute(8, x, timesteps, context, y, &out, work_ctx); |
| int t1 = ggml_time_ms(); |
|
|
| print_ggml_tensor(out); |
| LOG_DEBUG("mmdit test done in %dms", t1 - t0); |
| } |
| } |
|
|
| static void load_from_file_and_test(const std::string& file_path) { |
| |
| ggml_backend_t backend = ggml_backend_cpu_init(); |
| ggml_type model_data_type = GGML_TYPE_F16; |
| std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend)); |
| { |
| LOG_INFO("loading from '%s'", file_path.c_str()); |
|
|
| mmdit->alloc_params_buffer(); |
| std::map<std::string, ggml_tensor*> tensors; |
| mmdit->get_param_tensors(tensors, "model.diffusion_model"); |
|
|
| ModelLoader model_loader; |
| if (!model_loader.init_from_file(file_path)) { |
| LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); |
| return; |
| } |
|
|
| bool success = model_loader.load_tensors(tensors, backend); |
|
|
| if (!success) { |
| LOG_ERROR("load tensors from model loader failed"); |
| return; |
| } |
|
|
| LOG_INFO("mmdit model loaded"); |
| } |
| mmdit->test(); |
| } |
| }; |
|
|
| #endif |