// Tencent is pleased to support the open source community by making ncnn available. // // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // // Unless required by applicable law or agreed to in writing, software distributed // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. #include "fuse_expression.h" #include #include "storezip.h" namespace pnnx { static bool operand_maybe_tensor(const Operand* operand) { const Operator* op = operand->producer; if (op->type == "prim::Constant") { const Parameter& param = op->params.at("value"); if (param.type == 0 || param.type == 1 || param.type == 2 || param.type == 3 || param.type == 4 || param.type == 10) { return false; } else { return true; } } if (op->type == "prim::NumToTensor") { return operand_maybe_tensor(op->inputs[0]); } if (op->type == "prim::ListConstruct") { return false; } if (op->type == "torch.unbind" && op->inputs[0]->shape.size() == 1) { return false; } if (op->type == "aten::size") { return false; } if (op->type == "aten::Int") { return operand_maybe_tensor(op->inputs[0]); } if (op->type == "aten::to" || op->type == "aten::detach") { return operand_maybe_tensor(op->inputs[0]); } if (op->type == "aten::ScalarImplicit") { return false; } if (op->type == "aten::abs" || op->type == "aten::acos" || op->type == "aten::acosh" || op->type == "aten::asin" || op->type == "aten::asinh" || op->type == "aten::atan" || op->type == "aten::atanh" || op->type == "aten::ceil" || op->type == "aten::cos" || op->type == "aten::cosh" || op->type == "aten::exp" || op->type == "aten::floor" || op->type == "aten::log" || op->type == "aten::log10" || op->type == "aten::neg" || op->type == "aten::reciprocal" || op->type == "aten::round" || op->type == "aten::rsqrt" || op->type == "aten::sign" || op->type == "aten::sin" || op->type == "aten::sinh" || op->type == "aten::sqrt" || op->type == "aten::square" || op->type == "aten::tan" || op->type == "aten::tanh" || op->type == "aten::trunc") { return operand_maybe_tensor(op->inputs[0]); } if (op->type == "aten::atan2" || op->type == "aten::div" || op->type == "aten::floor_divide" || op->type == "aten::fmod" || op->type == "aten::max" || op->type == "aten::maximum" || op->type == "aten::min" || op->type == "aten::minimum" || op->type == "aten::mul" || op->type == "aten::pow" || op->type == "aten::remainder") { return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]); } if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__" || op->type == "aten::__lshift__" || op->type == "aten::__rshift__") { return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]); } if (op->type == "aten::add" || op->type == "aten::sub" || op->type == "aten::rsub") { return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]) || operand_maybe_tensor(op->inputs[2]); } return true; } static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, std::vector& inputs, const std::set& foldable_constants, StoreZipReader& zip, bool checksubgraph = true) { // fprintf(stderr, "fuse_expression %s\n", operand->name.c_str()); Operator* op = operand->producer; if (checksubgraph && operand_maybe_tensor(operand)) { if (op->outputs.size() > 1 || op->outputs[0]->consumers.size() > 1) { goto DEFAULT; } } if (op->type == "prim::Constant") { const Parameter& param = op->params["value"]; // fprintf(stderr, "fuse_expression prim::Constant %d\n", param.type); if (param.type == 0) { expr += "None"; } else if (param.type == 1) { expr += param.b ? "True" : "False"; } else if (param.type == 2) { char tmp[32]; sprintf(tmp, "%d", param.i); expr += tmp; } else if (param.type == 3) { char tmp[32]; sprintf(tmp, "%e", param.f); expr += tmp; } else if (param.type == 4) { expr += param.s; } else if (param.type == 10) { char tmp[32]; sprintf(tmp, "%e%+ej", param.c.real(), param.c.imag()); expr += tmp; } else { goto DEFAULT; } } else if (op->type == "pnnx.Attribute") { // fprintf(stderr, "operand pnnx.Attribute %s\n", operand->name.c_str()); const Attribute& data = op->attrs["data"]; if (data.shape.size() == 1 && data.shape[0] == 1 && data.type != -1) { if (data.type == 0) { expr += "None"; } else if (data.type == 1) { char tmp[32]; sprintf(tmp, "%e", ((const float*)data.data.data())[0]); expr += tmp; } else if (data.type == 2) { char tmp[32]; sprintf(tmp, "%e", ((const double*)data.data.data())[0]); expr += tmp; } else if (data.type == 4) { char tmp[32]; sprintf(tmp, "%d", ((const int*)data.data.data())[0]); expr += tmp; } else if (data.type == 5) { int64_t v = ((const int64_t*)data.data.data())[0]; if (v == std::numeric_limits::max()) v = INT_MAX; if (v == std::numeric_limits::min()) v = INT_MIN; char tmp[32]; sprintf(tmp, "%d", (int)v); expr += tmp; } else if (data.type == 6) { char tmp[32]; sprintf(tmp, "%d", ((const short*)data.data.data())[0]); expr += tmp; } else if (data.type == 7) { char tmp[32]; sprintf(tmp, "%d", ((const signed char*)data.data.data())[0]); expr += tmp; } else if (data.type == 8) { char tmp[32]; sprintf(tmp, "%u", ((const unsigned char*)data.data.data())[0]); expr += tmp; } else if (data.type == 9) { expr += ((const char*)data.data.data())[0] ? "True" : "False"; } else { // unsupported type fprintf(stderr, "fuse expression got unsupported scalar type %d\n", data.type); } } else { goto DEFAULT; } } else if (op->type == "torch.unbind") { // track chain // pnnx.Attribute/foldable with 1-rank // torch.unbind to constant scalar Operand* operand2 = op->inputs[0]; if (operand2->producer->type == "pnnx.Attribute") { const Attribute& data = operand2->producer->attrs["data"]; if (data.shape.size() == 1 && data.type != -1) { // resolve scalar i int si = 0; for (size_t i = 0; i < op->outputs.size(); i++) { if (op->outputs[i] == operand) { si = (int)i; break; } } if (data.type == 0) { expr += "None"; } else if (data.type == 1) { char tmp[32]; sprintf(tmp, "%e", ((const float*)data.data.data())[si]); expr += tmp; } else if (data.type == 2) { char tmp[32]; sprintf(tmp, "%e", ((const double*)data.data.data())[si]); expr += tmp; } else if (data.type == 4) { char tmp[32]; sprintf(tmp, "%d", ((const int*)data.data.data())[si]); expr += tmp; } else if (data.type == 5) { int64_t v = ((const int64_t*)data.data.data())[si]; if (v == std::numeric_limits::max()) v = INT_MAX; if (v == std::numeric_limits::min()) v = INT_MIN; char tmp[32]; sprintf(tmp, "%d", (int)v); expr += tmp; } else if (data.type == 6) { char tmp[32]; sprintf(tmp, "%d", ((const short*)data.data.data())[si]); expr += tmp; } else if (data.type == 7) { char tmp[32]; sprintf(tmp, "%d", ((const signed char*)data.data.data())[si]); expr += tmp; } else if (data.type == 8) { char tmp[32]; sprintf(tmp, "%u", ((const unsigned char*)data.data.data())[si]); expr += tmp; } else if (data.type == 9) { expr += ((const char*)data.data.data())[si] ? "True" : "False"; } else { // unsupported type fprintf(stderr, "fuse expression got unsupported scalar type %d\n", data.type); goto DEFAULT; } return; } } goto DEFAULT; } else if (checksubgraph && operand_maybe_tensor(operand) && foldable_constants.find(operand->name) != foldable_constants.end()) { // fprintf(stderr, "operand_is_foldable %s\n", operand->name.c_str()); if (operand->shape.size() == 0 && operand->type != -1) { // fuse literal constant into expression if (operand->type == 0) { expr += "None"; } else if (operand->type == 1) { float v; zip.read_file(operand->name, (char*)&v); char tmp[32]; sprintf(tmp, "%e", v); expr += tmp; } else if (operand->type == 2) { double v; zip.read_file(operand->name, (char*)&v); char tmp[32]; sprintf(tmp, "%e", v); expr += tmp; } else if (operand->type == 4) { int v; zip.read_file(operand->name, (char*)&v); char tmp[32]; sprintf(tmp, "%d", v); expr += tmp; } else if (operand->type == 5) { int64_t v; zip.read_file(operand->name, (char*)&v); if (v == std::numeric_limits::max()) v = INT_MAX; if (v == std::numeric_limits::min()) v = INT_MIN; char tmp[32]; sprintf(tmp, "%ld", v); expr += tmp; } else if (operand->type == 6) { short v; zip.read_file(operand->name, (char*)&v); char tmp[32]; sprintf(tmp, "%d", v); expr += tmp; } else if (operand->type == 7) { signed char v; zip.read_file(operand->name, (char*)&v); char tmp[32]; sprintf(tmp, "%d", v); expr += tmp; } else if (operand->type == 8) { unsigned char v; zip.read_file(operand->name, (char*)&v); char tmp[32]; sprintf(tmp, "%u", v); expr += tmp; } else if (operand->type == 9) { char v; zip.read_file(operand->name, &v); expr += v ? "True" : "False"; } else { // fprintf(stderr, "unknown foldable literal %s %d\n", operand->name.c_str(), operand->type); auto it = std::find(inputs.begin(), inputs.end(), operand); if (it == inputs.end()) { // tensor char tmp[32]; sprintf(tmp, "@%d", (int)inputs.size()); expr += tmp; inputs.push_back(operand); } else { // tensor char tmp[32]; sprintf(tmp, "@%d", (int)(it - inputs.begin())); expr += tmp; } } } else { goto DEFAULT; } } else if (op->type == "prim::NumToTensor") { fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); } else if (op->type == "prim::ListConstruct") { expr += "["; for (int i = 0; i < (int)op->inputs.size() - 1; i++) { fuse_expression(graph, op->inputs[i], expr, inputs, foldable_constants, zip); expr += ","; } if (op->inputs.size() > 0) { fuse_expression(graph, op->inputs[op->inputs.size() - 1], expr, inputs, foldable_constants, zip); } expr += "]"; } else if (op->type == "aten::size") { expr += "size("; fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); expr += ","; fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants, zip); expr += ")"; } else if (op->type == "aten::Int") { expr += "int("; fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); expr += ")"; } else if (op->type == "Tensor.to") { bool noop_type_cast = (op->outputs[0]->type != -1) && (op->inputs[0]->type == op->outputs[0]->type); if (noop_type_cast) { fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); } else { goto DEFAULT; } } else if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit") { fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); } else if (op->type == "aten::abs" || op->type == "aten::acos" || op->type == "aten::acosh" || op->type == "aten::asin" || op->type == "aten::asinh" || op->type == "aten::atan" || op->type == "aten::atanh" || op->type == "aten::ceil" || op->type == "aten::cos" || op->type == "aten::cosh" || op->type == "aten::exp" || op->type == "aten::floor" || op->type == "aten::log" || op->type == "aten::log10" || op->type == "aten::neg" || op->type == "aten::reciprocal" || op->type == "aten::round" || op->type == "aten::rsqrt" || op->type == "aten::sign" || op->type == "aten::sin" || op->type == "aten::sinh" || op->type == "aten::sqrt" || op->type == "aten::square" || op->type == "aten::tan" || op->type == "aten::tanh" || op->type == "aten::trunc") { std::string mathop = op->type.substr(6); expr += mathop; expr += "("; fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); expr += ")"; } else if (op->type == "aten::atan2" || op->type == "aten::floor_divide" || op->type == "aten::fmod" || op->type == "aten::max" || op->type == "aten::maximum" || op->type == "aten::min" || op->type == "aten::minimum" || op->type == "aten::mul" || op->type == "aten::pow" || op->type == "aten::remainder") { std::string mathop = op->type.substr(6); expr += mathop; expr += "("; fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); expr += ","; fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants, zip); expr += ")"; } else if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__" || op->type == "aten::__lshift__" || op->type == "aten::__rshift__") { std::string mathop = op->type.substr(8, op->type.size() - 10); expr += mathop; expr += "("; fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); expr += ","; fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants, zip); expr += ")"; } else if (op->type == "aten::add" || op->type == "aten::sub") { std::string mathop = op->type.substr(6); expr += mathop; expr += "("; fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); expr += ","; std::string expr1; std::string expr2; fuse_expression(graph, op->inputs[1], expr1, inputs, foldable_constants, zip); fuse_expression(graph, op->inputs[2], expr2, inputs, foldable_constants, zip); if (expr2 == "1") { expr += expr1; } else { expr += ","; expr += "mul("; expr += expr1; expr += ","; expr += expr2; expr += ")"; } expr += ")"; } else if (op->type == "aten::rsub") { expr += "sub("; std::string expr1; std::string expr2; fuse_expression(graph, op->inputs[1], expr1, inputs, foldable_constants, zip); fuse_expression(graph, op->inputs[2], expr2, inputs, foldable_constants, zip); if (expr2 == "1") { expr += expr1; } else { expr += ","; expr += "mul("; expr += expr1; expr += ","; expr += expr2; expr += ")"; } expr += ","; fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); expr += ")"; } else if (op->type == "aten::div") { std::string rounding_mode; if (op->inputs.size() == 3) fuse_expression(graph, op->inputs[2], rounding_mode, inputs, foldable_constants, zip); if (rounding_mode == "trunc") { expr += "floor_divide"; } else { expr += "div"; } expr += "("; fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); expr += ","; fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants, zip); expr += ")"; } else { goto DEFAULT; } return; DEFAULT: auto it = std::find(inputs.begin(), inputs.end(), operand); if (it == inputs.end()) { // tensor char tmp[32]; sprintf(tmp, "@%d", (int)inputs.size()); expr += tmp; inputs.push_back(operand); } else { // tensor char tmp[32]; sprintf(tmp, "@%d", (int)(it - inputs.begin())); expr += tmp; } } void fuse_expression(Graph& graph, const std::set& foldable_constants, const std::string& foldable_constants_zippath) { StoreZipReader zip; zip.open(foldable_constants_zippath); int pnnx_expr_index = 0; for (;;) { bool need_fuse = false; // build expression via reverse order for (int i = (int)graph.ops.size() - 1; i >= 0; i--) { Operator* op = graph.ops[i]; if (op->type == "prim::Constant") { need_fuse = true; } if (op->type == "prim::NumToTensor") { need_fuse = true; } if (op->type == "prim::ListConstruct") { need_fuse = true; } if (op->type == "aten::size") { need_fuse = true; } if (op->type == "aten::Int") { need_fuse = true; } if (op->type == "Tensor.to") { // fuse noop type cast only bool noop_to = (op->outputs[0]->type != -1) && (op->inputs[0]->type == op->outputs[0]->type); need_fuse = noop_to; } if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit") { need_fuse = true; } if (op->type == "aten::abs" || op->type == "aten::acos" || op->type == "aten::acosh" || op->type == "aten::add" || op->type == "aten::asin" || op->type == "aten::asinh" || op->type == "aten::atan" || op->type == "aten::atanh" || op->type == "aten::atan2" || op->type == "aten::ceil" || op->type == "aten::cos" || op->type == "aten::cosh" || op->type == "aten::div" || op->type == "aten::exp" || op->type == "aten::floor" || op->type == "aten::floor_divide" || op->type == "aten::fmod" || op->type == "aten::log" || op->type == "aten::log10" || op->type == "aten::max" || op->type == "aten::maximum" || op->type == "aten::min" || op->type == "aten::minimum" || op->type == "aten::mul" || op->type == "aten::neg" || op->type == "aten::pow" || op->type == "aten::reciprocal" || op->type == "aten::remainder" || op->type == "aten::round" || op->type == "aten::rsqrt" || op->type == "aten::rsub" || op->type == "aten::sign" || op->type == "aten::sin" || op->type == "aten::sinh" || op->type == "aten::sqrt" || op->type == "aten::square" || op->type == "aten::sub" || op->type == "aten::tan" || op->type == "aten::tanh" || op->type == "aten::trunc") { need_fuse = true; } if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__" || op->type == "aten::__lshift__" || op->type == "aten::__rshift__") { need_fuse = true; } if (need_fuse) { std::string expr; std::vector inputs; fuse_expression(graph, op->outputs[0], expr, inputs, foldable_constants, zip, false); // fprintf(stderr, "expr = %s\n", expr.c_str()); // lets rewrite graph char name[32]; sprintf(name, "pnnx_expr_%d", pnnx_expr_index++); op->type = "pnnx.Expression"; op->name = name; op->params.clear(); op->attrs.clear(); op->params["expr"] = expr; // fix input output for (Operand* operand : op->inputs) { operand->consumers.erase(std::find(operand->consumers.begin(), operand->consumers.end(), op)); } op->inputs = inputs; for (Operand* operand : op->inputs) { operand->consumers.push_back(op); } break; } } if (!need_fuse) break; } zip.close(); } } // namespace pnnx