| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include "fuse_index_expression.h" |
| |
|
| | #include <algorithm> |
| |
|
| | namespace pnnx { |
| |
|
| | static void replaceAll(std::string& str, const std::string& from, const std::string& to) |
| | { |
| | size_t start_pos = 0; |
| | while ((start_pos = str.find(from, start_pos)) != std::string::npos) |
| | { |
| | str.replace(start_pos, from.length(), to); |
| | start_pos += to.length(); |
| | } |
| | } |
| |
|
| | static std::string fuse_attribute_expression(Operator* op_expr) |
| | { |
| | std::string expr = op_expr->params["expr"].s; |
| |
|
| | for (int i = (int)op_expr->inputs.size() - 1; i >= 0; i--) |
| | { |
| | Operator* op_attr = op_expr->inputs[i]->producer; |
| |
|
| | const Attribute& attr = op_attr->attrs.begin()->second; |
| |
|
| | std::string attr_expr; |
| |
|
| | int count = attr.shape[0]; |
| |
|
| | if (attr.type == 9) |
| | { |
| | |
| | const char* pdata = (const char*)attr.data.data(); |
| | attr_expr += "["; |
| | for (int j = 0; j < count; j++) |
| | { |
| | const char* ls = pdata[j] != 0 ? "True" : "False"; |
| | attr_expr += ls; |
| |
|
| | if (j != count - 1) |
| | attr_expr += ","; |
| | } |
| | attr_expr += "]"; |
| | } |
| | else if (attr.type == 5) |
| | { |
| | |
| | const int64_t* pdata = (const int64_t*)attr.data.data(); |
| | attr_expr += "["; |
| | for (int j = 0; j < count; j++) |
| | { |
| | int64_t n = pdata[j]; |
| | attr_expr += std::to_string(n); |
| |
|
| | if (j != count - 1) |
| | attr_expr += ","; |
| | } |
| | attr_expr += "]"; |
| | } |
| | else |
| | { |
| | fprintf(stderr, "unsupported index expression input %d attr type %d\n", i, attr.type); |
| | } |
| |
|
| | |
| | replaceAll(expr, "@" + std::to_string(i), attr_expr); |
| | } |
| |
|
| | return expr; |
| | } |
| |
|
| | void fuse_index_expression(Graph& graph) |
| | { |
| | while (1) |
| | { |
| | bool matched = false; |
| |
|
| | for (size_t i = 0; i < graph.ops.size(); i++) |
| | { |
| | Operator* op = graph.ops[i]; |
| |
|
| | if (op->type != "Tensor.index") |
| | continue; |
| |
|
| | if (op->inputs.size() != 2) |
| | continue; |
| |
|
| | if (op->inputs[1]->consumers.size() != 1) |
| | continue; |
| |
|
| | Operator* op2 = op->inputs[1]->producer; |
| | if (op2->type != "pnnx.Expression") |
| | continue; |
| |
|
| | bool all_inputs_fusable = true; |
| | for (Operand* a : op2->inputs) |
| | { |
| | if (a->producer->type != "pnnx.Attribute") |
| | { |
| | all_inputs_fusable = false; |
| | break; |
| | } |
| | } |
| | if (!all_inputs_fusable) |
| | continue; |
| |
|
| | matched = true; |
| |
|
| | std::string expr = fuse_attribute_expression(op2); |
| |
|
| | op->params["expr"] = expr; |
| |
|
| | op->inputs[1]->producer = 0; |
| | op->inputs[1]->remove_consumer(op); |
| |
|
| | op->inputs.resize(1); |
| |
|
| | for (auto& x : op2->inputs) |
| | { |
| | x->remove_consumer(op2); |
| | } |
| |
|
| | op2->inputs.clear(); |
| | op2->outputs.clear(); |
| |
|
| | graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); |
| |
|
| | delete op2; |
| |
|
| | break; |
| | } |
| |
|
| | if (!matched) |
| | break; |
| | } |
| | } |
| |
|
| | } |
| |
|