// Tencent is pleased to support the open source community by making ncnn available. // // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // // Unless required by applicable law or agreed to in writing, software distributed // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. #include "shape_inference.h" #include #include "storezip.h" #include "pass_level0/constant_unpooling.h" #include "pass_level0/convert_half_to_float.h" #include "pass_level0/flatten_input.h" #include "pass_level0/inline_block.h" #include "pass_level0/reset_device.h" #include "pass_level0/shape_inference.h" namespace pnnx { static bool value_link_input(const torch::jit::Value* v, const std::vector& inputs, bool ignore_aten_size) { if (ignore_aten_size) { // any intermediate shape is constant with static input shape std::string optype = v->node()->kind().toDisplayString(); if (optype == "aten::size" || optype == "aten::new_empty" || optype == "aten::new_full" || optype == "aten::new_ones" || optype == "aten::new_zeros" || optype == "aten::empty_like" || optype == "aten::full_like" || optype == "aten::ones_like" || optype == "aten::zeros_like" || optype == "aten::_shape_as_tensor") return false; } for (auto x : inputs) { if (v == x) return true; } for (size_t i = 0; i < v->node()->inputs().size(); i++) { bool link = value_link_input(v->node()->inputs()[i], inputs, ignore_aten_size); if (link) return true; } return false; } static bool value_link_output(const torch::jit::Value* v, const std::vector& outputs) { for (auto x : outputs) { if (v == x) return true; } for (size_t i = 0; i < v->uses().size(); i++) { auto node = v->uses()[i].user; for (auto x : node->outputs()) { bool link = value_link_output(x, outputs); if (link) return true; } std::string op_type = node->kind().toDisplayString(); bool is_inplace_op = op_type.size() > 2 && op_type[op_type.size() - 2] != '_' && op_type[op_type.size() - 1] == '_'; if (is_inplace_op) { // optimize me: track other inplace op inputs return true; } } return false; } void shape_inference(const torch::jit::Module& mod, std::shared_ptr& graph, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, const std::string& device, std::set& foldable_constants, const std::string& foldable_constants_zippath) { // collect all intermediate output tensors std::vector > more_value_names; std::vector > more_values; { std::unordered_set value_names; std::vector values; for (const auto& n : graph->nodes()) { for (const auto& v : n->outputs()) { auto tensor_type = v->type()->cast(); if (!tensor_type) continue; value_names.insert(v->debugName()); values.push_back(v); } // too many intermediate blobs in one inference results oom if (value_names.size() >= 1000) { more_value_names.push_back(value_names); value_names.clear(); more_values.push_back(values); values.clear(); } } if (value_names.size() > 0) { more_value_names.push_back(value_names); more_values.push_back(values); } } // collect graph inputs outputs std::vector g_inputs; for (size_t i = 1; i < graph->inputs().size(); i++) { g_inputs.push_back(graph->inputs()[i]); } std::vector g_outputs; for (size_t i = 0; i < graph->outputs().size(); i++) { g_outputs.push_back(graph->outputs()[i]); } std::vector inputs; for (size_t i = 0; i < input_tensors.size(); i++) { const at::Tensor& it = input_tensors[i]; inputs.push_back(it); } std::vector inputs2; for (size_t i = 0; i < input_tensors2.size(); i++) { const at::Tensor& it = input_tensors2[i]; inputs2.push_back(it); } StoreZipWriter zip; zip.open(foldable_constants_zippath); for (size_t p = 0; p < more_value_names.size(); p++) { std::unordered_set& value_names = more_value_names[p]; std::vector& values = more_values[p]; // auto mod2 = mod.deepcopy(); torch::jit::Module mod2 = torch::jit::load(ptpath, (device == "gpu") ? c10::kCUDA : c10::kCPU); mod2.eval(); convert_half_to_float(mod2); auto method = mod2.find_method("forward"); if (!method) { method = mod2.get_methods()[0]; } auto graph2 = method->graph(); inline_block(graph2, module_operators); reset_device(graph2, device); flatten_input(graph2); constant_unpooling(graph2); std::vector values2; for (auto n : graph2->nodes()) { for (const auto& v : n->outputs()) { auto tensor_type = v->type()->cast(); if (!tensor_type) continue; if (value_names.find(v->debugName()) != value_names.end()) { values2.push_back(v); // fprintf(stderr, "%s ", v->debugName().c_str()); } } } fprintf(stderr, "\n----------------\n\n"); // set new graph output torch::jit::Node* new_return_node = graph2->createTuple(at::ArrayRef(values2)); graph2->appendNode(new_return_node); graph2->eraseOutput(0); graph2->registerOutput(new_return_node->outputs()[0]); // construct schema for new inputs and outputs { auto oldfs = method->function().getSchema(); std::vector arguments; std::vector returns; for (size_t i = 0; i < graph2->inputs().size(); i++) { auto v = graph2->inputs()[i]; arguments.push_back(c10::Argument(v->debugName(), v->type())); } for (size_t i = 0; i < graph2->outputs().size(); i++) { auto v = graph2->outputs()[i]; returns.push_back(c10::Argument(v->debugName(), v->type())); } c10::FunctionSchema newfs(oldfs.name(), oldfs.overload_name(), arguments, returns); method->function().setSchema(newfs); } // inference for all tensors auto outputs = mod2.copy().get_method(method->name())(inputs).toTuple(); if (input_tensors2.empty()) { // assign shape info for (size_t i = 0; i < values2.size(); i++) { auto v = values[i]; auto t = outputs->elements()[i].toTensor(); v->setType(c10::TensorType::create(t)); // check if value that does not depend on inputs if (!value_link_input(v, g_inputs, true) && value_link_output(v, g_outputs)) { // fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str()); foldable_constants.insert(v->debugName()); at::Tensor t2 = t.cpu().contiguous(); zip.write_file(v->debugName(), (const char*)t2.data_ptr(), t2.nbytes()); } } } else { // assign dynamic shape info auto outputs2 = mod2.copy().get_method(method->name())(inputs2).toTuple(); fprintf(stderr, "assign dynamic shape info\n"); for (size_t i = 0; i < values2.size(); i++) { auto v = values[i]; auto t = outputs->elements()[i].toTensor(); auto t2 = outputs2->elements()[i].toTensor(); auto type1 = c10::TensorType::create(t); auto type2 = c10::TensorType::create(t2); std::vector sizes1 = type1->symbolic_sizes().sizes().value(); std::vector sizes2 = type2->symbolic_sizes().sizes().value(); for (size_t i = 0; i < sizes1.size(); i++) { if (sizes1[i] == sizes2[i]) continue; sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1); } auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1)); v->setType(finaltype); // check if value that does not depend on inputs if (!value_link_input(v, g_inputs, false) && value_link_output(v, g_outputs)) { // fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str()); foldable_constants.insert(v->debugName()); at::Tensor t2 = t.cpu().contiguous(); zip.write_file(v->debugName(), (const char*)t2.data_ptr(), t2.nbytes()); } } } } zip.close(); if (input_tensors2.empty()) { for (size_t i = 0; i < input_tensors.size(); i++) { auto type = c10::TensorType::create(input_tensors[i]); graph->inputs()[1 + i]->setType(type); } } else { for (size_t i = 0; i < input_tensors.size(); i++) { auto type1 = c10::TensorType::create(input_tensors[i]); auto type2 = c10::TensorType::create(input_tensors2[i]); std::vector sizes1 = type1->symbolic_sizes().sizes().value(); std::vector sizes2 = type2->symbolic_sizes().sizes().value(); for (size_t i = 0; i < sizes1.size(); i++) { if (sizes1[i] == sizes2[i]) continue; sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1); } auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1)); graph->inputs()[1 + i]->setType(finaltype); } } } } // namespace pnnx