| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include "save_onnx.h" |
| |
|
| | #include "onnx.pb.h" |
| |
|
| | #include <string.h> |
| | #include <fstream> |
| | #include <iostream> |
| |
|
| | #include "utils.h" |
| |
|
| | namespace pnnx { |
| |
|
| | |
| | extern const char* get_operand_name(const Operand* x); |
| | extern const char* get_operator_type(const Operator* op); |
| | extern const char* get_operator_name(const Operator* op); |
| | extern std::vector<const char*> get_operator_params_keys(const Operator* op); |
| | extern std::vector<const char*> get_operator_attrs_keys(const Operator* op); |
| | extern const Parameter& get_operator_param(const Operator* op, const char* key); |
| | extern const Attribute& get_operator_attr(const Operator* op, const char* key); |
| | extern const char* get_param_s(const Parameter& p); |
| | extern std::vector<const char*> get_param_as(const Parameter& p); |
| |
|
| | int save_onnx(const Graph& g, const char* onnxpath, int fp16) |
| | { |
| | onnx::ModelProto model; |
| |
|
| | onnx::GraphProto* gp = model.mutable_graph(); |
| |
|
| | for (const Operand* x : g.operands) |
| | { |
| | onnx::ValueInfoProto* vip = gp->add_value_info(); |
| |
|
| | vip->set_name(get_operand_name(x)); |
| |
|
| | onnx::TypeProto* tp = vip->mutable_type(); |
| |
|
| | onnx::TypeProto_Tensor* tpt = tp->mutable_tensor_type(); |
| |
|
| | switch (x->type) |
| | { |
| | case 1: |
| | tpt->set_elem_type(fp16 ? 10 : 1); |
| | break; |
| | case 2: |
| | tpt->set_elem_type(fp16 ? 10 : 11); |
| | break; |
| | case 3: |
| | tpt->set_elem_type(10); |
| | break; |
| | case 4: |
| | tpt->set_elem_type(6); |
| | break; |
| | case 5: |
| | tpt->set_elem_type(7); |
| | break; |
| | case 6: |
| | tpt->set_elem_type(5); |
| | break; |
| | case 7: |
| | tpt->set_elem_type(3); |
| | break; |
| | case 8: |
| | tpt->set_elem_type(2); |
| | break; |
| | case 9: |
| | tpt->set_elem_type(9); |
| | break; |
| | case 10: |
| | tpt->set_elem_type(14); |
| | break; |
| | case 11: |
| | tpt->set_elem_type(15); |
| | break; |
| | case 12: |
| | tpt->set_elem_type(0); |
| | break; |
| | default: |
| | tpt->set_elem_type(0); |
| | break; |
| | } |
| |
|
| | onnx::TensorShapeProto* tsp = tpt->mutable_shape(); |
| |
|
| | for (auto s : x->shape) |
| | { |
| | onnx::TensorShapeProto_Dimension* tspd = tsp->add_dim(); |
| |
|
| | tspd->set_dim_value(s); |
| | } |
| | } |
| |
|
| | for (const Operator* op : g.ops) |
| | { |
| | onnx::NodeProto* np = gp->add_node(); |
| |
|
| | np->set_op_type(get_operator_type(op)); |
| | np->set_name(get_operator_name(op)); |
| |
|
| | for (const Operand* oprand : op->inputs) |
| | { |
| | np->add_input(get_operand_name(oprand)); |
| | } |
| |
|
| | for (const Operand* oprand : op->outputs) |
| | { |
| | np->add_output(get_operand_name(oprand)); |
| | } |
| |
|
| | std::vector<const char*> params_keys = get_operator_params_keys(op); |
| | for (const char* param_name : params_keys) |
| | { |
| | const Parameter& param = get_operator_param(op, param_name); |
| |
|
| | onnx::AttributeProto* ap = np->add_attribute(); |
| |
|
| | ap->set_name(param_name); |
| |
|
| | if (param.type == 0) |
| | { |
| | ap->set_s("None"); |
| | } |
| | if (param.type == 1) |
| | { |
| | if (param.b) |
| | ap->set_i(1); |
| | else |
| | ap->set_i(0); |
| | } |
| | if (param.type == 2) |
| | { |
| | ap->set_i(param.i); |
| | } |
| | if (param.type == 3) |
| | { |
| | ap->set_f(param.f); |
| | } |
| | if (param.type == 4) |
| | { |
| | ap->set_s(get_param_s(param)); |
| | } |
| | if (param.type == 5) |
| | { |
| | for (auto i : param.ai) |
| | { |
| | ap->add_ints(i); |
| | } |
| | } |
| | if (param.type == 6) |
| | { |
| | for (auto f : param.af) |
| | { |
| | ap->add_floats(f); |
| | } |
| | } |
| | if (param.type == 7) |
| | { |
| | std::vector<const char*> as = get_param_as(param); |
| | for (auto s : as) |
| | { |
| | ap->add_strings(s); |
| | } |
| | } |
| | } |
| |
|
| | std::vector<const char*> attrs_keys = get_operator_attrs_keys(op); |
| | for (const char* attr_name : attrs_keys) |
| | { |
| | onnx::TensorProto* tp = gp->add_initializer(); |
| |
|
| | tp->set_name(std::string(get_operator_name(op)) + "." + attr_name); |
| |
|
| | np->add_input(std::string(get_operator_name(op)) + "." + attr_name); |
| |
|
| | const Attribute& attr = get_operator_attr(op, attr_name); |
| | for (auto s : attr.shape) |
| | { |
| | tp->add_dims(s); |
| | } |
| |
|
| | switch (attr.type) |
| | { |
| | case 1: |
| | tp->set_data_type(fp16 ? 10 : 1); |
| | break; |
| | case 2: |
| | tp->set_data_type(fp16 ? 10 : 11); |
| | break; |
| | case 3: |
| | tp->set_data_type(10); |
| | break; |
| | case 4: |
| | tp->set_data_type(6); |
| | break; |
| | case 5: |
| | tp->set_data_type(7); |
| | break; |
| | case 6: |
| | tp->set_data_type(5); |
| | break; |
| | case 7: |
| | tp->set_data_type(3); |
| | break; |
| | case 8: |
| | tp->set_data_type(2); |
| | break; |
| | case 9: |
| | tp->set_data_type(9); |
| | break; |
| | case 10: |
| | tp->set_data_type(14); |
| | break; |
| | case 11: |
| | tp->set_data_type(15); |
| | break; |
| | case 12: |
| | tp->set_data_type(0); |
| | break; |
| | default: |
| | tp->set_data_type(0); |
| | break; |
| | } |
| |
|
| | std::string* d = tp->mutable_raw_data(); |
| | if (fp16 && attr.type == 1) |
| | { |
| | |
| | const float* p = (const float*)attr.data.data(); |
| | int len = attr.data.size() / 4; |
| | d->resize(len * 2); |
| | unsigned short* p_fp16 = (unsigned short*)d->data(); |
| | for (int i = 0; i < len; i++) |
| | { |
| | p_fp16[i] = float32_to_float16(p[i]); |
| | } |
| | } |
| | else if (fp16 && attr.type == 2) |
| | { |
| | |
| | const double* p = (const double*)attr.data.data(); |
| | int len = attr.data.size() / 4; |
| | d->resize(len); |
| | unsigned short* p_fp16 = (unsigned short*)d->data(); |
| | for (int i = 0; i < len; i++) |
| | { |
| | p_fp16[i] = float32_to_float16((float)p[i]); |
| | } |
| | } |
| | else |
| | { |
| | d->resize(attr.data.size()); |
| | memcpy((void*)d->data(), attr.data.data(), attr.data.size()); |
| | } |
| | } |
| | } |
| |
|
| | std::fstream output(onnxpath, std::ios::out | std::ios::trunc | std::ios::binary); |
| | if (!model.SerializeToOstream(&output)) |
| | { |
| | fprintf(stderr, "write onnx failed\n"); |
| | return -1; |
| | } |
| |
|
| | return 0; |
| | } |
| |
|
| | } |
| |
|