| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include "fuse_expression.h" |
| |
|
| | #include <algorithm> |
| |
|
| | #include "storezip.h" |
| |
|
| | namespace pnnx { |
| |
|
| | static bool operand_maybe_tensor(const Operand* operand) |
| | { |
| | const Operator* op = operand->producer; |
| |
|
| | if (op->type == "prim::Constant") |
| | { |
| | const Parameter& param = op->params.at("value"); |
| | if (param.type == 0 || param.type == 1 || param.type == 2 || param.type == 3 || param.type == 4 || param.type == 10) |
| | { |
| | return false; |
| | } |
| | else |
| | { |
| | return true; |
| | } |
| | } |
| |
|
| | if (op->type == "prim::NumToTensor") |
| | { |
| | return operand_maybe_tensor(op->inputs[0]); |
| | } |
| |
|
| | if (op->type == "prim::ListConstruct") |
| | { |
| | return false; |
| | } |
| |
|
| | if (op->type == "torch.unbind" && op->inputs[0]->shape.size() == 1) |
| | { |
| | return false; |
| | } |
| |
|
| | if (op->type == "aten::size") |
| | { |
| | return false; |
| | } |
| |
|
| | if (op->type == "aten::Int") |
| | { |
| | return operand_maybe_tensor(op->inputs[0]); |
| | } |
| |
|
| | if (op->type == "aten::to" || op->type == "aten::detach") |
| | { |
| | return operand_maybe_tensor(op->inputs[0]); |
| | } |
| |
|
| | if (op->type == "aten::ScalarImplicit") |
| | { |
| | return false; |
| | } |
| |
|
| | if (op->type == "aten::abs" |
| | || op->type == "aten::acos" |
| | || op->type == "aten::acosh" |
| | || op->type == "aten::asin" |
| | || op->type == "aten::asinh" |
| | || op->type == "aten::atan" |
| | || op->type == "aten::atanh" |
| | || op->type == "aten::ceil" |
| | || op->type == "aten::cos" |
| | || op->type == "aten::cosh" |
| | || op->type == "aten::exp" |
| | || op->type == "aten::floor" |
| | || op->type == "aten::log" |
| | || op->type == "aten::log10" |
| | || op->type == "aten::neg" |
| | || op->type == "aten::reciprocal" |
| | || op->type == "aten::round" |
| | || op->type == "aten::rsqrt" |
| | || op->type == "aten::sign" |
| | || op->type == "aten::sin" |
| | || op->type == "aten::sinh" |
| | || op->type == "aten::sqrt" |
| | || op->type == "aten::square" |
| | || op->type == "aten::tan" |
| | || op->type == "aten::tanh" |
| | || op->type == "aten::trunc") |
| | { |
| | return operand_maybe_tensor(op->inputs[0]); |
| | } |
| |
|
| | if (op->type == "aten::atan2" |
| | || op->type == "aten::div" |
| | || op->type == "aten::floor_divide" |
| | || op->type == "aten::fmod" |
| | || op->type == "aten::max" |
| | || op->type == "aten::maximum" |
| | || op->type == "aten::min" |
| | || op->type == "aten::minimum" |
| | || op->type == "aten::mul" |
| | || op->type == "aten::pow" |
| | || op->type == "aten::remainder") |
| | { |
| | return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]); |
| | } |
| |
|
| | if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__" || op->type == "aten::__lshift__" || op->type == "aten::__rshift__") |
| | { |
| | return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]); |
| | } |
| |
|
| | if (op->type == "aten::add" || op->type == "aten::sub" || op->type == "aten::rsub") |
| | { |
| | return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]) || operand_maybe_tensor(op->inputs[2]); |
| | } |
| |
|
| | return true; |
| | } |
| |
|
| | static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, std::vector<Operand*>& inputs, const std::set<std::string>& foldable_constants, StoreZipReader& zip, bool checksubgraph = true) |
| | { |
| | |
| |
|
| | Operator* op = operand->producer; |
| |
|
| | if (checksubgraph && operand_maybe_tensor(operand)) |
| | { |
| | if (op->outputs.size() > 1 || op->outputs[0]->consumers.size() > 1) |
| | { |
| | goto DEFAULT; |
| | } |
| | } |
| |
|
| | if (op->type == "prim::Constant") |
| | { |
| | const Parameter& param = op->params["value"]; |
| | |
| | if (param.type == 0) |
| | { |
| | expr += "None"; |
| | } |
| | else if (param.type == 1) |
| | { |
| | expr += param.b ? "True" : "False"; |
| | } |
| | else if (param.type == 2) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%d", param.i); |
| | expr += tmp; |
| | } |
| | else if (param.type == 3) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%e", param.f); |
| | expr += tmp; |
| | } |
| | else if (param.type == 4) |
| | { |
| | expr += param.s; |
| | } |
| | else if (param.type == 10) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%e%+ej", param.c.real(), param.c.imag()); |
| | expr += tmp; |
| | } |
| | else |
| | { |
| | goto DEFAULT; |
| | } |
| | } |
| | else if (op->type == "pnnx.Attribute") |
| | { |
| | |
| |
|
| | const Attribute& data = op->attrs["data"]; |
| | if (data.shape.size() == 1 && data.shape[0] == 1 && data.type != -1) |
| | { |
| | if (data.type == 0) |
| | { |
| | expr += "None"; |
| | } |
| | else if (data.type == 1) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%e", ((const float*)data.data.data())[0]); |
| | expr += tmp; |
| | } |
| | else if (data.type == 2) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%e", ((const double*)data.data.data())[0]); |
| | expr += tmp; |
| | } |
| | else if (data.type == 4) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%d", ((const int*)data.data.data())[0]); |
| | expr += tmp; |
| | } |
| | else if (data.type == 5) |
| | { |
| | int64_t v = ((const int64_t*)data.data.data())[0]; |
| | if (v == std::numeric_limits<int64_t>::max()) v = INT_MAX; |
| | if (v == std::numeric_limits<int64_t>::min()) v = INT_MIN; |
| |
|
| | char tmp[32]; |
| | sprintf(tmp, "%d", (int)v); |
| | expr += tmp; |
| | } |
| | else if (data.type == 6) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%d", ((const short*)data.data.data())[0]); |
| | expr += tmp; |
| | } |
| | else if (data.type == 7) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%d", ((const signed char*)data.data.data())[0]); |
| | expr += tmp; |
| | } |
| | else if (data.type == 8) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%u", ((const unsigned char*)data.data.data())[0]); |
| | expr += tmp; |
| | } |
| | else if (data.type == 9) |
| | { |
| | expr += ((const char*)data.data.data())[0] ? "True" : "False"; |
| | } |
| | else |
| | { |
| | |
| | fprintf(stderr, "fuse expression got unsupported scalar type %d\n", data.type); |
| | } |
| | } |
| | else |
| | { |
| | goto DEFAULT; |
| | } |
| | } |
| | else if (op->type == "torch.unbind") |
| | { |
| | |
| | |
| | |
| | Operand* operand2 = op->inputs[0]; |
| | if (operand2->producer->type == "pnnx.Attribute") |
| | { |
| | const Attribute& data = operand2->producer->attrs["data"]; |
| |
|
| | if (data.shape.size() == 1 && data.type != -1) |
| | { |
| | |
| | int si = 0; |
| | for (size_t i = 0; i < op->outputs.size(); i++) |
| | { |
| | if (op->outputs[i] == operand) |
| | { |
| | si = (int)i; |
| | break; |
| | } |
| | } |
| |
|
| | if (data.type == 0) |
| | { |
| | expr += "None"; |
| | } |
| | else if (data.type == 1) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%e", ((const float*)data.data.data())[si]); |
| | expr += tmp; |
| | } |
| | else if (data.type == 2) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%e", ((const double*)data.data.data())[si]); |
| | expr += tmp; |
| | } |
| | else if (data.type == 4) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%d", ((const int*)data.data.data())[si]); |
| | expr += tmp; |
| | } |
| | else if (data.type == 5) |
| | { |
| | int64_t v = ((const int64_t*)data.data.data())[si]; |
| | if (v == std::numeric_limits<int64_t>::max()) v = INT_MAX; |
| | if (v == std::numeric_limits<int64_t>::min()) v = INT_MIN; |
| |
|
| | char tmp[32]; |
| | sprintf(tmp, "%d", (int)v); |
| | expr += tmp; |
| | } |
| | else if (data.type == 6) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%d", ((const short*)data.data.data())[si]); |
| | expr += tmp; |
| | } |
| | else if (data.type == 7) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%d", ((const signed char*)data.data.data())[si]); |
| | expr += tmp; |
| | } |
| | else if (data.type == 8) |
| | { |
| | char tmp[32]; |
| | sprintf(tmp, "%u", ((const unsigned char*)data.data.data())[si]); |
| | expr += tmp; |
| | } |
| | else if (data.type == 9) |
| | { |
| | expr += ((const char*)data.data.data())[si] ? "True" : "False"; |
| | } |
| | else |
| | { |
| | |
| | fprintf(stderr, "fuse expression got unsupported scalar type %d\n", data.type); |
| | goto DEFAULT; |
| | } |
| | return; |
| | } |
| | } |
| |
|
| | goto DEFAULT; |
| | } |
| | else if (checksubgraph && operand_maybe_tensor(operand) && foldable_constants.find(operand->name) != foldable_constants.end()) |
| | { |
| | |
| |
|
| | if (operand->shape.size() == 0 && operand->type != -1) |
| | { |
| | |
| | if (operand->type == 0) |
| | { |
| | expr += "None"; |
| | } |
| | else if (operand->type == 1) |
| | { |
| | float v; |
| | zip.read_file(operand->name, (char*)&v); |
| |
|
| | char tmp[32]; |
| | sprintf(tmp, "%e", v); |
| | expr += tmp; |
| | } |
| | else if (operand->type == 2) |
| | { |
| | double v; |
| | zip.read_file(operand->name, (char*)&v); |
| |
|
| | char tmp[32]; |
| | sprintf(tmp, "%e", v); |
| | expr += tmp; |
| | } |
| | else if (operand->type == 4) |
| | { |
| | int v; |
| | zip.read_file(operand->name, (char*)&v); |
| |
|
| | char tmp[32]; |
| | sprintf(tmp, "%d", v); |
| | expr += tmp; |
| | } |
| | else if (operand->type == 5) |
| | { |
| | int64_t v; |
| | zip.read_file(operand->name, (char*)&v); |
| |
|
| | if (v == std::numeric_limits<int64_t>::max()) v = INT_MAX; |
| | if (v == std::numeric_limits<int64_t>::min()) v = INT_MIN; |
| |
|
| | char tmp[32]; |
| | sprintf(tmp, "%ld", v); |
| | expr += tmp; |
| | } |
| | else if (operand->type == 6) |
| | { |
| | short v; |
| | zip.read_file(operand->name, (char*)&v); |
| |
|
| | char tmp[32]; |
| | sprintf(tmp, "%d", v); |
| | expr += tmp; |
| | } |
| | else if (operand->type == 7) |
| | { |
| | signed char v; |
| | zip.read_file(operand->name, (char*)&v); |
| |
|
| | char tmp[32]; |
| | sprintf(tmp, "%d", v); |
| | expr += tmp; |
| | } |
| | else if (operand->type == 8) |
| | { |
| | unsigned char v; |
| | zip.read_file(operand->name, (char*)&v); |
| |
|
| | char tmp[32]; |
| | sprintf(tmp, "%u", v); |
| | expr += tmp; |
| | } |
| | else if (operand->type == 9) |
| | { |
| | char v; |
| | zip.read_file(operand->name, &v); |
| |
|
| | expr += v ? "True" : "False"; |
| | } |
| | else |
| | { |
| | |
| | auto it = std::find(inputs.begin(), inputs.end(), operand); |
| | if (it == inputs.end()) |
| | { |
| | |
| | char tmp[32]; |
| | sprintf(tmp, "@%d", (int)inputs.size()); |
| | expr += tmp; |
| |
|
| | inputs.push_back(operand); |
| | } |
| | else |
| | { |
| | |
| | char tmp[32]; |
| | sprintf(tmp, "@%d", (int)(it - inputs.begin())); |
| | expr += tmp; |
| | } |
| | } |
| | } |
| | else |
| | { |
| | goto DEFAULT; |
| | } |
| | } |
| | else if (op->type == "prim::NumToTensor") |
| | { |
| | fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); |
| | } |
| | else if (op->type == "prim::ListConstruct") |
| | { |
| | expr += "["; |
| | for (int i = 0; i < (int)op->inputs.size() - 1; i++) |
| | { |
| | fuse_expression(graph, op->inputs[i], expr, inputs, foldable_constants, zip); |
| | expr += ","; |
| | } |
| | if (op->inputs.size() > 0) |
| | { |
| | fuse_expression(graph, op->inputs[op->inputs.size() - 1], expr, inputs, foldable_constants, zip); |
| | } |
| | expr += "]"; |
| | } |
| | else if (op->type == "aten::size") |
| | { |
| | expr += "size("; |
| | fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); |
| | expr += ","; |
| | fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants, zip); |
| | expr += ")"; |
| | } |
| | else if (op->type == "aten::Int") |
| | { |
| | expr += "int("; |
| | fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); |
| | expr += ")"; |
| | } |
| | else if (op->type == "Tensor.to") |
| | { |
| | bool noop_type_cast = (op->outputs[0]->type != -1) && (op->inputs[0]->type == op->outputs[0]->type); |
| | if (noop_type_cast) |
| | { |
| | fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); |
| | } |
| | else |
| | { |
| | goto DEFAULT; |
| | } |
| | } |
| | else if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit") |
| | { |
| | fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); |
| | } |
| | else if (op->type == "aten::abs" |
| | || op->type == "aten::acos" |
| | || op->type == "aten::acosh" |
| | || op->type == "aten::asin" |
| | || op->type == "aten::asinh" |
| | || op->type == "aten::atan" |
| | || op->type == "aten::atanh" |
| | || op->type == "aten::ceil" |
| | || op->type == "aten::cos" |
| | || op->type == "aten::cosh" |
| | || op->type == "aten::exp" |
| | || op->type == "aten::floor" |
| | || op->type == "aten::log" |
| | || op->type == "aten::log10" |
| | || op->type == "aten::neg" |
| | || op->type == "aten::reciprocal" |
| | || op->type == "aten::round" |
| | || op->type == "aten::rsqrt" |
| | || op->type == "aten::sign" |
| | || op->type == "aten::sin" |
| | || op->type == "aten::sinh" |
| | || op->type == "aten::sqrt" |
| | || op->type == "aten::square" |
| | || op->type == "aten::tan" |
| | || op->type == "aten::tanh" |
| | || op->type == "aten::trunc") |
| | { |
| | std::string mathop = op->type.substr(6); |
| |
|
| | expr += mathop; |
| | expr += "("; |
| | fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); |
| | expr += ")"; |
| | } |
| | else if (op->type == "aten::atan2" |
| | || op->type == "aten::floor_divide" |
| | || op->type == "aten::fmod" |
| | || op->type == "aten::max" |
| | || op->type == "aten::maximum" |
| | || op->type == "aten::min" |
| | || op->type == "aten::minimum" |
| | || op->type == "aten::mul" |
| | || op->type == "aten::pow" |
| | || op->type == "aten::remainder") |
| | { |
| | std::string mathop = op->type.substr(6); |
| |
|
| | expr += mathop; |
| | expr += "("; |
| | fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); |
| | expr += ","; |
| | fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants, zip); |
| | expr += ")"; |
| | } |
| | else if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__" || op->type == "aten::__lshift__" || op->type == "aten::__rshift__") |
| | { |
| | std::string mathop = op->type.substr(8, op->type.size() - 10); |
| |
|
| | expr += mathop; |
| | expr += "("; |
| | fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); |
| | expr += ","; |
| | fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants, zip); |
| | expr += ")"; |
| | } |
| | else if (op->type == "aten::add" || op->type == "aten::sub") |
| | { |
| | std::string mathop = op->type.substr(6); |
| |
|
| | expr += mathop; |
| | expr += "("; |
| | fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); |
| | expr += ","; |
| |
|
| | std::string expr1; |
| | std::string expr2; |
| | fuse_expression(graph, op->inputs[1], expr1, inputs, foldable_constants, zip); |
| | fuse_expression(graph, op->inputs[2], expr2, inputs, foldable_constants, zip); |
| |
|
| | if (expr2 == "1") |
| | { |
| | expr += expr1; |
| | } |
| | else |
| | { |
| | expr += ","; |
| | expr += "mul("; |
| | expr += expr1; |
| | expr += ","; |
| | expr += expr2; |
| | expr += ")"; |
| | } |
| |
|
| | expr += ")"; |
| | } |
| | else if (op->type == "aten::rsub") |
| | { |
| | expr += "sub("; |
| | std::string expr1; |
| | std::string expr2; |
| | fuse_expression(graph, op->inputs[1], expr1, inputs, foldable_constants, zip); |
| | fuse_expression(graph, op->inputs[2], expr2, inputs, foldable_constants, zip); |
| |
|
| | if (expr2 == "1") |
| | { |
| | expr += expr1; |
| | } |
| | else |
| | { |
| | expr += ","; |
| | expr += "mul("; |
| | expr += expr1; |
| | expr += ","; |
| | expr += expr2; |
| | expr += ")"; |
| | } |
| |
|
| | expr += ","; |
| | fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); |
| | expr += ")"; |
| | } |
| | else if (op->type == "aten::div") |
| | { |
| | std::string rounding_mode; |
| | if (op->inputs.size() == 3) |
| | fuse_expression(graph, op->inputs[2], rounding_mode, inputs, foldable_constants, zip); |
| |
|
| | if (rounding_mode == "trunc") |
| | { |
| | expr += "floor_divide"; |
| | } |
| | else |
| | { |
| | expr += "div"; |
| | } |
| |
|
| | expr += "("; |
| | fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); |
| | expr += ","; |
| | fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants, zip); |
| | expr += ")"; |
| | } |
| | else |
| | { |
| | goto DEFAULT; |
| | } |
| |
|
| | return; |
| |
|
| | DEFAULT: |
| | auto it = std::find(inputs.begin(), inputs.end(), operand); |
| | if (it == inputs.end()) |
| | { |
| | |
| | char tmp[32]; |
| | sprintf(tmp, "@%d", (int)inputs.size()); |
| | expr += tmp; |
| |
|
| | inputs.push_back(operand); |
| | } |
| | else |
| | { |
| | |
| | char tmp[32]; |
| | sprintf(tmp, "@%d", (int)(it - inputs.begin())); |
| | expr += tmp; |
| | } |
| | } |
| |
|
| | void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath) |
| | { |
| | StoreZipReader zip; |
| | zip.open(foldable_constants_zippath); |
| |
|
| | int pnnx_expr_index = 0; |
| |
|
| | for (;;) |
| | { |
| | bool need_fuse = false; |
| |
|
| | |
| | for (int i = (int)graph.ops.size() - 1; i >= 0; i--) |
| | { |
| | Operator* op = graph.ops[i]; |
| |
|
| | if (op->type == "prim::Constant") |
| | { |
| | need_fuse = true; |
| | } |
| | if (op->type == "prim::NumToTensor") |
| | { |
| | need_fuse = true; |
| | } |
| | if (op->type == "prim::ListConstruct") |
| | { |
| | need_fuse = true; |
| | } |
| | if (op->type == "aten::size") |
| | { |
| | need_fuse = true; |
| | } |
| | if (op->type == "aten::Int") |
| | { |
| | need_fuse = true; |
| | } |
| | if (op->type == "Tensor.to") |
| | { |
| | |
| | bool noop_to = (op->outputs[0]->type != -1) && (op->inputs[0]->type == op->outputs[0]->type); |
| | need_fuse = noop_to; |
| | } |
| | if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit") |
| | { |
| | need_fuse = true; |
| | } |
| | if (op->type == "aten::abs" |
| | || op->type == "aten::acos" |
| | || op->type == "aten::acosh" |
| | || op->type == "aten::add" |
| | || op->type == "aten::asin" |
| | || op->type == "aten::asinh" |
| | || op->type == "aten::atan" |
| | || op->type == "aten::atanh" |
| | || op->type == "aten::atan2" |
| | || op->type == "aten::ceil" |
| | || op->type == "aten::cos" |
| | || op->type == "aten::cosh" |
| | || op->type == "aten::div" |
| | || op->type == "aten::exp" |
| | || op->type == "aten::floor" |
| | || op->type == "aten::floor_divide" |
| | || op->type == "aten::fmod" |
| | || op->type == "aten::log" |
| | || op->type == "aten::log10" |
| | || op->type == "aten::max" |
| | || op->type == "aten::maximum" |
| | || op->type == "aten::min" |
| | || op->type == "aten::minimum" |
| | || op->type == "aten::mul" |
| | || op->type == "aten::neg" |
| | || op->type == "aten::pow" |
| | || op->type == "aten::reciprocal" |
| | || op->type == "aten::remainder" |
| | || op->type == "aten::round" |
| | || op->type == "aten::rsqrt" |
| | || op->type == "aten::rsub" |
| | || op->type == "aten::sign" |
| | || op->type == "aten::sin" |
| | || op->type == "aten::sinh" |
| | || op->type == "aten::sqrt" |
| | || op->type == "aten::square" |
| | || op->type == "aten::sub" |
| | || op->type == "aten::tan" |
| | || op->type == "aten::tanh" |
| | || op->type == "aten::trunc") |
| | { |
| | need_fuse = true; |
| | } |
| | if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__" || op->type == "aten::__lshift__" || op->type == "aten::__rshift__") |
| | { |
| | need_fuse = true; |
| | } |
| |
|
| | if (need_fuse) |
| | { |
| | std::string expr; |
| | std::vector<Operand*> inputs; |
| | fuse_expression(graph, op->outputs[0], expr, inputs, foldable_constants, zip, false); |
| | |
| |
|
| | |
| | char name[32]; |
| | sprintf(name, "pnnx_expr_%d", pnnx_expr_index++); |
| |
|
| | op->type = "pnnx.Expression"; |
| | op->name = name; |
| |
|
| | op->params.clear(); |
| | op->attrs.clear(); |
| |
|
| | op->params["expr"] = expr; |
| |
|
| | |
| | for (Operand* operand : op->inputs) |
| | { |
| | operand->consumers.erase(std::find(operand->consumers.begin(), operand->consumers.end(), op)); |
| | } |
| |
|
| | op->inputs = inputs; |
| |
|
| | for (Operand* operand : op->inputs) |
| | { |
| | operand->consumers.push_back(op); |
| | } |
| |
|
| | break; |
| | } |
| | } |
| |
|
| | if (!need_fuse) |
| | break; |
| | } |
| |
|
| | zip.close(); |
| | } |
| |
|
| | } |
| |
|