| #ifndef __ESRGAN_HPP__ |
| #define __ESRGAN_HPP__ |
|
|
| #include "ggml_extend.hpp" |
| #include "model.h" |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| class ResidualDenseBlock : public GGMLBlock { |
| protected: |
| int num_feat; |
| int num_grow_ch; |
|
|
| public: |
| ResidualDenseBlock(int num_feat = 64, int num_grow_ch = 32) |
| : num_feat(num_feat), num_grow_ch(num_grow_ch) { |
| blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_grow_ch, {3, 3}, {1, 1}, {1, 1})); |
| blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + num_grow_ch, num_grow_ch, {3, 3}, {1, 1}, {1, 1})); |
| blocks["conv3"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, {3, 3}, {1, 1}, {1, 1})); |
| blocks["conv4"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, {3, 3}, {1, 1}, {1, 1})); |
| blocks["conv5"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 4 * num_grow_ch, num_feat, {3, 3}, {1, 1}, {1, 1})); |
| } |
|
|
| struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) { |
| return ggml_leaky_relu(ctx, x, 0.2f, true); |
| } |
|
|
| struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { |
| |
| |
|
|
| auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv1"]); |
| auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv2"]); |
| auto conv3 = std::dynamic_pointer_cast<Conv2d>(blocks["conv3"]); |
| auto conv4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv4"]); |
| auto conv5 = std::dynamic_pointer_cast<Conv2d>(blocks["conv5"]); |
|
|
| auto x1 = lrelu(ctx, conv1->forward(ctx, x)); |
| auto x_cat = ggml_concat(ctx, x, x1, 2); |
| auto x2 = lrelu(ctx, conv2->forward(ctx, x_cat)); |
| x_cat = ggml_concat(ctx, x_cat, x2, 2); |
| auto x3 = lrelu(ctx, conv3->forward(ctx, x_cat)); |
| x_cat = ggml_concat(ctx, x_cat, x3, 2); |
| auto x4 = lrelu(ctx, conv4->forward(ctx, x_cat)); |
| x_cat = ggml_concat(ctx, x_cat, x4, 2); |
| auto x5 = conv5->forward(ctx, x_cat); |
|
|
| x5 = ggml_add(ctx, ggml_scale(ctx, x5, 0.2f), x); |
| return x5; |
| } |
| }; |
|
|
| class RRDB : public GGMLBlock { |
| public: |
| RRDB(int num_feat, int num_grow_ch = 32) { |
| blocks["rdb1"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch)); |
| blocks["rdb2"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch)); |
| blocks["rdb3"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch)); |
| } |
|
|
| struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { |
| |
| |
|
|
| auto rdb1 = std::dynamic_pointer_cast<ResidualDenseBlock>(blocks["rdb1"]); |
| auto rdb2 = std::dynamic_pointer_cast<ResidualDenseBlock>(blocks["rdb2"]); |
| auto rdb3 = std::dynamic_pointer_cast<ResidualDenseBlock>(blocks["rdb3"]); |
|
|
| auto out = rdb1->forward(ctx, x); |
| out = rdb2->forward(ctx, out); |
| out = rdb3->forward(ctx, out); |
|
|
| out = ggml_add(ctx, ggml_scale(ctx, out, 0.2f), x); |
| return out; |
| } |
| }; |
|
|
| class RRDBNet : public GGMLBlock { |
| protected: |
| int scale = 4; |
| int num_block = 6; |
| int num_in_ch = 3; |
| int num_out_ch = 3; |
| int num_feat = 64; |
| int num_grow_ch = 32; |
|
|
| public: |
| RRDBNet() { |
| blocks["conv_first"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_in_ch, num_feat, {3, 3}, {1, 1}, {1, 1})); |
| for (int i = 0; i < num_block; i++) { |
| std::string name = "body." + std::to_string(i); |
| blocks[name] = std::shared_ptr<GGMLBlock>(new RRDB(num_feat, num_grow_ch)); |
| } |
| blocks["conv_body"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1})); |
| |
| blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1})); |
| blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1})); |
| blocks["conv_hr"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1})); |
| blocks["conv_last"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_out_ch, {3, 3}, {1, 1}, {1, 1})); |
| } |
|
|
| struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) { |
| return ggml_leaky_relu(ctx, x, 0.2f, true); |
| } |
|
|
| struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { |
| |
| |
| auto conv_first = std::dynamic_pointer_cast<Conv2d>(blocks["conv_first"]); |
| auto conv_body = std::dynamic_pointer_cast<Conv2d>(blocks["conv_body"]); |
| auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]); |
| auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]); |
| auto conv_hr = std::dynamic_pointer_cast<Conv2d>(blocks["conv_hr"]); |
| auto conv_last = std::dynamic_pointer_cast<Conv2d>(blocks["conv_last"]); |
|
|
| auto feat = conv_first->forward(ctx, x); |
| auto body_feat = feat; |
| for (int i = 0; i < num_block; i++) { |
| std::string name = "body." + std::to_string(i); |
| auto block = std::dynamic_pointer_cast<RRDB>(blocks[name]); |
|
|
| body_feat = block->forward(ctx, body_feat); |
| } |
| body_feat = conv_body->forward(ctx, body_feat); |
| feat = ggml_add(ctx, feat, body_feat); |
| |
| feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2))); |
| feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2))); |
| auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat))); |
| return out; |
| } |
| }; |
|
|
| struct ESRGAN : public GGMLRunner { |
| RRDBNet rrdb_net; |
| int scale = 4; |
| int tile_size = 128; |
|
|
| ESRGAN(ggml_backend_t backend, std::map<std::string, enum ggml_type>& tensor_types) |
| : GGMLRunner(backend) { |
| rrdb_net.init(params_ctx, tensor_types, ""); |
| } |
|
|
| std::string get_desc() { |
| return "esrgan"; |
| } |
|
|
| bool load_from_file(const std::string& file_path) { |
| LOG_INFO("loading esrgan from '%s'", file_path.c_str()); |
|
|
| alloc_params_buffer(); |
| std::map<std::string, ggml_tensor*> esrgan_tensors; |
| rrdb_net.get_param_tensors(esrgan_tensors); |
|
|
| ModelLoader model_loader; |
| if (!model_loader.init_from_file(file_path)) { |
| LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str()); |
| return false; |
| } |
|
|
| bool success = model_loader.load_tensors(esrgan_tensors, backend); |
|
|
| if (!success) { |
| LOG_ERROR("load esrgan tensors from model loader failed"); |
| return false; |
| } |
|
|
| LOG_INFO("esrgan model loaded"); |
| return success; |
| } |
|
|
| struct ggml_cgraph* build_graph(struct ggml_tensor* x) { |
| struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); |
| x = to_backend(x); |
| struct ggml_tensor* out = rrdb_net.forward(compute_ctx, x); |
| ggml_build_forward_expand(gf, out); |
| return gf; |
| } |
|
|
| void compute(const int n_threads, |
| struct ggml_tensor* x, |
| ggml_tensor** output, |
| ggml_context* output_ctx = NULL) { |
| auto get_graph = [&]() -> struct ggml_cgraph* { |
| return build_graph(x); |
| }; |
| GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); |
| } |
| }; |
|
|
| #endif |