// Tencent is pleased to support the open source community by making ncnn available. // // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // // Unless required by applicable law or agreed to in writing, software distributed // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. #include "ir.h" #include #include #include #include #include #include #include #include #if BUILD_PNNX #include #include #endif #include "storezip.h" #include "utils.h" namespace pnnx { static bool type_is_integer(int type) { if (type == 1) return false; if (type == 2) return false; if (type == 3) return false; if (type == 4) return true; if (type == 5) return true; if (type == 6) return true; if (type == 7) return true; if (type == 8) return true; if (type == 9) return true; if (type == 10) return false; if (type == 11) return false; if (type == 12) return false; return false; } static const char* type_to_string(int type) { if (type == 1) return "f32"; if (type == 2) return "f64"; if (type == 3) return "f16"; if (type == 4) return "i32"; if (type == 5) return "i64"; if (type == 6) return "i16"; if (type == 7) return "i8"; if (type == 8) return "u8"; if (type == 9) return "bool"; if (type == 10) return "c64"; if (type == 11) return "c128"; if (type == 12) return "c32"; return "null"; } static const char* type_to_numpy_string(int type) { if (type == 1) return "float32"; if (type == 2) return "float64"; if (type == 3) return "float16"; if (type == 4) return "int32"; if (type == 5) return "int64"; if (type == 6) return "int16"; if (type == 7) return "int8"; if (type == 8) return "uint8"; if (type == 9) return "bool8"; if (type == 10) return "csingle"; if (type == 11) return "cdouble"; if (type == 12) return "chalf"; return "null"; } static const char* type_to_dtype_string(int type) { if (type == 1) return "torch.float"; if (type == 2) return "torch.double"; if (type == 3) return "torch.half"; if (type == 4) return "torch.int"; if (type == 5) return "torch.long"; if (type == 6) return "torch.short"; if (type == 7) return "torch.int8"; if (type == 8) return "torch.uint8"; if (type == 9) return "torch.bool"; if (type == 10) return "torch.complex64"; if (type == 11) return "torch.complex128"; if (type == 12) return "torch.complex32"; return "null"; } static size_t type_to_elemsize(int type) { if (type == 1) return 4; if (type == 2) return 8; if (type == 3) return 2; if (type == 4) return 4; if (type == 5) return 8; if (type == 6) return 2; if (type == 7) return 1; if (type == 8) return 1; if (type == 9) return 1; if (type == 10) return 8; if (type == 11) return 16; if (type == 12) return 4; return 0; // null } static int string_to_type(const char* s) { if (strcmp(s, "f32") == 0) return 1; if (strcmp(s, "f64") == 0) return 2; if (strcmp(s, "f16") == 0) return 3; if (strcmp(s, "i32") == 0) return 4; if (strcmp(s, "i64") == 0) return 5; if (strcmp(s, "i16") == 0) return 6; if (strcmp(s, "i8") == 0) return 7; if (strcmp(s, "u8") == 0) return 8; if (strcmp(s, "bool") == 0) return 9; if (strcmp(s, "c64") == 0) return 10; if (strcmp(s, "c128") == 0) return 11; if (strcmp(s, "c32") == 0) return 12; return 0; // null } #if BUILD_PNNX int get_at_tensor_type(const at::ScalarType& st) { if (st == c10::ScalarType::Float) return 1; if (st == c10::ScalarType::Double) return 2; if (st == c10::ScalarType::Half) return 3; if (st == c10::ScalarType::Int) return 4; if (st == c10::ScalarType::QInt32) return 4; if (st == c10::ScalarType::Long) return 5; if (st == c10::ScalarType::Short) return 6; if (st == c10::ScalarType::Char) return 7; if (st == c10::ScalarType::QInt8) return 7; if (st == c10::ScalarType::Byte) return 8; if (st == c10::ScalarType::QUInt8) return 8; if (st == c10::ScalarType::Bool) return 9; if (st == c10::ScalarType::ComplexFloat) return 10; if (st == c10::ScalarType::ComplexDouble) return 11; if (st == c10::ScalarType::ComplexHalf) return 12; return 0; // unknown type } Parameter::Parameter(const torch::jit::Node* value_node) { type = 0; if (value_node->kind() == c10::prim::Constant) { if (value_node->output()->type()->kind() == c10::TypeKind::NoneType) { type = 0; return; } if (!value_node->hasAttribute(torch::jit::attr::value)) { fprintf(stderr, "no attribute value\n"); value_node->dump(); return; } switch (value_node->output()->type()->kind()) { case c10::TypeKind::NoneType: { type = 0; break; } case c10::TypeKind::BoolType: { type = 1; b = value_node->i(torch::jit::attr::value); break; } case c10::TypeKind::IntType: { type = 2; int64_t i64 = value_node->i(torch::jit::attr::value); if (i64 == std::numeric_limits::max()) i64 = INT_MAX; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; i = (int)i64; break; } case c10::TypeKind::FloatType: { type = 3; f = (float)value_node->f(torch::jit::attr::value); break; } case c10::TypeKind::StringType: { type = 4; s = value_node->s(torch::jit::attr::value); break; } case c10::TypeKind::DeviceObjType: { type = 4; s = value_node->s(torch::jit::attr::value); break; } #if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 9) case c10::TypeKind::ComplexType: { type = 10; c = std::complex(value_node->c(torch::jit::attr::value)); break; } #endif case c10::TypeKind::TensorType: { at::Tensor t = value_node->t(torch::jit::attr::value); if (t.dim() == 0) { if (t.scalar_type() == c10::ScalarType::Long) { type = 2; int64_t i64 = t.item(); if (i64 == std::numeric_limits::max()) i64 = INT_MAX; if (i64 == std::numeric_limits::min()) i64 = INT_MIN; i = (int)i64; } else if (t.scalar_type() == c10::ScalarType::Int) { type = 2; i = t.item(); } else if (t.scalar_type() == c10::ScalarType::Double) { type = 3; f = (float)t.item(); } else if (t.scalar_type() == c10::ScalarType::Float) { type = 3; f = t.item(); } else if (t.scalar_type() == c10::ScalarType::ComplexDouble) { type = 10; c = std::complex(t.item >()); } else if (t.scalar_type() == c10::ScalarType::ComplexFloat) { type = 10; c = std::complex(t.item >()); } else { fprintf(stderr, "unknown Parameter value kind %s of TensorType, t.dim = 0\n", value_node->kind().toDisplayString()); } } else { // constant tensor will become pnnx attribute node later type = 8; } break; } default: { fprintf(stderr, "unknown Parameter value kind %s\n", c10::typeKindToString(value_node->output()->type()->kind())); break; } } } else if (value_node->kind() == c10::prim::ListConstruct) { switch (value_node->output()->type()->cast()->getElementType()->kind()) { case c10::TypeKind::IntType: { type = 5; for (const auto& x : value_node->inputs()) { ai.push_back((int)x->node()->i(torch::jit::attr::value)); } break; } case c10::TypeKind::FloatType: { type = 6; for (const auto& x : value_node->inputs()) { af.push_back((float)x->node()->f(torch::jit::attr::value)); } break; } case c10::TypeKind::StringType: { type = 7; for (const auto& x : value_node->inputs()) { as.push_back(x->node()->s(torch::jit::attr::value)); } break; } #if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 9) case c10::TypeKind::ComplexType: { type = 11; for (const auto& x : value_node->inputs()) { ac.push_back(std::complex(x->node()->c(torch::jit::attr::value))); } break; } #endif default: { fprintf(stderr, "unknown Parameter value list element kind %s\n", c10::typeKindToString(value_node->output()->type()->cast()->getElementType()->kind())); break; } } } else { fprintf(stderr, "unknown Parameter value_node kind %s\n", value_node->kind().toDisplayString()); } } Parameter::Parameter(const torch::jit::Value* value) : Parameter(value->node()) { } #endif // BUILD_PNNX bool operator==(const Parameter& lhs, const Parameter& rhs) { if (lhs.type != rhs.type) return false; if (lhs.type == 0) return true; if (lhs.type == 1 && lhs.b == rhs.b) return true; if (lhs.type == 2 && lhs.i == rhs.i) return true; if (lhs.type == 3 && lhs.f == rhs.f) return true; if (lhs.type == 4 && lhs.s == rhs.s) return true; if (lhs.type == 5 && lhs.ai == rhs.ai) return true; if (lhs.type == 6 && lhs.af == rhs.af) return true; if (lhs.type == 7 && lhs.as == rhs.as) return true; if (lhs.type == 10 && lhs.c == rhs.c) return true; if (lhs.type == 11 && lhs.ac == rhs.ac) return true; return false; } #if BUILD_PNNX Attribute::Attribute(const at::Tensor& t) { type = get_at_tensor_type(t.scalar_type()); const int ndim = (int)t.dim(); if (ndim == 0) { shape = {1}; data.resize(type_to_elemsize(type)); if (t.scalar_type() == c10::ScalarType::Long) { int64_t i = t.item(); memcpy((void*)data.data(), (const void*)&i, data.size()); } else if (t.scalar_type() == c10::ScalarType::Int) { int i = t.item(); memcpy((void*)data.data(), (const void*)&i, data.size()); } else if (t.scalar_type() == c10::ScalarType::Double) { double f = t.item(); memcpy((void*)data.data(), (const void*)&f, data.size()); } else if (t.scalar_type() == c10::ScalarType::Float) { float f = t.item(); memcpy((void*)data.data(), (const void*)&f, data.size()); } else { fprintf(stderr, "unknown Attribute tensor scalar type %d\n", type); } return; } shape.resize(ndim); for (int i = 0; i < ndim; i++) shape[i] = t.size(i); if (shape.size() > 0) { data.resize(elemcount() * type_to_elemsize(type)); memcpy((void*)data.data(), (const void*)t.cpu().contiguous().data_ptr(), data.size()); } } #endif // BUILD_PNNX Attribute::Attribute(const std::initializer_list& _shape, const std::vector& t) { type = 1; shape = _shape; if (shape.size() > 0) { data.resize(elemcount() * type_to_elemsize(type)); memcpy((void*)data.data(), (const void*)t.data(), data.size()); } } size_t Attribute::elemsize() const { return type_to_elemsize(type); } int Attribute::elemcount() const { if (shape.empty()) return 0; int size = shape[0]; for (size_t i = 1; i < shape.size(); i++) { size *= shape[i]; } return size; } std::vector Attribute::get_float32_data() const { std::vector v(elemcount()); if (type == 1) { memcpy((void*)v.data(), (const void*)data.data(), data.size()); } else if (type == 2) { // f64 const double* p = (const double*)data.data(); for (size_t i = 0; i < v.size(); i++) { v[i] = float(p[i]); } } else if (type == 3) { // f16 const unsigned short* p = (const unsigned short*)data.data(); for (size_t i = 0; i < v.size(); i++) { v[i] = float16_to_float32(p[i]); } } else { fprintf(stderr, "cannot convert type %d to float32 data\n", type); } return v; } void Attribute::set_float32_data(const std::vector& newdata) { data.resize(newdata.size() * elemsize()); if (type == 1) { memcpy((void*)data.data(), (const void*)newdata.data(), data.size()); } else if (type == 2) { // f64 double* p = (double*)data.data(); for (size_t i = 0; i < newdata.size(); i++) { p[i] = newdata[i]; } } else if (type == 3) { // f16 unsigned short* p = (unsigned short*)data.data(); for (size_t i = 0; i < newdata.size(); i++) { p[i] = float32_to_float16(newdata[i]); } } else { fprintf(stderr, "cannot convert float32 data to type %d\n", type); } } bool operator==(const Attribute& lhs, const Attribute& rhs) { if (lhs.type != rhs.type) return false; if (lhs.type == 0) return true; if (lhs.shape != rhs.shape) return false; if (lhs.data != rhs.data) return false; return true; } Attribute operator+(const Attribute& a, const Attribute& b) { Attribute c; if (a.type != b.type) { fprintf(stderr, "concat attribute type mismatch\n"); return c; } if (a.shape.size() != b.shape.size()) { fprintf(stderr, "concat attribute shape rank mismatch\n"); return c; } for (int i = 1; i < (int)a.shape.size(); i++) { if (a.shape[i] != b.shape[i]) { fprintf(stderr, "concat attribute shape mismatch\n"); return c; } } c.type = a.type; c.shape = a.shape; c.shape[0] += b.shape[0]; // concat the first dim c.data.resize(a.data.size() + b.data.size()); memcpy(c.data.data(), a.data.data(), a.data.size()); memcpy(c.data.data() + a.data.size(), b.data.data(), b.data.size()); return c; } Parameter Parameter::parse_from_string(const std::string& value) { if (value.find('%') != std::string::npos) { Parameter p; p.type = 4; p.s = value; return p; } Parameter p; p.type = 0; if (value == "None" || value == "()" || value == "[]") { return p; } if (value == "True" || value == "False") { // bool p.type = 1; p.b = value == "True"; return p; } if (value[0] == '(' || value[0] == '[') { // list std::string lc = value.substr(1, value.size() - 2); std::istringstream lcss(lc); while (!lcss.eof()) { std::string elem; std::getline(lcss, elem, ','); if ((elem[0] != '-' && (elem[0] < '0' || elem[0] > '9')) || (elem[0] == '-' && (elem[1] < '0' || elem[1] > '9'))) { // string p.type = 7; p.as.push_back(elem); } else if (elem.find('.') != std::string::npos || elem.find('e') != std::string::npos) { // float p.type = 6; p.af.push_back(std::stof(elem)); } else { // integer p.type = 5; p.ai.push_back(std::stoi(elem)); } } return p; } if ((value[0] != '-' && (value[0] < '0' || value[0] > '9')) || (value[0] == '-' && (value[1] < '0' || value[1] > '9'))) { // string p.type = 4; p.s = value; return p; } if (value.find('.') != std::string::npos || value.find('e') != std::string::npos) { // float p.type = 3; p.f = std::stof(value); return p; } // integer p.type = 2; p.i = std::stoi(value); return p; } std::string Parameter::encode_to_string(const Parameter& param) { if (param.type == 0) { return std::string("None"); } if (param.type == 1) { if (param.b) return std::string("True"); else return std::string("False"); } if (param.type == 2) { return std::to_string(param.i); } if (param.type == 3) { char buf[64]; sprintf(buf, "%e", param.f); return std::string(buf); } if (param.type == 4) { return param.s; } if (param.type == 5) { std::string s("("); for (size_t i = 0; i < param.ai.size(); i++) { s += std::to_string(param.ai[i]); if (i + 1 != param.ai.size()) s += std::string(","); } s += std::string(")"); return s; } if (param.type == 6) { std::string s("("); for (size_t i = 0; i < param.af.size(); i++) { char buf[64]; sprintf(buf, "%e", param.af[i]); s += std::string(buf); if (i + 1 != param.af.size()) s += std::string(","); } s += std::string(")"); return s; } if (param.type == 7) { std::string s("("); for (size_t i = 0; i < param.as.size(); i++) { s += param.as[i]; if (i + 1 != param.as.size()) s += std::string(","); } s += std::string(")"); return s; } if (param.type == 10) { char buf[128]; sprintf(buf, "%e+%ej", param.c.real(), param.c.imag()); return std::string(buf); } if (param.type == 11) { std::string s("("); for (size_t i = 0; i < param.ac.size(); i++) { char buf[128]; sprintf(buf, "%e+%ej", param.ac[i].real(), param.ac[i].imag()); s += std::string(buf); if (i + 1 != param.ac.size()) s += std::string(","); } s += std::string(")"); return s; } fprintf(stderr, "unknown parameter type %d\n", param.type); return std::string(); } Graph::Graph() { } Graph::~Graph() { for (auto x : ops) delete x; for (auto x : operands) delete x; ops.clear(); operands.clear(); } Graph::Graph(const Graph& /*rhs*/) { } Graph& Graph::operator=(const Graph& /*rhs*/) { return *this; } static void load_parameter(Operator* op, const std::string& key, const std::string& value) { op->params[key] = Parameter::parse_from_string(value); } static void load_input_key(Operator* op, const std::string& key, const std::string& value) { op->inputnames.resize(op->inputs.size()); for (size_t i = 0; i < op->inputs.size(); i++) { const Operand* oprand = op->inputs[i]; if (oprand->name == value) { op->inputnames[i] = key; break; } } } static void load_shape(Operator* op, const std::string& key, const std::string& value) { Operand* operand = 0; for (auto r : op->inputs) { if (r->name == key) { operand = r; break; } } if (!operand) { for (auto r : op->outputs) { if (r->name == key) { operand = r; break; } } } if (!operand) { fprintf(stderr, "no such operand %s for operator %s\n", key.c_str(), op->name.c_str()); return; } // type std::string typestr = value.substr(value.find_last_of(')') + 1); operand->type = string_to_type(typestr.c_str()); // shape std::string lc = value.substr(1, value.find_last_of(')') - 1); std::istringstream lcss(lc); operand->shape.clear(); while (!lcss.eof()) { std::string elem; std::getline(lcss, elem, ','); if (elem == "?") { operand->shape.push_back(-1); } else if (elem[0] == '%') { // encode %abc as symbolic tag operand->shape.push_back(-233); int index = operand->shape.size() - 1; std::string key = elem.substr(1); operand->params[std::string("__shape_") + std::to_string(index)] = key; } else { int i = std::stoi(elem); operand->shape.push_back(i); } } } static void load_attribute(Operator* op, const std::string& key, const std::string& value, StoreZipReader& szr) { Attribute& a = op->attrs[key]; // type std::string typestr = value.substr(value.find_last_of(')') + 1); a.type = string_to_type(typestr.c_str()); if (a.type == 0) return; // shape std::string lc = value.substr(1, value.find_last_of(')') - 1); std::istringstream lcss(lc); a.shape.clear(); while (!lcss.eof()) { std::string elem; std::getline(lcss, elem, ','); int i = std::stoi(elem); a.shape.push_back(i); } if (a.shape.empty()) return; // data size_t size = 1; for (int i : a.shape) { size *= i; } size_t bytesize = size * type_to_elemsize(a.type); std::string filename = op->name + "." + key; size_t filesize = szr.get_file_size(filename); if (filesize == 0) { // no such file return; } if (filesize != bytesize) { fprintf(stderr, "file size not match expect %lu but got %lu\n", bytesize, filesize); } a.data.resize(bytesize); szr.read_file(filename, (char*)a.data.data()); } int Graph::load(const std::string& parampath, const std::string& binpath) { std::ifstream is(parampath, std::ios::in | std::ios::binary); if (!is.good()) { fprintf(stderr, "open failed\n"); return -1; } StoreZipReader szr; if (szr.open(binpath) != 0) { fprintf(stderr, "open failed\n"); return -1; } int magic = 0; { std::string line; std::getline(is, line); std::istringstream iss(line); iss >> magic; } int operator_count = 0; int operand_count = 0; { std::string line; std::getline(is, line); std::istringstream iss(line); iss >> operator_count >> operand_count; } for (int i = 0; i < operator_count; i++) { std::string line; std::getline(is, line); std::istringstream iss(line); std::string type; std::string name; int input_count = 0; int output_count = 0; iss >> type >> name >> input_count >> output_count; Operator* op = new_operator(type, name); for (int j = 0; j < input_count; j++) { std::string operand_name; iss >> operand_name; Operand* r = get_operand(operand_name); r->consumers.push_back(op); op->inputs.push_back(r); } for (int j = 0; j < output_count; j++) { std::string operand_name; iss >> operand_name; Operand* r = new_operand(operand_name); r->producer = op; op->outputs.push_back(r); } // key=value while (!iss.eof()) { std::string param; iss >> param; std::string key; std::string value; std::istringstream pss(param); std::getline(pss, key, '='); std::getline(pss, value); if (key[0] == '@') { // attribute load_attribute(op, key.substr(1), value, szr); } else if (key[0] == '$') { // operand input key load_input_key(op, key.substr(1), value); } else if (key[0] == '#') { // operand shape load_shape(op, key.substr(1), value); } else { // parameter load_parameter(op, key, value); } } } return 0; } int Graph::save(const std::string& parampath, const std::string& binpath) { FILE* paramfp = fopen(parampath.c_str(), "wb"); if (!paramfp) { fprintf(stderr, "fopen %s failed\n", parampath.c_str()); return -1; } StoreZipWriter szw; if (szw.open(binpath) != 0) { fprintf(stderr, "open failed\n"); return -1; } // magic fprintf(paramfp, "7767517\n"); // op count and oprand count fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size()); for (const Operator* op : ops) { fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size()); for (const Operand* oprand : op->inputs) { fprintf(paramfp, " %s", oprand->name.c_str()); } for (const Operand* oprand : op->outputs) { fprintf(paramfp, " %s", oprand->name.c_str()); } for (const auto& it : op->params) { fprintf(paramfp, " %s=", it.first.c_str()); const Parameter& param = it.second; std::string s = Parameter::encode_to_string(param); fprintf(paramfp, "%s", s.c_str()); } for (const auto& it : op->attrs) { fprintf(paramfp, " @%s=", it.first.c_str()); const Attribute& attr = it.second; fprintf(paramfp, "("); for (int i = 0; i < (int)attr.shape.size() - 1; i++) { fprintf(paramfp, "%d,", attr.shape[i]); } if (attr.shape.size() > 0) fprintf(paramfp, "%d", attr.shape[attr.shape.size() - 1]); fprintf(paramfp, ")"); fprintf(paramfp, type_to_string(attr.type)); std::string filename = op->name + "." + it.first; szw.write_file(filename, attr.data.data(), attr.data.size()); } if (op->inputnames.size() == op->inputs.size()) { for (size_t i = 0; i < op->inputs.size(); i++) { if (op->inputnames[i].empty()) continue; const Operand* oprand = op->inputs[i]; fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str()); } } for (const Operand* oprand : op->inputs) { if (oprand->shape.empty()) continue; fprintf(paramfp, " #%s=", oprand->name.c_str()); fprintf(paramfp, "("); for (int i = 0; i < (int)oprand->shape.size() - 1; i++) { if (oprand->shape[i] == -1) fprintf(paramfp, "?,"); else fprintf(paramfp, "%d,", oprand->shape[i]); } if (oprand->shape.size() > 0) { if (oprand->shape[oprand->shape.size() - 1] == -1) fprintf(paramfp, "?"); else fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); } fprintf(paramfp, ")"); fprintf(paramfp, type_to_string(oprand->type)); } for (const Operand* oprand : op->outputs) { if (oprand->shape.empty()) continue; fprintf(paramfp, " #%s=", oprand->name.c_str()); fprintf(paramfp, "("); for (int i = 0; i < (int)oprand->shape.size() - 1; i++) { if (oprand->shape[i] == -1) fprintf(paramfp, "?,"); else fprintf(paramfp, "%d,", oprand->shape[i]); } if (oprand->shape.size() > 0) { if (oprand->shape[oprand->shape.size() - 1] == -1) fprintf(paramfp, "?"); else fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); } fprintf(paramfp, ")"); fprintf(paramfp, type_to_string(oprand->type)); } fprintf(paramfp, "\n"); } fclose(paramfp); return 0; } static std::string sanitize_identifier(const std::string& s) { std::string ss = s; for (size_t i = 0; i < ss.size(); i++) { if (ss[i] == '.' || ss[i] == ':') ss[i] = '_'; } return ss; } static std::string expand_expression(const Operator* op) { std::string expr = op->params.at("expr").s; // split into tokens std::vector tokens; { std::string t; for (size_t i = 0; i < expr.size(); i++) { char ch = expr[i]; if (ch == '[') // list { t += ch; tokens.push_back(t); t.clear(); } else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') { if (!t.empty()) { tokens.push_back(t); t.clear(); } } else { t += ch; } } if (!t.empty()) { tokens.push_back(t); } } // scan and stack std::stack exprstack; for (int i = (int)tokens.size() - 1; i >= 0; i--) { const std::string& t = tokens[i]; if (t == "size") { std::string a = exprstack.top(); exprstack.pop(); std::string b = exprstack.top(); exprstack.pop(); std::string r = a + ".size(" + b + ")"; exprstack.push(r); } else if (t == "int" || t == "abs" || t == "acos" || t == "acosh" || t == "asin" || t == "asinh" || t == "atan" || t == "atanh" || t == "ceil" || t == "cos" || t == "cosh" || t == "exp" || t == "floor" || t == "log" || t == "log10" || t == "neg" || t == "reciprocal" || t == "round" || t == "rsqrt" || t == "sign" || t == "sin" || t == "sinh" || t == "sqrt" || t == "square" || t == "tan" || t == "tanh" || t == "trunc") { std::string unaryop; if (t == "int") unaryop = "int"; if (t == "abs") unaryop = "torch.abs"; if (t == "acos") unaryop = "torch.acos"; if (t == "acosh") unaryop = "torch.acosh"; if (t == "asin") unaryop = "torch.asin"; if (t == "asinh") unaryop = "torch.asinh"; if (t == "atan") unaryop = "torch.atan"; if (t == "atanh") unaryop = "torch.atanh"; if (t == "ceil") unaryop = "torch.ceil"; if (t == "cos") unaryop = "torch.cos"; if (t == "cosh") unaryop = "torch.cosh"; if (t == "exp") unaryop = "torch.exp"; if (t == "floor") unaryop = "torch.floor"; if (t == "log") unaryop = "torch.log"; if (t == "log10") unaryop = "torch.log10"; if (t == "neg") unaryop = "torch.neg"; if (t == "reciprocal") unaryop = "torch.reciprocal"; if (t == "round") unaryop = "torch.round"; if (t == "rsqrt") unaryop = "torch.rsqrt"; if (t == "sign") unaryop = "torch.sign"; if (t == "sin") unaryop = "torch.sin"; if (t == "sinh") unaryop = "torch.sinh"; if (t == "sqrt") unaryop = "torch.sqrt"; if (t == "square") unaryop = "torch.square"; if (t == "tan") unaryop = "torch.tan"; if (t == "tanh") unaryop = "torch.tanh"; if (t == "trunc") unaryop = "torch.trunc"; std::string a = exprstack.top(); exprstack.pop(); std::string r = unaryop + "(" + a + ")"; exprstack.push(r); } else if (t == "atan2" || t == "fmod" || t == "max" || t == "maximum" || t == "min" || t == "minimum" || t == "pow") { std::string binaryop; if (t == "atan2") binaryop = "torch.atan2"; if (t == "fmod") binaryop = "torch.fmod"; if (t == "max") binaryop = "torch.max"; if (t == "maximum") binaryop = "torch.maximum"; if (t == "min") binaryop = "torch.min"; if (t == "minimum") binaryop = "torch.minimum"; if (t == "pow") binaryop = "torch.pow"; std::string a = exprstack.top(); exprstack.pop(); std::string b = exprstack.top(); exprstack.pop(); std::string r = binaryop + "(" + a + ", " + b + ")"; exprstack.push(r); } else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "remainder" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift") { std::string binaryop; if (t == "add") binaryop = "+"; if (t == "sub") binaryop = "-"; if (t == "mul") binaryop = "*"; if (t == "div") binaryop = "/"; if (t == "floor_divide") binaryop = "//"; if (t == "remainder") binaryop = "%"; if (t == "and") binaryop = "&"; if (t == "or") binaryop = "|"; if (t == "xor") binaryop = "^"; if (t == "lshift") binaryop = "<<"; if (t == "rshift") binaryop = ">>"; std::string a = exprstack.top(); exprstack.pop(); std::string b = exprstack.top(); exprstack.pop(); std::string r = std::string("(") + a + " " + binaryop + " " + b + ")"; exprstack.push(r); } else if (t == "[") // list { std::vector elements; while (!exprstack.empty()) { std::string a = exprstack.top(); exprstack.pop(); elements.push_back(a); } std::string r = "["; for (int j = 0; j < (int)elements.size() - 1; j++) { r += elements[j]; if (j + 1 != (int)elements.size()) r += ", "; } if (!elements.empty()) { r += elements[elements.size() - 1]; } r += "]"; exprstack.push(r); } else if (t[0] == '@') { int input_index = std::stoi(t.substr(1)); std::string varid = std::string("v_") + sanitize_identifier(op->inputs[input_index]->name); exprstack.push(varid); } else { // literal if (t[t.size() - 1] == 'j') { // complex std::string r = std::string("(") + t + ")"; exprstack.push(r); } else { exprstack.push(t); } } } std::string r = exprstack.top(); exprstack.pop(); return r; } static std::string make_slice_expression(const Operator* op) { for (size_t j = 0; j < op->inputnames.size(); j++) { fprintf(stderr, "make_slice_expression %s %s\n", op->inputnames[j].c_str(), op->inputs[j]->name.c_str()); } std::vector dims; if (op->params.find("dims") != op->params.end()) { dims = op->params.at("dims").ai; } else { dims.push_back(op->params.at("dim").i); } std::string r; int last_dim = -1; const int ndim = (int)dims.size(); for (int i = 0; i < ndim; i++) { int dim = dims[i]; for (int j = last_dim + 1; j < dim; j++) { r += ":,"; } last_dim = dim; if (op->params.find("starts") != op->params.end()) { std::vector starts = op->params.at("starts").ai; int start = starts[i]; if (start != 0) r += std::to_string(start); } else { fprintf(stderr, "find start\n"); // find start for (size_t j = 0; j < op->inputnames.size(); j++) { if (op->inputnames[j] == "start") { r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); fprintf(stderr, "find start %s\n", op->inputs[j]->name.c_str()); break; } } } r += ':'; if (op->params.find("ends") != op->params.end()) { std::vector ends = op->params.at("ends").ai; int end = ends[i]; if (end != INT_MAX) r += std::to_string(end); } else { // find end for (size_t j = 0; j < op->inputnames.size(); j++) { if (op->inputnames[j] == "end") { r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); break; } } } if (op->params.find("steps") != op->params.end()) { std::vector steps = op->params.at("steps").ai; int step = steps[i]; if (step != 1) { r += ':'; r += std::to_string(step); } } else { // find step for (size_t j = 0; j < op->inputnames.size(); j++) { if (op->inputnames[j] == "step") { r += ':'; r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); break; } } } if (i + 1 != ndim) r += ','; } return r; } static std::string make_index_expression(const Operator* op) { fprintf(stderr, "make_index_expression %s\n", op->name.c_str()); std::string index_expr = op->params.at("expr").s; // strip out-most [ ] pair index_expr = index_expr.substr(1, index_expr.size() - 2); // None,None, -> ..., bool leading_none = false; while (index_expr.substr(0, 5) == "None,") { leading_none = true; index_expr = index_expr.substr(5); } if (leading_none) { index_expr = "...," + index_expr; } return index_expr; } int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { FILE* pyfp = fopen(pypath.c_str(), "wb"); if (!pyfp) { fprintf(stderr, "fopen %s failed\n", pypath.c_str()); return -1; } fprintf(pyfp, "import os\n"); fprintf(pyfp, "import numpy as np\n"); fprintf(pyfp, "import tempfile, zipfile\n"); fprintf(pyfp, "import torch\n"); fprintf(pyfp, "import torch.nn as nn\n"); fprintf(pyfp, "import torch.nn.functional as F\n"); fprintf(pyfp, "try:\n"); fprintf(pyfp, " import torchvision\n"); fprintf(pyfp, "except:\n"); fprintf(pyfp, " pass\n"); fprintf(pyfp, "\n"); fprintf(pyfp, "class Model(nn.Module):\n"); fprintf(pyfp, " def __init__(self):\n"); fprintf(pyfp, " super(Model, self).__init__()\n"); fprintf(pyfp, "\n"); // module { for (const Operator* op : ops) { if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.") continue; fprintf(pyfp, " self.%s = %s(", sanitize_identifier(op->name).c_str(), op->type.c_str()); int param_count = op->params.size(); if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") { param_count -= 2; // ignore scale and zero_point } int param_index = 0; for (const auto& it : op->params) { if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") { if (it.first == "scale" || it.first == "zero_point") continue; } fprintf(pyfp, "%s=", it.first.c_str()); const Parameter& param = it.second; if (param.type == 0) { fprintf(pyfp, "None"); } if (param.type == 1) { if (param.b) fprintf(pyfp, "True"); else fprintf(pyfp, "False"); } if (param.type == 2) { fprintf(pyfp, "%d", param.i); } if (param.type == 3) { fprintf(pyfp, "%f", param.f); } if (param.type == 4) { if (param.s.substr(0, 6) == "torch.") { fprintf(pyfp, "%s", param.s.c_str()); } else { fprintf(pyfp, "\'%s\'", param.s.c_str()); } } if (param.type == 5) { fprintf(pyfp, "("); for (size_t i = 0; i < param.ai.size(); i++) { fprintf(pyfp, "%d", param.ai[i]); if (i + 1 != param.ai.size() || param.ai.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } if (param.type == 6) { fprintf(pyfp, "("); for (size_t i = 0; i < param.af.size(); i++) { fprintf(pyfp, "%f", param.af[i]); if (i + 1 != param.af.size() || param.af.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } if (param.type == 7) { fprintf(pyfp, "("); for (size_t i = 0; i < param.as.size(); i++) { if (param.as[i].substr(0, 6) == "torch.") { fprintf(pyfp, "%s", param.as[i].c_str()); } else { fprintf(pyfp, "\'%s\'", param.as[i].c_str()); } if (i + 1 != param.as.size() || param.as.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } param_index++; if (param_index != param_count) fprintf(pyfp, ", "); } fprintf(pyfp, ")\n"); } } fprintf(pyfp, "\n"); // load weights { fprintf(pyfp, " archive = zipfile.ZipFile('%s', 'r')\n", pnnxbinpath.c_str()); for (const Operator* op : ops) { if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.") continue; if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") { for (const auto& it : op->attrs) { if (it.first == "weight" || it.first == "bias") { fprintf(pyfp, " self_%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); } else { // unknown attr continue; } const Attribute& attr = it.second; for (size_t i = 0; i < attr.shape.size(); i++) { fprintf(pyfp, "%d", attr.shape[i]); if (i + 1 != attr.shape.size()) fprintf(pyfp, ","); } fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); } fprintf(pyfp, " self.%s.set_weight_bias(self_%s_weight, self_%s_bias)\n", sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str()); fprintf(pyfp, " self.%s.scale = %f\n", sanitize_identifier(op->name).c_str(), op->params.at("scale").f); fprintf(pyfp, " self.%s.zero_point = %d\n", sanitize_identifier(op->name).c_str(), op->params.at("zero_point").i); continue; } for (const auto& it : op->attrs) { if (it.first == "running_mean" || it.first == "running_var") { fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); } else { fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); } const Attribute& attr = it.second; for (size_t i = 0; i < attr.shape.size(); i++) { fprintf(pyfp, "%d", attr.shape[i]); if (i + 1 != attr.shape.size()) fprintf(pyfp, ","); } if (attr.type == 1 || attr.type == 2 || attr.type == 3) { fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); } else { fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); } } } for (const Operator* op : ops) { if (op->type != "pnnx.Attribute") continue; const std::string& key = op->attrs.begin()->first; const Attribute& attr = op->attrs.begin()->second; bool is_running_mean_var = false; { const Operand* r = op->outputs[0]; if (r->consumers.size() == 1) { const Operator* op2 = r->consumers[0]; if (op2->type == "F.batch_norm" || op2->type == "F.instance_norm") { if (r == op2->inputs[1] || r == op2->inputs[2]) { is_running_mean_var = true; } } } } if (is_running_mean_var) { fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str()); } else { fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str()); } for (size_t i = 0; i < attr.shape.size(); i++) { fprintf(pyfp, "%d", attr.shape[i]); if (i + 1 != attr.shape.size()) fprintf(pyfp, ","); } if (attr.type == 1 || attr.type == 2 || attr.type == 3) { fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); } else { fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); } } fprintf(pyfp, " archive.close()\n"); } fprintf(pyfp, "\n"); // utility function { fprintf(pyfp, " def load_pnnx_bin_as_parameter(self, archive, key, shape, dtype, requires_grad=True):\n"); fprintf(pyfp, " return nn.Parameter(self.load_pnnx_bin_as_tensor(archive, key, shape, dtype), requires_grad)\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " def load_pnnx_bin_as_tensor(self, archive, key, shape, dtype):\n"); fprintf(pyfp, " fd, tmppath = tempfile.mkstemp()\n"); fprintf(pyfp, " with os.fdopen(fd, 'wb') as tmpf, archive.open(key) as keyfile:\n"); fprintf(pyfp, " tmpf.write(keyfile.read())\n"); fprintf(pyfp, " m = np.memmap(tmppath, dtype=dtype, mode='r', shape=shape).copy()\n"); fprintf(pyfp, " os.remove(tmppath)\n"); fprintf(pyfp, " return torch.from_numpy(m)\n"); } fprintf(pyfp, "\n"); // def forward { fprintf(pyfp, " def forward(self"); for (const Operator* op : ops) { if (op->type != "pnnx.Input") continue; fprintf(pyfp, ", v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); } fprintf(pyfp, "):\n"); } // forward body { for (const Operator* op : ops) { if (op->type == "pnnx.Input" || op->type == "pnnx.Output") continue; fprintf(pyfp, " "); if (op->type == "pnnx.Expression") { // expr for (size_t i = 0; i < op->outputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); } std::string expanded_expr = expand_expression(op); fprintf(pyfp, " = %s\n", expanded_expr.c_str()); } else if (op->type == "pnnx.Attribute") { const std::string& key = op->attrs.begin()->first; fprintf(pyfp, "v_%s = self.%s_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str()); } else if (op->type == "Tensor.slice") { // slice expr std::string slice_expr = make_slice_expression(op); fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), slice_expr.c_str()); } else if (op->type == "Tensor.slice_copy") { // slice copy expr std::string slice_expr = make_slice_expression(op); fprintf(pyfp, "v_%s = v_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str()); fprintf(pyfp, " v_%s[%s] = v_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), slice_expr.c_str(), sanitize_identifier(op->inputs[1]->name).c_str()); } else if (op->type == "Tensor.index") { // index expr if (op->inputs.size() == 2) { std::string expanded_expr = expand_expression(op->inputs[1]->producer); fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str()); } else { std::string index_expr = make_index_expression(op); fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); } } else if (op->type == "Tensor.expand") { // expand fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); if (op->inputs.size() == 2) { fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); } else { const std::vector& shape = op->params.at("shape").ai; for (size_t i = 0; i < shape.size(); i++) { fprintf(pyfp, "%d", shape[i]); if (i + 1 != shape.size()) fprintf(pyfp, ", "); } } fprintf(pyfp, ")\n"); } else if (op->type == "Tensor.view" || op->type == "Tensor.reshape") { // view reshape fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); if (op->inputs.size() == 2) { fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); } else { const std::vector& shape = op->params.at("shape").ai; for (size_t i = 0; i < shape.size(); i++) { fprintf(pyfp, "%d", shape[i]); if (i + 1 != shape.size()) fprintf(pyfp, ", "); } } fprintf(pyfp, ")\n"); } else if (op->type == "Tensor.repeat") { // view reshape fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); if (op->inputs.size() == 2) { fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); } else { const std::vector& sizes = op->params.at("sizes").ai; for (size_t i = 0; i < sizes.size(); i++) { fprintf(pyfp, "%d", sizes[i]); if (i + 1 != sizes.size()) fprintf(pyfp, ", "); } } fprintf(pyfp, ")\n"); } else if (op->type == "torch.cat" || op->type == "torch.stack") { // cat fprintf(pyfp, "v_%s = %s(", sanitize_identifier(op->outputs[0]->name).c_str(), op->type.c_str()); if (op->inputs.size() == 1) { fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); } else { fprintf(pyfp, "("); for (size_t i = 0; i < op->inputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } fprintf(pyfp, ")"); } fprintf(pyfp, ", dim=%d", op->params.at("dim").i); fprintf(pyfp, ")\n"); } else if (op->type == "torch.einsum") { // einsum fprintf(pyfp, "v_%s = %s(", sanitize_identifier(op->outputs[0]->name).c_str(), op->type.c_str()); fprintf(pyfp, "\'%s\'", op->params.at("equation").s.c_str()); for (size_t i = 0; i < op->inputs.size(); i++) { fprintf(pyfp, ", v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); } fprintf(pyfp, ")\n"); } else if (op->type == "prim::TupleUnpack") { for (size_t i = 0; i < op->outputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); } fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); } else if (op->type == "prim::TupleConstruct") { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); fprintf(pyfp, " = ("); for (size_t i = 0; i < op->inputs.size(); i++) { fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); } fprintf(pyfp, ")\n"); } else if (op->type == "prim::ListUnpack") { for (size_t i = 0; i < op->outputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); } fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); } else if (op->type == "prim::ListConstruct") { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); fprintf(pyfp, " = ["); for (size_t i = 0; i < op->inputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } fprintf(pyfp, "]\n"); } else if (op->type == "nn.GRU" || op->type == "nn.RNN") { if (op->outputs.size() == 1) { fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); } else { fprintf(pyfp, "v_%s, v_%s", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->outputs[1]->name).c_str()); } fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); if (op->inputs.size() == 2) { fprintf(pyfp, ", v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); } fprintf(pyfp, ")\n"); } else if (op->type == "nn.LSTM") { if (op->outputs.size() == 1) { fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); } else { fprintf(pyfp, "v_%s, (v_%s, v_%s)", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->outputs[1]->name).c_str(), sanitize_identifier(op->outputs[2]->name).c_str()); } fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); if (op->inputs.size() == 3) { fprintf(pyfp, ", (v_%s, v_%s)", sanitize_identifier(op->inputs[1]->name).c_str(), sanitize_identifier(op->inputs[2]->name).c_str()); } fprintf(pyfp, ")\n"); } else if (op->type == "nn.MultiheadAttention") { bool need_weights = true; if (op->outputs.size() == 1) { fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); need_weights = false; } else { for (size_t i = 0; i < op->outputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); } } fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); if (op->inputs.size() == 1) { std::string in0 = sanitize_identifier(op->inputs[0]->name); fprintf(pyfp, "v_%s, v_%s, v_%s", in0.c_str(), in0.c_str(), in0.c_str()); } else if (op->inputs.size() == 2) { std::string in0 = sanitize_identifier(op->inputs[0]->name); std::string in1 = sanitize_identifier(op->inputs[1]->name); if (op->inputnames.size() == 2 && op->inputnames[1] == "attn_mask") { fprintf(pyfp, "v_%s, v_%s, v_%s, attn_mask=v_%s", in0.c_str(), in0.c_str(), in0.c_str(), in1.c_str()); } else { fprintf(pyfp, "v_%s, v_%s, v_%s", in0.c_str(), in1.c_str(), in1.c_str()); } } else if (op->inputs.size() == 3) { std::string in0 = sanitize_identifier(op->inputs[0]->name); std::string in1 = sanitize_identifier(op->inputs[1]->name); std::string in2 = sanitize_identifier(op->inputs[2]->name); if (op->inputnames.size() == 3 && op->inputnames[2] == "attn_mask") { fprintf(pyfp, "v_%s, v_%s, v_%s, attn_mask=v_%s", in0.c_str(), in1.c_str(), in1.c_str(), in2.c_str()); } else { fprintf(pyfp, "v_%s, v_%s, v_%s", in0.c_str(), in1.c_str(), in2.c_str()); } } else if (op->inputs.size() == 4) { std::string in0 = sanitize_identifier(op->inputs[0]->name); std::string in1 = sanitize_identifier(op->inputs[1]->name); std::string in2 = sanitize_identifier(op->inputs[2]->name); std::string in3 = sanitize_identifier(op->inputs[3]->name); if (op->inputnames.size() == 4 && op->inputnames[3] == "attn_mask") { fprintf(pyfp, "v_%s, v_%s, v_%s, attn_mask=v_%s", in0.c_str(), in1.c_str(), in2.c_str(), in3.c_str()); } else { fprintf(pyfp, "v_%s, v_%s, v_%s, v_%s", in0.c_str(), in1.c_str(), in2.c_str(), in3.c_str()); } } else { for (size_t i = 0; i < op->inputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } } if (need_weights) { fprintf(pyfp, ", need_weights=True"); } else { fprintf(pyfp, ", need_weights=False"); } fprintf(pyfp, ")\n"); } else if (op->type.substr(0, 3) == "nn." || op->type.substr(0, 16) == "torchvision.ops.") { // self.xxx() for (size_t i = 0; i < op->outputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); } fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); for (size_t i = 0; i < op->inputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } fprintf(pyfp, ")\n"); } else if (op->type.find("::") != std::string::npos || op->type.find(".") != std::string::npos) { // direct for (size_t i = 0; i < op->outputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); } if (op->type.substr(0, 7) == "Tensor.") { if (op->type == "Tensor.fill") { fprintf(pyfp, " = v_%s.fill_(", sanitize_identifier(op->inputs[0]->name).c_str()); } else { fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); } if (op->inputnames.size() == op->inputs.size()) { for (size_t i = 1; i < op->inputs.size(); i++) { if (!op->inputnames[i].empty()) continue; fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); } for (size_t i = 1; i < op->inputs.size(); i++) { if (op->inputnames[i].empty()) continue; fprintf(pyfp, "%s=v_%s, ", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str()); } } else { for (size_t i = 1; i < op->inputs.size(); i++) { fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); } } } else { fprintf(pyfp, " = %s(", op->type.c_str()); if (op->inputnames.size() == op->inputs.size()) { for (size_t i = 0; i < op->inputs.size(); i++) { if (!op->inputnames[i].empty()) continue; fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } for (size_t i = 0; i < op->inputs.size(); i++) { if (op->inputnames[i].empty()) continue; fprintf(pyfp, "%s=v_%s", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } } else { for (size_t i = 0; i < op->inputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } } } int i = 0; for (const auto& it : op->params) { if (op->type.substr(0, 7) == "Tensor." && i == 0) { fprintf(pyfp, "%s=", it.first.c_str()); } else if (op->inputs.empty() && i == 0) { fprintf(pyfp, "%s=", it.first.c_str()); } else { fprintf(pyfp, ", %s=", it.first.c_str()); } i++; const Parameter& param = it.second; if (param.type == 0) { fprintf(pyfp, "None"); } if (param.type == 1) { if (param.b) fprintf(pyfp, "True"); else fprintf(pyfp, "False"); } if (param.type == 2) { fprintf(pyfp, "%d", param.i); } if (param.type == 3) { fprintf(pyfp, "%f", param.f); } if (param.type == 4) { if (param.s.substr(0, 6) == "torch.") { fprintf(pyfp, "%s", param.s.c_str()); } else { fprintf(pyfp, "\'%s\'", param.s.c_str()); } } if (param.type == 5) { fprintf(pyfp, "("); for (size_t i = 0; i < param.ai.size(); i++) { fprintf(pyfp, "%d", param.ai[i]); if (i + 1 != param.ai.size() || param.ai.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } if (param.type == 6) { fprintf(pyfp, "("); for (size_t i = 0; i < param.af.size(); i++) { fprintf(pyfp, "%f", param.af[i]); if (i + 1 != param.af.size() || param.af.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } if (param.type == 7) { fprintf(pyfp, "("); for (size_t i = 0; i < param.as.size(); i++) { if (param.as[i].substr(0, 6) == "torch.") { fprintf(pyfp, "%s", param.as[i].c_str()); } else { fprintf(pyfp, "\'%s\'", param.as[i].c_str()); } if (i + 1 != param.as.size() || param.as.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } if (param.type == 10) { fprintf(pyfp, "(%f%+fj)", param.c.real(), param.c.imag()); } if (param.type == 11) { fprintf(pyfp, "("); for (size_t i = 0; i < param.ac.size(); i++) { fprintf(pyfp, "(%f%+fj)", param.ac[i].real(), param.ac[i].imag()); if (i + 1 != param.ac.size() || param.ac.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } } fprintf(pyfp, ")\n"); } else { fprintf(stderr, "todo %s\n", op->type.c_str()); } } } // return { fprintf(pyfp, " return "); int output_count = 0; { for (const Operator* op : ops) { if (op->type == "pnnx.Output") output_count++; } } int output_index = 0; for (const Operator* op : ops) { if (op->type != "pnnx.Output") continue; fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); if (output_index + 1 != output_count) fprintf(pyfp, ", "); output_index++; } fprintf(pyfp, "\n"); } fprintf(pyfp, "\n"); // export torchscript { fprintf(pyfp, "def export_torchscript():\n"); fprintf(pyfp, " net = Model()\n"); fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " torch.manual_seed(0)\n"); std::vector input_names; for (const Operator* op : ops) { if (op->type != "pnnx.Input") continue; const Operand* r = op->outputs[0]; std::string input_name = std::string("v_") + sanitize_identifier(r->name); if (type_is_integer(r->type)) { fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); for (size_t i = 0; i < r->shape.size(); i++) { fprintf(pyfp, "%d", r->shape[i]); if (i + 1 != r->shape.size() || r->shape.size() == 1) fprintf(pyfp, ", "); } fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); } else { fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); for (size_t i = 0; i < r->shape.size(); i++) { fprintf(pyfp, "%d, ", r->shape[i]); } fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); } input_names.push_back(input_name); } fprintf(pyfp, "\n"); if (input_names.size() == 1) { fprintf(pyfp, " mod = torch.jit.trace(net, %s)\n", input_names[0].c_str()); } else { fprintf(pyfp, " mod = torch.jit.trace(net, ("); for (size_t i = 0; i < input_names.size(); i++) { fprintf(pyfp, "%s", input_names[i].c_str()); if (i + 1 != input_names.size()) fprintf(pyfp, ", "); } fprintf(pyfp, "))\n"); } fprintf(pyfp, " mod.save(\"%s.pt\")\n", pypath.c_str()); } fprintf(pyfp, "\n"); // export onnx { fprintf(pyfp, "def export_onnx():\n"); fprintf(pyfp, " net = Model()\n"); fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " torch.manual_seed(0)\n"); std::vector input_names; for (const Operator* op : ops) { if (op->type != "pnnx.Input") continue; const Operand* r = op->outputs[0]; std::string input_name = std::string("v_") + sanitize_identifier(r->name); if (type_is_integer(r->type)) { fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); for (size_t i = 0; i < r->shape.size(); i++) { fprintf(pyfp, "%d", r->shape[i]); if (i + 1 != r->shape.size() || r->shape.size() == 1) fprintf(pyfp, ", "); } fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); } else { fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); for (size_t i = 0; i < r->shape.size(); i++) { fprintf(pyfp, "%d, ", r->shape[i]); } fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); } input_names.push_back(input_name); } fprintf(pyfp, "\n"); // torch.onnx._export(net, v_0, "test_swin_t.onnx", export_params=True, opset_version=14, input_names=['in0'], output_names=['out0']) if (input_names.size() == 1) { fprintf(pyfp, " torch.onnx._export(net, %s", input_names[0].c_str()); } else { fprintf(pyfp, " torch.onnx._export(net, ("); for (size_t i = 0; i < input_names.size(); i++) { fprintf(pyfp, "%s", input_names[i].c_str()); if (i + 1 != input_names.size()) fprintf(pyfp, ", "); } fprintf(pyfp, ")"); } fprintf(pyfp, ", \"%s.onnx\", export_params=True, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, opset_version=13", pypath.c_str()); fprintf(pyfp, ", input_names=["); { int input_count = 0; { for (const Operator* op : ops) { if (op->type == "pnnx.Input") input_count++; } } int input_index = 0; for (const Operator* op : ops) { if (op->type != "pnnx.Input") continue; fprintf(pyfp, "'in%d'", input_index); if (input_index + 1 != input_count) fprintf(pyfp, ", "); input_index++; } } fprintf(pyfp, "]"); fprintf(pyfp, ", output_names=["); { int output_count = 0; { for (const Operator* op : ops) { if (op->type == "pnnx.Output") output_count++; } } int output_index = 0; for (const Operator* op : ops) { if (op->type != "pnnx.Output") continue; fprintf(pyfp, "'out%d'", output_index); if (output_index + 1 != output_count) fprintf(pyfp, ", "); output_index++; } } fprintf(pyfp, "]"); fprintf(pyfp, ")\n"); } fprintf(pyfp, "\n"); // test inference { fprintf(pyfp, "def test_inference():\n"); fprintf(pyfp, " net = Model()\n"); fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " torch.manual_seed(0)\n"); std::vector input_names; for (const Operator* op : ops) { if (op->type != "pnnx.Input") continue; const Operand* r = op->outputs[0]; std::string input_name = std::string("v_") + sanitize_identifier(r->name); if (type_is_integer(r->type)) { fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); for (size_t i = 0; i < r->shape.size(); i++) { fprintf(pyfp, "%d", r->shape[i]); if (i + 1 != r->shape.size() || r->shape.size() == 1) fprintf(pyfp, ", "); } fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); } else { fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); for (size_t i = 0; i < r->shape.size(); i++) { fprintf(pyfp, "%d, ", r->shape[i]); } fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); } input_names.push_back(input_name); } fprintf(pyfp, "\n"); if (input_names.size() == 1) { fprintf(pyfp, " return net(%s)\n", input_names[0].c_str()); } else { fprintf(pyfp, " return net("); for (size_t i = 0; i < input_names.size(); i++) { fprintf(pyfp, "%s", input_names[i].c_str()); if (i + 1 != input_names.size()) fprintf(pyfp, ", "); } fprintf(pyfp, ")\n"); } } fclose(pyfp); return 0; } int Graph::parse(const std::string& param) { std::istringstream is(param); if (!is.good()) { fprintf(stderr, "open failed\n"); return -1; } int magic = 0; { std::string line; std::getline(is, line); std::istringstream iss(line); iss >> magic; } int operator_count = 0; int operand_count = 0; { std::string line; std::getline(is, line); std::istringstream iss(line); iss >> operator_count >> operand_count; } for (int i = 0; i < operator_count; i++) { std::string line; std::getline(is, line); std::istringstream iss(line); std::string type; std::string name; int input_count = 0; int output_count = 0; iss >> type >> name >> input_count >> output_count; Operator* op = new_operator(type, name); for (int j = 0; j < input_count; j++) { std::string operand_name; iss >> operand_name; Operand* r = get_operand(operand_name); r->consumers.push_back(op); op->inputs.push_back(r); } for (int j = 0; j < output_count; j++) { std::string operand_name; iss >> operand_name; Operand* r = new_operand(operand_name); r->producer = op; op->outputs.push_back(r); } // key=value while (!iss.eof()) { std::string param; iss >> param; std::string key; std::string value; std::istringstream pss(param); std::getline(pss, key, '='); std::getline(pss, value); if (key[0] == '@') { // attribute // load_attribute(op, key.substr(1), value, szr); op->attrs[key.substr(1)] = Attribute(); Attribute& attr = op->attrs[key.substr(1)]; attr.type = 0; if (value.empty()) continue; if (value[0] == '%') { // @data=%op1.data attr.data = std::vector(value.begin(), value.end()); } if (value[0] == '(') { // @data=(1,%c,?,4)f32 // type std::string typestr = value.substr(value.find_last_of(')') + 1); attr.type = string_to_type(typestr.c_str()); // shape std::string lc = value.substr(1, value.find_last_of(')') - 1); std::istringstream lcss(lc); attr.shape.clear(); while (!lcss.eof()) { std::string elem; std::getline(lcss, elem, ','); if (elem == "?") { attr.shape.push_back(-1); } else if (elem[0] == '%') { // encode %abc as symbolic tag attr.shape.push_back(-233); int index = attr.shape.size() - 1; std::string key = elem.substr(1); attr.params[std::string("__shape_") + std::to_string(index)] = key; } else { int i = std::stoi(elem); attr.shape.push_back(i); } } } } else if (key[0] == '$') { // operand input key load_input_key(op, key.substr(1), value); } else if (key[0] == '#') { // operand shape load_shape(op, key.substr(1), value); } else { // parameter load_parameter(op, key, value); } } } return 0; } void Operand::remove_consumer(const Operator* c) { auto it = std::find(consumers.begin(), consumers.end(), c); consumers.erase(it); } Operator* Graph::new_operator(const std::string& type, const std::string& name) { Operator* op = new Operator; op->type = type; op->name = name; ops.push_back(op); return op; } Operator* Graph::new_operator_before(const std::string& type, const std::string& name, const Operator* cur) { Operator* op = new Operator; op->type = type; op->name = name; ops.insert(std::find(ops.begin(), ops.end(), cur), op); return op; } Operator* Graph::new_operator_after(const std::string& type, const std::string& name, const Operator* cur) { Operator* op = new Operator; op->type = type; op->name = name; ops.insert(std::find(ops.begin(), ops.end(), cur) + 1, op); return op; } #if BUILD_PNNX Operand* Graph::new_operand(const torch::jit::Value* v) { Operand* r = new Operand; r->name = v->debugName(); r->type = -1; auto pt = v->type()->cast(); if (pt) { if (pt->scalarType().has_value() && pt->dim().has_value()) { r->type = get_at_tensor_type(pt->scalarType().value()); const int ndim = (int)pt->dim().value(); r->shape.resize(ndim); for (int i = 0; i < ndim; i++) { if (pt->sizes()[i].has_value()) r->shape[i] = (int)pt->sizes()[i].value(); else r->shape[i] = -1; } } } operands.push_back(r); return r; } #endif // BUILD_PNNX Operand* Graph::new_operand(const std::string& name) { Operand* r = new Operand; r->name = name; operands.push_back(r); return r; } Operand* Graph::get_operand(const std::string& name) { for (Operand* r : operands) { if (r->name == name) return r; } return 0; } const Operand* Graph::get_operand(const std::string& name) const { for (const Operand* r : operands) { if (r->name == name) return r; } return 0; } } // namespace pnnx