| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include "eliminate_reshape_shape_expression.h" |
| |
|
| | #include <iostream> |
| | #include <sstream> |
| | #include <algorithm> |
| | #include <stack> |
| | #include <vector> |
| | #include <string> |
| |
|
| | namespace pnnx { |
| |
|
| | static bool token_is_interger_literal(const std::string& t) |
| | { |
| | std::istringstream iss(t); |
| | int f; |
| | iss >> std::noskipws >> f; |
| | return iss.eof() && !iss.fail(); |
| | } |
| |
|
| | static void build_shape(const std::string& expr, std::vector<int>& shape, std::vector<std::string>& expr_tokens) |
| | { |
| | std::string listexpr = expr.substr(1, expr.size() - 2); |
| |
|
| | std::string t; |
| | std::string et; |
| | int level = 0; |
| | for (size_t i = 0; i < listexpr.size(); i++) |
| | { |
| | char ch = listexpr[i]; |
| |
|
| | if (ch == '(' || ch == '[') |
| | { |
| | level += 1; |
| | t = "-1"; |
| | et += ch; |
| | } |
| | else if (ch == ')' || ch == ']') |
| | { |
| | level -= 1; |
| | t = "-1"; |
| | et += ch; |
| | } |
| | else if (level == 0 && ch == ',') |
| | { |
| | int dimsize = token_is_interger_literal(t) ? std::stoi(t) : -1; |
| | shape.push_back(dimsize); |
| | expr_tokens.push_back(et); |
| | t.clear(); |
| | et.clear(); |
| | } |
| | else |
| | { |
| | t += ch; |
| | et += ch; |
| | } |
| | } |
| |
|
| | if (level == 0 && !t.empty()) |
| | { |
| | int dimsize = token_is_interger_literal(t) ? std::stoi(t) : -1; |
| | shape.push_back(dimsize); |
| | } |
| |
|
| | if (level == 0 && !et.empty()) |
| | { |
| | expr_tokens.push_back(et); |
| | } |
| | } |
| |
|
| | static std::string build_expr(const std::vector<std::string>& expr_tokens) |
| | { |
| | std::string expr; |
| |
|
| | expr += '['; |
| | for (int i = 0; i < (int)expr_tokens.size(); i++) |
| | { |
| | expr += expr_tokens[i]; |
| | if (i != (int)expr_tokens.size() - 1) |
| | expr += ','; |
| | } |
| | expr += ']'; |
| |
|
| | return expr; |
| | } |
| |
|
| | void eliminate_reshape_shape_expression(Graph& graph) |
| | { |
| | while (1) |
| | { |
| | bool matched = false; |
| |
|
| | for (size_t i = 0; i < graph.ops.size(); i++) |
| | { |
| | Operator* op = graph.ops[i]; |
| |
|
| | if (op->type != "Tensor.view" && op->type != "Tensor.reshape") |
| | continue; |
| |
|
| | if (op->inputs.size() != 2) |
| | continue; |
| |
|
| | Operator* op_expr = op->inputs[1]->producer; |
| | if (op_expr->type != "pnnx.Expression") |
| | continue; |
| |
|
| | std::string expr = op_expr->params.at("expr").s; |
| | if (expr.empty() || expr[0] != '[') |
| | continue; |
| |
|
| | std::vector<int> outshape = op->outputs[0]->shape; |
| | if (outshape.empty()) |
| | continue; |
| |
|
| | std::vector<int> shape; |
| | std::vector<std::string> expr_tokens; |
| | build_shape(expr, shape, expr_tokens); |
| |
|
| | |
| | for (size_t j = 0; j < outshape.size(); j++) |
| | { |
| | if (outshape[j] != -1) |
| | { |
| | shape[j] = outshape[j]; |
| | expr_tokens[j] = std::to_string(outshape[j]); |
| | } |
| | } |
| |
|
| | |
| | int dynamic_dim_count = 0; |
| | for (size_t j = 0; j < shape.size(); j++) |
| | { |
| | if (shape[j] == -1) |
| | { |
| | dynamic_dim_count += 1; |
| | } |
| | } |
| |
|
| | if (dynamic_dim_count > 1) |
| | { |
| | op_expr->params["expr"] = build_expr(expr_tokens); |
| | continue; |
| | } |
| |
|
| | matched = true; |
| |
|
| | op->params["shape"] = shape; |
| |
|
| | op->inputs.resize(1); |
| | op_expr->outputs[0]->remove_consumer(op); |
| |
|
| | if (op_expr->outputs[0]->consumers.size() == 0) |
| | { |
| | |
| | for (auto x : op_expr->inputs) |
| | { |
| | x->remove_consumer(op_expr); |
| | } |
| |
|
| | Operand* op_expr_out = op_expr->outputs[0]; |
| |
|
| | graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op_expr_out)); |
| | delete op_expr_out; |
| |
|
| | op_expr->inputs.clear(); |
| | op_expr->outputs.clear(); |
| |
|
| | graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op_expr)); |
| | delete op_expr; |
| | } |
| |
|
| | break; |
| | } |
| |
|
| | if (!matched) |
| | break; |
| | } |
| |
|
| | for (size_t i = 0; i < graph.ops.size(); i++) |
| | { |
| | Operator* op = graph.ops[i]; |
| |
|
| | if (op->type != "Tensor.view" && op->type != "Tensor.reshape") |
| | continue; |
| |
|
| | if (op->inputs.size() != 1) |
| | continue; |
| |
|
| | std::vector<int> outshape = op->outputs[0]->shape; |
| | if (outshape.empty()) |
| | continue; |
| |
|
| | std::vector<int> shape = op->params.at("shape").ai; |
| |
|
| | |
| | for (size_t j = 0; j < outshape.size(); j++) |
| | { |
| | if (outshape[j] != -1) |
| | { |
| | shape[j] = outshape[j]; |
| | } |
| | } |
| |
|
| | op->params["shape"] = shape; |
| | } |
| | } |
| |
|
| | } |
| |
|