| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include "ir.h" |
| |
|
| | #include <limits.h> |
| | #include <stdint.h> |
| | #include <string.h> |
| | #include <algorithm> |
| | #include <fstream> |
| | #include <sstream> |
| | #include <string> |
| | #include <stack> |
| |
|
| | #if BUILD_PNNX |
| | #include <torch/script.h> |
| | #include <torch/csrc/api/include/torch/version.h> |
| | #endif |
| |
|
| | #include "storezip.h" |
| | #include "utils.h" |
| |
|
| | namespace pnnx { |
| |
|
| | static bool type_is_integer(int type) |
| | { |
| | if (type == 1) return false; |
| | if (type == 2) return false; |
| | if (type == 3) return false; |
| | if (type == 4) return true; |
| | if (type == 5) return true; |
| | if (type == 6) return true; |
| | if (type == 7) return true; |
| | if (type == 8) return true; |
| | if (type == 9) return true; |
| | if (type == 10) return false; |
| | if (type == 11) return false; |
| | if (type == 12) return false; |
| | return false; |
| | } |
| |
|
| | static const char* type_to_string(int type) |
| | { |
| | if (type == 1) return "f32"; |
| | if (type == 2) return "f64"; |
| | if (type == 3) return "f16"; |
| | if (type == 4) return "i32"; |
| | if (type == 5) return "i64"; |
| | if (type == 6) return "i16"; |
| | if (type == 7) return "i8"; |
| | if (type == 8) return "u8"; |
| | if (type == 9) return "bool"; |
| | if (type == 10) return "c64"; |
| | if (type == 11) return "c128"; |
| | if (type == 12) return "c32"; |
| | return "null"; |
| | } |
| |
|
| | static const char* type_to_numpy_string(int type) |
| | { |
| | if (type == 1) return "float32"; |
| | if (type == 2) return "float64"; |
| | if (type == 3) return "float16"; |
| | if (type == 4) return "int32"; |
| | if (type == 5) return "int64"; |
| | if (type == 6) return "int16"; |
| | if (type == 7) return "int8"; |
| | if (type == 8) return "uint8"; |
| | if (type == 9) return "bool8"; |
| | if (type == 10) return "csingle"; |
| | if (type == 11) return "cdouble"; |
| | if (type == 12) return "chalf"; |
| | return "null"; |
| | } |
| |
|
| | static const char* type_to_dtype_string(int type) |
| | { |
| | if (type == 1) return "torch.float"; |
| | if (type == 2) return "torch.double"; |
| | if (type == 3) return "torch.half"; |
| | if (type == 4) return "torch.int"; |
| | if (type == 5) return "torch.long"; |
| | if (type == 6) return "torch.short"; |
| | if (type == 7) return "torch.int8"; |
| | if (type == 8) return "torch.uint8"; |
| | if (type == 9) return "torch.bool"; |
| | if (type == 10) return "torch.complex64"; |
| | if (type == 11) return "torch.complex128"; |
| | if (type == 12) return "torch.complex32"; |
| | return "null"; |
| | } |
| |
|
| | static size_t type_to_elemsize(int type) |
| | { |
| | if (type == 1) return 4; |
| | if (type == 2) return 8; |
| | if (type == 3) return 2; |
| | if (type == 4) return 4; |
| | if (type == 5) return 8; |
| | if (type == 6) return 2; |
| | if (type == 7) return 1; |
| | if (type == 8) return 1; |
| | if (type == 9) return 1; |
| | if (type == 10) return 8; |
| | if (type == 11) return 16; |
| | if (type == 12) return 4; |
| | return 0; |
| | } |
| |
|
| | static int string_to_type(const char* s) |
| | { |
| | if (strcmp(s, "f32") == 0) return 1; |
| | if (strcmp(s, "f64") == 0) return 2; |
| | if (strcmp(s, "f16") == 0) return 3; |
| | if (strcmp(s, "i32") == 0) return 4; |
| | if (strcmp(s, "i64") == 0) return 5; |
| | if (strcmp(s, "i16") == 0) return 6; |
| | if (strcmp(s, "i8") == 0) return 7; |
| | if (strcmp(s, "u8") == 0) return 8; |
| | if (strcmp(s, "bool") == 0) return 9; |
| | if (strcmp(s, "c64") == 0) return 10; |
| | if (strcmp(s, "c128") == 0) return 11; |
| | if (strcmp(s, "c32") == 0) return 12; |
| | return 0; |
| | } |
| |
|
| | #if BUILD_PNNX |
| | int get_at_tensor_type(const at::ScalarType& st) |
| | { |
| | if (st == c10::ScalarType::Float) return 1; |
| | if (st == c10::ScalarType::Double) return 2; |
| | if (st == c10::ScalarType::Half) return 3; |
| | if (st == c10::ScalarType::Int) return 4; |
| | if (st == c10::ScalarType::QInt32) return 4; |
| | if (st == c10::ScalarType::Long) return 5; |
| | if (st == c10::ScalarType::Short) return 6; |
| | if (st == c10::ScalarType::Char) return 7; |
| | if (st == c10::ScalarType::QInt8) return 7; |
| | if (st == c10::ScalarType::Byte) return 8; |
| | if (st == c10::ScalarType::QUInt8) return 8; |
| | if (st == c10::ScalarType::Bool) return 9; |
| | if (st == c10::ScalarType::ComplexFloat) return 10; |
| | if (st == c10::ScalarType::ComplexDouble) return 11; |
| | if (st == c10::ScalarType::ComplexHalf) return 12; |
| | return 0; |
| | } |
| |
|
| | Parameter::Parameter(const torch::jit::Node* value_node) |
| | { |
| | type = 0; |
| |
|
| | if (value_node->kind() == c10::prim::Constant) |
| | { |
| | if (value_node->output()->type()->kind() == c10::TypeKind::NoneType) |
| | { |
| | type = 0; |
| | return; |
| | } |
| |
|
| | if (!value_node->hasAttribute(torch::jit::attr::value)) |
| | { |
| | fprintf(stderr, "no attribute value\n"); |
| | value_node->dump(); |
| | return; |
| | } |
| |
|
| | switch (value_node->output()->type()->kind()) |
| | { |
| | case c10::TypeKind::NoneType: |
| | { |
| | type = 0; |
| | break; |
| | } |
| | case c10::TypeKind::BoolType: |
| | { |
| | type = 1; |
| | b = value_node->i(torch::jit::attr::value); |
| | break; |
| | } |
| | case c10::TypeKind::IntType: |
| | { |
| | type = 2; |
| | int64_t i64 = value_node->i(torch::jit::attr::value); |
| | if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX; |
| | if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN; |
| | i = (int)i64; |
| | break; |
| | } |
| | case c10::TypeKind::FloatType: |
| | { |
| | type = 3; |
| | f = (float)value_node->f(torch::jit::attr::value); |
| | break; |
| | } |
| | case c10::TypeKind::StringType: |
| | { |
| | type = 4; |
| | s = value_node->s(torch::jit::attr::value); |
| | break; |
| | } |
| | case c10::TypeKind::DeviceObjType: |
| | { |
| | type = 4; |
| | s = value_node->s(torch::jit::attr::value); |
| | break; |
| | } |
| | #if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 9) |
| | case c10::TypeKind::ComplexType: |
| | { |
| | type = 10; |
| | c = std::complex<float>(value_node->c(torch::jit::attr::value)); |
| | break; |
| | } |
| | #endif |
| | case c10::TypeKind::TensorType: |
| | { |
| | at::Tensor t = value_node->t(torch::jit::attr::value); |
| |
|
| | if (t.dim() == 0) |
| | { |
| | if (t.scalar_type() == c10::ScalarType::Long) |
| | { |
| | type = 2; |
| | int64_t i64 = t.item<int64_t>(); |
| | if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX; |
| | if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN; |
| | i = (int)i64; |
| | } |
| | else if (t.scalar_type() == c10::ScalarType::Int) |
| | { |
| | type = 2; |
| | i = t.item<int>(); |
| | } |
| | else if (t.scalar_type() == c10::ScalarType::Double) |
| | { |
| | type = 3; |
| | f = (float)t.item<double>(); |
| | } |
| | else if (t.scalar_type() == c10::ScalarType::Float) |
| | { |
| | type = 3; |
| | f = t.item<float>(); |
| | } |
| | else if (t.scalar_type() == c10::ScalarType::ComplexDouble) |
| | { |
| | type = 10; |
| | c = std::complex<float>(t.item<c10::complex<double> >()); |
| | } |
| | else if (t.scalar_type() == c10::ScalarType::ComplexFloat) |
| | { |
| | type = 10; |
| | c = std::complex<float>(t.item<c10::complex<float> >()); |
| | } |
| | else |
| | { |
| | fprintf(stderr, "unknown Parameter value kind %s of TensorType, t.dim = 0\n", value_node->kind().toDisplayString()); |
| | } |
| | } |
| | else |
| | { |
| | |
| | type = 8; |
| | } |
| |
|
| | break; |
| | } |
| | default: |
| | { |
| | fprintf(stderr, "unknown Parameter value kind %s\n", c10::typeKindToString(value_node->output()->type()->kind())); |
| | break; |
| | } |
| | } |
| | } |
| | else if (value_node->kind() == c10::prim::ListConstruct) |
| | { |
| | switch (value_node->output()->type()->cast<c10::ListType>()->getElementType()->kind()) |
| | { |
| | case c10::TypeKind::IntType: |
| | { |
| | type = 5; |
| | for (const auto& x : value_node->inputs()) |
| | { |
| | ai.push_back((int)x->node()->i(torch::jit::attr::value)); |
| | } |
| | break; |
| | } |
| | case c10::TypeKind::FloatType: |
| | { |
| | type = 6; |
| | for (const auto& x : value_node->inputs()) |
| | { |
| | af.push_back((float)x->node()->f(torch::jit::attr::value)); |
| | } |
| | break; |
| | } |
| | case c10::TypeKind::StringType: |
| | { |
| | type = 7; |
| | for (const auto& x : value_node->inputs()) |
| | { |
| | as.push_back(x->node()->s(torch::jit::attr::value)); |
| | } |
| | break; |
| | } |
| | #if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 9) |
| | case c10::TypeKind::ComplexType: |
| | { |
| | type = 11; |
| | for (const auto& x : value_node->inputs()) |
| | { |
| | ac.push_back(std::complex<float>(x->node()->c(torch::jit::attr::value))); |
| | } |
| | break; |
| | } |
| | #endif |
| | default: |
| | { |
| | fprintf(stderr, "unknown Parameter value list element kind %s\n", c10::typeKindToString(value_node->output()->type()->cast<c10::ListType>()->getElementType()->kind())); |
| | break; |
| | } |
| | } |
| | } |
| | else |
| | { |
| | fprintf(stderr, "unknown Parameter value_node kind %s\n", value_node->kind().toDisplayString()); |
| | } |
| | } |
| |
|
| | Parameter::Parameter(const torch::jit::Value* value) |
| | : Parameter(value->node()) |
| | { |
| | } |
| | #endif |
| |
|
| | bool operator==(const Parameter& lhs, const Parameter& rhs) |
| | { |
| | if (lhs.type != rhs.type) |
| | return false; |
| |
|
| | if (lhs.type == 0) |
| | return true; |
| |
|
| | if (lhs.type == 1 && lhs.b == rhs.b) |
| | return true; |
| |
|
| | if (lhs.type == 2 && lhs.i == rhs.i) |
| | return true; |
| |
|
| | if (lhs.type == 3 && lhs.f == rhs.f) |
| | return true; |
| |
|
| | if (lhs.type == 4 && lhs.s == rhs.s) |
| | return true; |
| |
|
| | if (lhs.type == 5 && lhs.ai == rhs.ai) |
| | return true; |
| |
|
| | if (lhs.type == 6 && lhs.af == rhs.af) |
| | return true; |
| |
|
| | if (lhs.type == 7 && lhs.as == rhs.as) |
| | return true; |
| |
|
| | if (lhs.type == 10 && lhs.c == rhs.c) |
| | return true; |
| |
|
| | if (lhs.type == 11 && lhs.ac == rhs.ac) |
| | return true; |
| |
|
| | return false; |
| | } |
| |
|
| | #if BUILD_PNNX |
| | Attribute::Attribute(const at::Tensor& t) |
| | { |
| | type = get_at_tensor_type(t.scalar_type()); |
| |
|
| | const int ndim = (int)t.dim(); |
| |
|
| | if (ndim == 0) |
| | { |
| | shape = {1}; |
| |
|
| | data.resize(type_to_elemsize(type)); |
| |
|
| | if (t.scalar_type() == c10::ScalarType::Long) |
| | { |
| | int64_t i = t.item<int64_t>(); |
| | memcpy((void*)data.data(), (const void*)&i, data.size()); |
| | } |
| | else if (t.scalar_type() == c10::ScalarType::Int) |
| | { |
| | int i = t.item<int>(); |
| | memcpy((void*)data.data(), (const void*)&i, data.size()); |
| | } |
| | else if (t.scalar_type() == c10::ScalarType::Double) |
| | { |
| | double f = t.item<double>(); |
| | memcpy((void*)data.data(), (const void*)&f, data.size()); |
| | } |
| | else if (t.scalar_type() == c10::ScalarType::Float) |
| | { |
| | float f = t.item<float>(); |
| | memcpy((void*)data.data(), (const void*)&f, data.size()); |
| | } |
| | else |
| | { |
| | fprintf(stderr, "unknown Attribute tensor scalar type %d\n", type); |
| | } |
| |
|
| | return; |
| | } |
| |
|
| | shape.resize(ndim); |
| | for (int i = 0; i < ndim; i++) |
| | shape[i] = t.size(i); |
| |
|
| | if (shape.size() > 0) |
| | { |
| | data.resize(elemcount() * type_to_elemsize(type)); |
| | memcpy((void*)data.data(), (const void*)t.cpu().contiguous().data_ptr(), data.size()); |
| | } |
| | } |
| | #endif |
| |
|
| | Attribute::Attribute(const std::initializer_list<int>& _shape, const std::vector<float>& t) |
| | { |
| | type = 1; |
| | shape = _shape; |
| |
|
| | if (shape.size() > 0) |
| | { |
| | data.resize(elemcount() * type_to_elemsize(type)); |
| | memcpy((void*)data.data(), (const void*)t.data(), data.size()); |
| | } |
| | } |
| |
|
| | size_t Attribute::elemsize() const |
| | { |
| | return type_to_elemsize(type); |
| | } |
| |
|
| | int Attribute::elemcount() const |
| | { |
| | if (shape.empty()) |
| | return 0; |
| |
|
| | int size = shape[0]; |
| | for (size_t i = 1; i < shape.size(); i++) |
| | { |
| | size *= shape[i]; |
| | } |
| |
|
| | return size; |
| | } |
| |
|
| | std::vector<float> Attribute::get_float32_data() const |
| | { |
| | std::vector<float> v(elemcount()); |
| |
|
| | if (type == 1) |
| | { |
| | memcpy((void*)v.data(), (const void*)data.data(), data.size()); |
| | } |
| | else if (type == 2) |
| | { |
| | |
| | const double* p = (const double*)data.data(); |
| | for (size_t i = 0; i < v.size(); i++) |
| | { |
| | v[i] = float(p[i]); |
| | } |
| | } |
| | else if (type == 3) |
| | { |
| | |
| | const unsigned short* p = (const unsigned short*)data.data(); |
| | for (size_t i = 0; i < v.size(); i++) |
| | { |
| | v[i] = float16_to_float32(p[i]); |
| | } |
| | } |
| | else |
| | { |
| | fprintf(stderr, "cannot convert type %d to float32 data\n", type); |
| | } |
| |
|
| | return v; |
| | } |
| |
|
| | void Attribute::set_float32_data(const std::vector<float>& newdata) |
| | { |
| | data.resize(newdata.size() * elemsize()); |
| |
|
| | if (type == 1) |
| | { |
| | memcpy((void*)data.data(), (const void*)newdata.data(), data.size()); |
| | } |
| | else if (type == 2) |
| | { |
| | |
| | double* p = (double*)data.data(); |
| | for (size_t i = 0; i < newdata.size(); i++) |
| | { |
| | p[i] = newdata[i]; |
| | } |
| | } |
| | else if (type == 3) |
| | { |
| | |
| | unsigned short* p = (unsigned short*)data.data(); |
| | for (size_t i = 0; i < newdata.size(); i++) |
| | { |
| | p[i] = float32_to_float16(newdata[i]); |
| | } |
| | } |
| | else |
| | { |
| | fprintf(stderr, "cannot convert float32 data to type %d\n", type); |
| | } |
| | } |
| |
|
| | bool operator==(const Attribute& lhs, const Attribute& rhs) |
| | { |
| | if (lhs.type != rhs.type) |
| | return false; |
| |
|
| | if (lhs.type == 0) |
| | return true; |
| |
|
| | if (lhs.shape != rhs.shape) |
| | return false; |
| |
|
| | if (lhs.data != rhs.data) |
| | return false; |
| |
|
| | return true; |
| | } |
| |
|
| | Attribute operator+(const Attribute& a, const Attribute& b) |
| | { |
| | Attribute c; |
| |
|
| | if (a.type != b.type) |
| | { |
| | fprintf(stderr, "concat attribute type mismatch\n"); |
| | return c; |
| | } |
| |
|
| | if (a.shape.size() != b.shape.size()) |
| | { |
| | fprintf(stderr, "concat attribute shape rank mismatch\n"); |
| | return c; |
| | } |
| |
|
| | for (int i = 1; i < (int)a.shape.size(); i++) |
| | { |
| | if (a.shape[i] != b.shape[i]) |
| | { |
| | fprintf(stderr, "concat attribute shape mismatch\n"); |
| | return c; |
| | } |
| | } |
| |
|
| | c.type = a.type; |
| | c.shape = a.shape; |
| | c.shape[0] += b.shape[0]; |
| |
|
| | c.data.resize(a.data.size() + b.data.size()); |
| | memcpy(c.data.data(), a.data.data(), a.data.size()); |
| | memcpy(c.data.data() + a.data.size(), b.data.data(), b.data.size()); |
| |
|
| | return c; |
| | } |
| |
|
| | Parameter Parameter::parse_from_string(const std::string& value) |
| | { |
| | if (value.find('%') != std::string::npos) |
| | { |
| | Parameter p; |
| | p.type = 4; |
| | p.s = value; |
| | return p; |
| | } |
| |
|
| | Parameter p; |
| | p.type = 0; |
| |
|
| | if (value == "None" || value == "()" || value == "[]") |
| | { |
| | return p; |
| | } |
| |
|
| | if (value == "True" || value == "False") |
| | { |
| | |
| | p.type = 1; |
| | p.b = value == "True"; |
| | return p; |
| | } |
| |
|
| | if (value[0] == '(' || value[0] == '[') |
| | { |
| | |
| | std::string lc = value.substr(1, value.size() - 2); |
| | std::istringstream lcss(lc); |
| |
|
| | while (!lcss.eof()) |
| | { |
| | std::string elem; |
| | std::getline(lcss, elem, ','); |
| |
|
| | if ((elem[0] != '-' && (elem[0] < '0' || elem[0] > '9')) || (elem[0] == '-' && (elem[1] < '0' || elem[1] > '9'))) |
| | { |
| | |
| | p.type = 7; |
| | p.as.push_back(elem); |
| | } |
| | else if (elem.find('.') != std::string::npos || elem.find('e') != std::string::npos) |
| | { |
| | |
| | p.type = 6; |
| | p.af.push_back(std::stof(elem)); |
| | } |
| | else |
| | { |
| | |
| | p.type = 5; |
| | p.ai.push_back(std::stoi(elem)); |
| | } |
| | } |
| | return p; |
| | } |
| |
|
| | if ((value[0] != '-' && (value[0] < '0' || value[0] > '9')) || (value[0] == '-' && (value[1] < '0' || value[1] > '9'))) |
| | { |
| | |
| | p.type = 4; |
| | p.s = value; |
| | return p; |
| | } |
| |
|
| | if (value.find('.') != std::string::npos || value.find('e') != std::string::npos) |
| | { |
| | |
| | p.type = 3; |
| | p.f = std::stof(value); |
| | return p; |
| | } |
| |
|
| | |
| | p.type = 2; |
| | p.i = std::stoi(value); |
| | return p; |
| | } |
| |
|
| | std::string Parameter::encode_to_string(const Parameter& param) |
| | { |
| | if (param.type == 0) |
| | { |
| | return std::string("None"); |
| | } |
| | if (param.type == 1) |
| | { |
| | if (param.b) |
| | return std::string("True"); |
| | else |
| | return std::string("False"); |
| | } |
| | if (param.type == 2) |
| | { |
| | return std::to_string(param.i); |
| | } |
| | if (param.type == 3) |
| | { |
| | char buf[64]; |
| | sprintf(buf, "%e", param.f); |
| | return std::string(buf); |
| | } |
| | if (param.type == 4) |
| | { |
| | return param.s; |
| | } |
| | if (param.type == 5) |
| | { |
| | std::string s("("); |
| | for (size_t i = 0; i < param.ai.size(); i++) |
| | { |
| | s += std::to_string(param.ai[i]); |
| | if (i + 1 != param.ai.size()) |
| | s += std::string(","); |
| | } |
| | s += std::string(")"); |
| | return s; |
| | } |
| | if (param.type == 6) |
| | { |
| | std::string s("("); |
| | for (size_t i = 0; i < param.af.size(); i++) |
| | { |
| | char buf[64]; |
| | sprintf(buf, "%e", param.af[i]); |
| | s += std::string(buf); |
| | if (i + 1 != param.af.size()) |
| | s += std::string(","); |
| | } |
| | s += std::string(")"); |
| | return s; |
| | } |
| | if (param.type == 7) |
| | { |
| | std::string s("("); |
| | for (size_t i = 0; i < param.as.size(); i++) |
| | { |
| | s += param.as[i]; |
| | if (i + 1 != param.as.size()) |
| | s += std::string(","); |
| | } |
| | s += std::string(")"); |
| | return s; |
| | } |
| | if (param.type == 10) |
| | { |
| | char buf[128]; |
| | sprintf(buf, "%e+%ej", param.c.real(), param.c.imag()); |
| | return std::string(buf); |
| | } |
| | if (param.type == 11) |
| | { |
| | std::string s("("); |
| | for (size_t i = 0; i < param.ac.size(); i++) |
| | { |
| | char buf[128]; |
| | sprintf(buf, "%e+%ej", param.ac[i].real(), param.ac[i].imag()); |
| | s += std::string(buf); |
| | if (i + 1 != param.ac.size()) |
| | s += std::string(","); |
| | } |
| | s += std::string(")"); |
| | return s; |
| | } |
| |
|
| | fprintf(stderr, "unknown parameter type %d\n", param.type); |
| | return std::string(); |
| | } |
| |
|
| | Graph::Graph() |
| | { |
| | } |
| |
|
| | Graph::~Graph() |
| | { |
| | for (auto x : ops) |
| | delete x; |
| |
|
| | for (auto x : operands) |
| | delete x; |
| |
|
| | ops.clear(); |
| | operands.clear(); |
| | } |
| |
|
| | Graph::Graph(const Graph& ) |
| | { |
| | } |
| |
|
| | Graph& Graph::operator=(const Graph& ) |
| | { |
| | return *this; |
| | } |
| |
|
| | static void load_parameter(Operator* op, const std::string& key, const std::string& value) |
| | { |
| | op->params[key] = Parameter::parse_from_string(value); |
| | } |
| |
|
| | static void load_input_key(Operator* op, const std::string& key, const std::string& value) |
| | { |
| | op->inputnames.resize(op->inputs.size()); |
| |
|
| | for (size_t i = 0; i < op->inputs.size(); i++) |
| | { |
| | const Operand* oprand = op->inputs[i]; |
| | if (oprand->name == value) |
| | { |
| | op->inputnames[i] = key; |
| | break; |
| | } |
| | } |
| | } |
| |
|
| | static void load_shape(Operator* op, const std::string& key, const std::string& value) |
| | { |
| | Operand* operand = 0; |
| | for (auto r : op->inputs) |
| | { |
| | if (r->name == key) |
| | { |
| | operand = r; |
| | break; |
| | } |
| | } |
| |
|
| | if (!operand) |
| | { |
| | for (auto r : op->outputs) |
| | { |
| | if (r->name == key) |
| | { |
| | operand = r; |
| | break; |
| | } |
| | } |
| | } |
| |
|
| | if (!operand) |
| | { |
| | fprintf(stderr, "no such operand %s for operator %s\n", key.c_str(), op->name.c_str()); |
| | return; |
| | } |
| |
|
| | |
| | std::string typestr = value.substr(value.find_last_of(')') + 1); |
| | operand->type = string_to_type(typestr.c_str()); |
| |
|
| | |
| | std::string lc = value.substr(1, value.find_last_of(')') - 1); |
| | std::istringstream lcss(lc); |
| |
|
| | operand->shape.clear(); |
| | while (!lcss.eof()) |
| | { |
| | std::string elem; |
| | std::getline(lcss, elem, ','); |
| |
|
| | if (elem == "?") |
| | { |
| | operand->shape.push_back(-1); |
| | } |
| | else if (elem[0] == '%') |
| | { |
| | |
| | operand->shape.push_back(-233); |
| | int index = operand->shape.size() - 1; |
| | std::string key = elem.substr(1); |
| | operand->params[std::string("__shape_") + std::to_string(index)] = key; |
| | } |
| | else |
| | { |
| | int i = std::stoi(elem); |
| | operand->shape.push_back(i); |
| | } |
| | } |
| | } |
| |
|
| | static void load_attribute(Operator* op, const std::string& key, const std::string& value, StoreZipReader& szr) |
| | { |
| | Attribute& a = op->attrs[key]; |
| |
|
| | |
| | std::string typestr = value.substr(value.find_last_of(')') + 1); |
| | a.type = string_to_type(typestr.c_str()); |
| |
|
| | if (a.type == 0) |
| | return; |
| |
|
| | |
| | std::string lc = value.substr(1, value.find_last_of(')') - 1); |
| | std::istringstream lcss(lc); |
| |
|
| | a.shape.clear(); |
| | while (!lcss.eof()) |
| | { |
| | std::string elem; |
| | std::getline(lcss, elem, ','); |
| |
|
| | int i = std::stoi(elem); |
| | a.shape.push_back(i); |
| | } |
| |
|
| | if (a.shape.empty()) |
| | return; |
| |
|
| | |
| | size_t size = 1; |
| | for (int i : a.shape) |
| | { |
| | size *= i; |
| | } |
| |
|
| | size_t bytesize = size * type_to_elemsize(a.type); |
| |
|
| | std::string filename = op->name + "." + key; |
| |
|
| | size_t filesize = szr.get_file_size(filename); |
| |
|
| | if (filesize == 0) |
| | { |
| | |
| | return; |
| | } |
| |
|
| | if (filesize != bytesize) |
| | { |
| | fprintf(stderr, "file size not match expect %lu but got %lu\n", bytesize, filesize); |
| | } |
| |
|
| | a.data.resize(bytesize); |
| | szr.read_file(filename, (char*)a.data.data()); |
| | } |
| |
|
| | int Graph::load(const std::string& parampath, const std::string& binpath) |
| | { |
| | std::ifstream is(parampath, std::ios::in | std::ios::binary); |
| | if (!is.good()) |
| | { |
| | fprintf(stderr, "open failed\n"); |
| | return -1; |
| | } |
| |
|
| | StoreZipReader szr; |
| | if (szr.open(binpath) != 0) |
| | { |
| | fprintf(stderr, "open failed\n"); |
| | return -1; |
| | } |
| |
|
| | int magic = 0; |
| | { |
| | std::string line; |
| | std::getline(is, line); |
| | std::istringstream iss(line); |
| |
|
| | iss >> magic; |
| | } |
| |
|
| | int operator_count = 0; |
| | int operand_count = 0; |
| | { |
| | std::string line; |
| | std::getline(is, line); |
| | std::istringstream iss(line); |
| |
|
| | iss >> operator_count >> operand_count; |
| | } |
| |
|
| | for (int i = 0; i < operator_count; i++) |
| | { |
| | std::string line; |
| | std::getline(is, line); |
| | std::istringstream iss(line); |
| |
|
| | std::string type; |
| | std::string name; |
| | int input_count = 0; |
| | int output_count = 0; |
| |
|
| | iss >> type >> name >> input_count >> output_count; |
| |
|
| | Operator* op = new_operator(type, name); |
| |
|
| | for (int j = 0; j < input_count; j++) |
| | { |
| | std::string operand_name; |
| | iss >> operand_name; |
| |
|
| | Operand* r = get_operand(operand_name); |
| | r->consumers.push_back(op); |
| | op->inputs.push_back(r); |
| | } |
| |
|
| | for (int j = 0; j < output_count; j++) |
| | { |
| | std::string operand_name; |
| | iss >> operand_name; |
| |
|
| | Operand* r = new_operand(operand_name); |
| | r->producer = op; |
| | op->outputs.push_back(r); |
| | } |
| |
|
| | |
| | while (!iss.eof()) |
| | { |
| | std::string param; |
| | iss >> param; |
| |
|
| | std::string key; |
| | std::string value; |
| | std::istringstream pss(param); |
| | std::getline(pss, key, '='); |
| | std::getline(pss, value); |
| |
|
| | if (key[0] == '@') |
| | { |
| | |
| | load_attribute(op, key.substr(1), value, szr); |
| | } |
| | else if (key[0] == '$') |
| | { |
| | |
| | load_input_key(op, key.substr(1), value); |
| | } |
| | else if (key[0] == '#') |
| | { |
| | |
| | load_shape(op, key.substr(1), value); |
| | } |
| | else |
| | { |
| | |
| | load_parameter(op, key, value); |
| | } |
| | } |
| | } |
| |
|
| | return 0; |
| | } |
| |
|
| | int Graph::save(const std::string& parampath, const std::string& binpath) |
| | { |
| | FILE* paramfp = fopen(parampath.c_str(), "wb"); |
| | if (!paramfp) |
| | { |
| | fprintf(stderr, "fopen %s failed\n", parampath.c_str()); |
| | return -1; |
| | } |
| |
|
| | StoreZipWriter szw; |
| | if (szw.open(binpath) != 0) |
| | { |
| | fprintf(stderr, "open failed\n"); |
| | return -1; |
| | } |
| |
|
| | |
| | fprintf(paramfp, "7767517\n"); |
| |
|
| | |
| | fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size()); |
| |
|
| | for (const Operator* op : ops) |
| | { |
| | fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size()); |
| |
|
| | for (const Operand* oprand : op->inputs) |
| | { |
| | fprintf(paramfp, " %s", oprand->name.c_str()); |
| | } |
| |
|
| | for (const Operand* oprand : op->outputs) |
| | { |
| | fprintf(paramfp, " %s", oprand->name.c_str()); |
| | } |
| |
|
| | for (const auto& it : op->params) |
| | { |
| | fprintf(paramfp, " %s=", it.first.c_str()); |
| |
|
| | const Parameter& param = it.second; |
| | std::string s = Parameter::encode_to_string(param); |
| | fprintf(paramfp, "%s", s.c_str()); |
| | } |
| |
|
| | for (const auto& it : op->attrs) |
| | { |
| | fprintf(paramfp, " @%s=", it.first.c_str()); |
| |
|
| | const Attribute& attr = it.second; |
| | fprintf(paramfp, "("); |
| | for (int i = 0; i < (int)attr.shape.size() - 1; i++) |
| | { |
| | fprintf(paramfp, "%d,", attr.shape[i]); |
| | } |
| | if (attr.shape.size() > 0) |
| | fprintf(paramfp, "%d", attr.shape[attr.shape.size() - 1]); |
| | fprintf(paramfp, ")"); |
| |
|
| | fprintf(paramfp, type_to_string(attr.type)); |
| |
|
| | std::string filename = op->name + "." + it.first; |
| | szw.write_file(filename, attr.data.data(), attr.data.size()); |
| | } |
| |
|
| | if (op->inputnames.size() == op->inputs.size()) |
| | { |
| | for (size_t i = 0; i < op->inputs.size(); i++) |
| | { |
| | if (op->inputnames[i].empty()) |
| | continue; |
| |
|
| | const Operand* oprand = op->inputs[i]; |
| | fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str()); |
| | } |
| | } |
| |
|
| | for (const Operand* oprand : op->inputs) |
| | { |
| | if (oprand->shape.empty()) |
| | continue; |
| |
|
| | fprintf(paramfp, " #%s=", oprand->name.c_str()); |
| |
|
| | fprintf(paramfp, "("); |
| | for (int i = 0; i < (int)oprand->shape.size() - 1; i++) |
| | { |
| | if (oprand->shape[i] == -1) |
| | fprintf(paramfp, "?,"); |
| | else |
| | fprintf(paramfp, "%d,", oprand->shape[i]); |
| | } |
| | if (oprand->shape.size() > 0) |
| | { |
| | if (oprand->shape[oprand->shape.size() - 1] == -1) |
| | fprintf(paramfp, "?"); |
| | else |
| | fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); |
| | } |
| | fprintf(paramfp, ")"); |
| |
|
| | fprintf(paramfp, type_to_string(oprand->type)); |
| | } |
| |
|
| | for (const Operand* oprand : op->outputs) |
| | { |
| | if (oprand->shape.empty()) |
| | continue; |
| |
|
| | fprintf(paramfp, " #%s=", oprand->name.c_str()); |
| |
|
| | fprintf(paramfp, "("); |
| | for (int i = 0; i < (int)oprand->shape.size() - 1; i++) |
| | { |
| | if (oprand->shape[i] == -1) |
| | fprintf(paramfp, "?,"); |
| | else |
| | fprintf(paramfp, "%d,", oprand->shape[i]); |
| | } |
| | if (oprand->shape.size() > 0) |
| | { |
| | if (oprand->shape[oprand->shape.size() - 1] == -1) |
| | fprintf(paramfp, "?"); |
| | else |
| | fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); |
| | } |
| | fprintf(paramfp, ")"); |
| |
|
| | fprintf(paramfp, type_to_string(oprand->type)); |
| | } |
| |
|
| | fprintf(paramfp, "\n"); |
| | } |
| |
|
| | fclose(paramfp); |
| |
|
| | return 0; |
| | } |
| |
|
| | static std::string sanitize_identifier(const std::string& s) |
| | { |
| | std::string ss = s; |
| | for (size_t i = 0; i < ss.size(); i++) |
| | { |
| | if (ss[i] == '.' || ss[i] == ':') |
| | ss[i] = '_'; |
| | } |
| |
|
| | return ss; |
| | } |
| |
|
| | static std::string expand_expression(const Operator* op) |
| | { |
| | std::string expr = op->params.at("expr").s; |
| |
|
| | |
| | std::vector<std::string> tokens; |
| | { |
| | std::string t; |
| | for (size_t i = 0; i < expr.size(); i++) |
| | { |
| | char ch = expr[i]; |
| |
|
| | if (ch == '[') |
| | { |
| | t += ch; |
| | tokens.push_back(t); |
| | t.clear(); |
| | } |
| | else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') |
| | { |
| | if (!t.empty()) |
| | { |
| | tokens.push_back(t); |
| | t.clear(); |
| | } |
| | } |
| | else |
| | { |
| | t += ch; |
| | } |
| | } |
| |
|
| | if (!t.empty()) |
| | { |
| | tokens.push_back(t); |
| | } |
| | } |
| |
|
| | |
| | std::stack<std::string> exprstack; |
| | for (int i = (int)tokens.size() - 1; i >= 0; i--) |
| | { |
| | const std::string& t = tokens[i]; |
| |
|
| | if (t == "size") |
| | { |
| | std::string a = exprstack.top(); |
| | exprstack.pop(); |
| | std::string b = exprstack.top(); |
| | exprstack.pop(); |
| |
|
| | std::string r = a + ".size(" + b + ")"; |
| | exprstack.push(r); |
| | } |
| | else if (t == "int" |
| | || t == "abs" |
| | || t == "acos" |
| | || t == "acosh" |
| | || t == "asin" |
| | || t == "asinh" |
| | || t == "atan" |
| | || t == "atanh" |
| | || t == "ceil" |
| | || t == "cos" |
| | || t == "cosh" |
| | || t == "exp" |
| | || t == "floor" |
| | || t == "log" |
| | || t == "log10" |
| | || t == "neg" |
| | || t == "reciprocal" |
| | || t == "round" |
| | || t == "rsqrt" |
| | || t == "sign" |
| | || t == "sin" |
| | || t == "sinh" |
| | || t == "sqrt" |
| | || t == "square" |
| | || t == "tan" |
| | || t == "tanh" |
| | || t == "trunc") |
| | { |
| | std::string unaryop; |
| | if (t == "int") unaryop = "int"; |
| | if (t == "abs") unaryop = "torch.abs"; |
| | if (t == "acos") unaryop = "torch.acos"; |
| | if (t == "acosh") unaryop = "torch.acosh"; |
| | if (t == "asin") unaryop = "torch.asin"; |
| | if (t == "asinh") unaryop = "torch.asinh"; |
| | if (t == "atan") unaryop = "torch.atan"; |
| | if (t == "atanh") unaryop = "torch.atanh"; |
| | if (t == "ceil") unaryop = "torch.ceil"; |
| | if (t == "cos") unaryop = "torch.cos"; |
| | if (t == "cosh") unaryop = "torch.cosh"; |
| | if (t == "exp") unaryop = "torch.exp"; |
| | if (t == "floor") unaryop = "torch.floor"; |
| | if (t == "log") unaryop = "torch.log"; |
| | if (t == "log10") unaryop = "torch.log10"; |
| | if (t == "neg") unaryop = "torch.neg"; |
| | if (t == "reciprocal") unaryop = "torch.reciprocal"; |
| | if (t == "round") unaryop = "torch.round"; |
| | if (t == "rsqrt") unaryop = "torch.rsqrt"; |
| | if (t == "sign") unaryop = "torch.sign"; |
| | if (t == "sin") unaryop = "torch.sin"; |
| | if (t == "sinh") unaryop = "torch.sinh"; |
| | if (t == "sqrt") unaryop = "torch.sqrt"; |
| | if (t == "square") unaryop = "torch.square"; |
| | if (t == "tan") unaryop = "torch.tan"; |
| | if (t == "tanh") unaryop = "torch.tanh"; |
| | if (t == "trunc") unaryop = "torch.trunc"; |
| |
|
| | std::string a = exprstack.top(); |
| | exprstack.pop(); |
| |
|
| | std::string r = unaryop + "(" + a + ")"; |
| | exprstack.push(r); |
| | } |
| | else if (t == "atan2" |
| | || t == "fmod" |
| | || t == "max" |
| | || t == "maximum" |
| | || t == "min" |
| | || t == "minimum" |
| | || t == "pow") |
| | { |
| | std::string binaryop; |
| | if (t == "atan2") binaryop = "torch.atan2"; |
| | if (t == "fmod") binaryop = "torch.fmod"; |
| | if (t == "max") binaryop = "torch.max"; |
| | if (t == "maximum") binaryop = "torch.maximum"; |
| | if (t == "min") binaryop = "torch.min"; |
| | if (t == "minimum") binaryop = "torch.minimum"; |
| | if (t == "pow") binaryop = "torch.pow"; |
| |
|
| | std::string a = exprstack.top(); |
| | exprstack.pop(); |
| | std::string b = exprstack.top(); |
| | exprstack.pop(); |
| |
|
| | std::string r = binaryop + "(" + a + ", " + b + ")"; |
| | exprstack.push(r); |
| | } |
| | else if (t == "add" |
| | || t == "sub" |
| | || t == "mul" |
| | || t == "div" |
| | || t == "floor_divide" |
| | || t == "remainder" |
| | || t == "and" |
| | || t == "or" |
| | || t == "xor" |
| | || t == "lshift" |
| | || t == "rshift") |
| | { |
| | std::string binaryop; |
| | if (t == "add") binaryop = "+"; |
| | if (t == "sub") binaryop = "-"; |
| | if (t == "mul") binaryop = "*"; |
| | if (t == "div") binaryop = "/"; |
| | if (t == "floor_divide") binaryop = "//"; |
| | if (t == "remainder") binaryop = "%"; |
| | if (t == "and") binaryop = "&"; |
| | if (t == "or") binaryop = "|"; |
| | if (t == "xor") binaryop = "^"; |
| | if (t == "lshift") binaryop = "<<"; |
| | if (t == "rshift") binaryop = ">>"; |
| |
|
| | std::string a = exprstack.top(); |
| | exprstack.pop(); |
| | std::string b = exprstack.top(); |
| | exprstack.pop(); |
| |
|
| | std::string r = std::string("(") + a + " " + binaryop + " " + b + ")"; |
| | exprstack.push(r); |
| | } |
| | else if (t == "[") |
| | { |
| | std::vector<std::string> elements; |
| | while (!exprstack.empty()) |
| | { |
| | std::string a = exprstack.top(); |
| | exprstack.pop(); |
| |
|
| | elements.push_back(a); |
| | } |
| |
|
| | std::string r = "["; |
| | for (int j = 0; j < (int)elements.size() - 1; j++) |
| | { |
| | r += elements[j]; |
| | if (j + 1 != (int)elements.size()) |
| | r += ", "; |
| | } |
| | if (!elements.empty()) |
| | { |
| | r += elements[elements.size() - 1]; |
| | } |
| | r += "]"; |
| |
|
| | exprstack.push(r); |
| | } |
| | else if (t[0] == '@') |
| | { |
| | int input_index = std::stoi(t.substr(1)); |
| | std::string varid = std::string("v_") + sanitize_identifier(op->inputs[input_index]->name); |
| | exprstack.push(varid); |
| | } |
| | else |
| | { |
| | |
| | if (t[t.size() - 1] == 'j') |
| | { |
| | |
| | std::string r = std::string("(") + t + ")"; |
| | exprstack.push(r); |
| | } |
| | else |
| | { |
| | exprstack.push(t); |
| | } |
| | } |
| | } |
| |
|
| | std::string r = exprstack.top(); |
| | exprstack.pop(); |
| |
|
| | return r; |
| | } |
| |
|
| | static std::string make_slice_expression(const Operator* op) |
| | { |
| | for (size_t j = 0; j < op->inputnames.size(); j++) |
| | { |
| | fprintf(stderr, "make_slice_expression %s %s\n", op->inputnames[j].c_str(), op->inputs[j]->name.c_str()); |
| | } |
| |
|
| | std::vector<int> dims; |
| | if (op->params.find("dims") != op->params.end()) |
| | { |
| | dims = op->params.at("dims").ai; |
| | } |
| | else |
| | { |
| | dims.push_back(op->params.at("dim").i); |
| | } |
| |
|
| | std::string r; |
| |
|
| | int last_dim = -1; |
| | const int ndim = (int)dims.size(); |
| | for (int i = 0; i < ndim; i++) |
| | { |
| | int dim = dims[i]; |
| | for (int j = last_dim + 1; j < dim; j++) |
| | { |
| | r += ":,"; |
| | } |
| | last_dim = dim; |
| |
|
| | if (op->params.find("starts") != op->params.end()) |
| | { |
| | std::vector<int> starts = op->params.at("starts").ai; |
| | int start = starts[i]; |
| |
|
| | if (start != 0) |
| | r += std::to_string(start); |
| | } |
| | else |
| | { |
| | fprintf(stderr, "find start\n"); |
| | |
| | for (size_t j = 0; j < op->inputnames.size(); j++) |
| | { |
| | if (op->inputnames[j] == "start") |
| | { |
| | r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); |
| |
|
| | fprintf(stderr, "find start %s\n", op->inputs[j]->name.c_str()); |
| | break; |
| | } |
| | } |
| | } |
| |
|
| | r += ':'; |
| |
|
| | if (op->params.find("ends") != op->params.end()) |
| | { |
| | std::vector<int> ends = op->params.at("ends").ai; |
| | int end = ends[i]; |
| | if (end != INT_MAX) |
| | r += std::to_string(end); |
| | } |
| | else |
| | { |
| | |
| | for (size_t j = 0; j < op->inputnames.size(); j++) |
| | { |
| | if (op->inputnames[j] == "end") |
| | { |
| | r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); |
| | break; |
| | } |
| | } |
| | } |
| |
|
| | if (op->params.find("steps") != op->params.end()) |
| | { |
| | std::vector<int> steps = op->params.at("steps").ai; |
| | int step = steps[i]; |
| | if (step != 1) |
| | { |
| | r += ':'; |
| | r += std::to_string(step); |
| | } |
| | } |
| | else |
| | { |
| | |
| | for (size_t j = 0; j < op->inputnames.size(); j++) |
| | { |
| | if (op->inputnames[j] == "step") |
| | { |
| | r += ':'; |
| | r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); |
| | break; |
| | } |
| | } |
| | } |
| |
|
| | if (i + 1 != ndim) |
| | r += ','; |
| | } |
| |
|
| | return r; |
| | } |
| |
|
| | static std::string make_index_expression(const Operator* op) |
| | { |
| | fprintf(stderr, "make_index_expression %s\n", op->name.c_str()); |
| |
|
| | std::string index_expr = op->params.at("expr").s; |
| |
|
| | |
| | index_expr = index_expr.substr(1, index_expr.size() - 2); |
| |
|
| | |
| | bool leading_none = false; |
| | while (index_expr.substr(0, 5) == "None,") |
| | { |
| | leading_none = true; |
| | index_expr = index_expr.substr(5); |
| | } |
| | if (leading_none) |
| | { |
| | index_expr = "...," + index_expr; |
| | } |
| |
|
| | return index_expr; |
| | } |
| |
|
| | int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) |
| | { |
| | FILE* pyfp = fopen(pypath.c_str(), "wb"); |
| | if (!pyfp) |
| | { |
| | fprintf(stderr, "fopen %s failed\n", pypath.c_str()); |
| | return -1; |
| | } |
| |
|
| | fprintf(pyfp, "import os\n"); |
| | fprintf(pyfp, "import numpy as np\n"); |
| | fprintf(pyfp, "import tempfile, zipfile\n"); |
| | fprintf(pyfp, "import torch\n"); |
| | fprintf(pyfp, "import torch.nn as nn\n"); |
| | fprintf(pyfp, "import torch.nn.functional as F\n"); |
| | fprintf(pyfp, "try:\n"); |
| | fprintf(pyfp, " import torchvision\n"); |
| | fprintf(pyfp, "except:\n"); |
| | fprintf(pyfp, " pass\n"); |
| |
|
| | fprintf(pyfp, "\n"); |
| |
|
| | fprintf(pyfp, "class Model(nn.Module):\n"); |
| | fprintf(pyfp, " def __init__(self):\n"); |
| | fprintf(pyfp, " super(Model, self).__init__()\n"); |
| |
|
| | fprintf(pyfp, "\n"); |
| |
|
| | |
| | { |
| | for (const Operator* op : ops) |
| | { |
| | if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.") |
| | continue; |
| |
|
| | fprintf(pyfp, " self.%s = %s(", sanitize_identifier(op->name).c_str(), op->type.c_str()); |
| |
|
| | int param_count = op->params.size(); |
| | if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") |
| | { |
| | param_count -= 2; |
| | } |
| |
|
| | int param_index = 0; |
| | for (const auto& it : op->params) |
| | { |
| | if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") |
| | { |
| | if (it.first == "scale" || it.first == "zero_point") |
| | continue; |
| | } |
| |
|
| | fprintf(pyfp, "%s=", it.first.c_str()); |
| |
|
| | const Parameter& param = it.second; |
| | if (param.type == 0) |
| | { |
| | fprintf(pyfp, "None"); |
| | } |
| | if (param.type == 1) |
| | { |
| | if (param.b) |
| | fprintf(pyfp, "True"); |
| | else |
| | fprintf(pyfp, "False"); |
| | } |
| | if (param.type == 2) |
| | { |
| | fprintf(pyfp, "%d", param.i); |
| | } |
| | if (param.type == 3) |
| | { |
| | fprintf(pyfp, "%f", param.f); |
| | } |
| | if (param.type == 4) |
| | { |
| | if (param.s.substr(0, 6) == "torch.") |
| | { |
| | fprintf(pyfp, "%s", param.s.c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, "\'%s\'", param.s.c_str()); |
| | } |
| | } |
| | if (param.type == 5) |
| | { |
| | fprintf(pyfp, "("); |
| | for (size_t i = 0; i < param.ai.size(); i++) |
| | { |
| | fprintf(pyfp, "%d", param.ai[i]); |
| | if (i + 1 != param.ai.size() || param.ai.size() == 1) |
| | fprintf(pyfp, ","); |
| | } |
| | fprintf(pyfp, ")"); |
| | } |
| | if (param.type == 6) |
| | { |
| | fprintf(pyfp, "("); |
| | for (size_t i = 0; i < param.af.size(); i++) |
| | { |
| | fprintf(pyfp, "%f", param.af[i]); |
| | if (i + 1 != param.af.size() || param.af.size() == 1) |
| | fprintf(pyfp, ","); |
| | } |
| | fprintf(pyfp, ")"); |
| | } |
| | if (param.type == 7) |
| | { |
| | fprintf(pyfp, "("); |
| | for (size_t i = 0; i < param.as.size(); i++) |
| | { |
| | if (param.as[i].substr(0, 6) == "torch.") |
| | { |
| | fprintf(pyfp, "%s", param.as[i].c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, "\'%s\'", param.as[i].c_str()); |
| | } |
| | if (i + 1 != param.as.size() || param.as.size() == 1) |
| | fprintf(pyfp, ","); |
| | } |
| | fprintf(pyfp, ")"); |
| | } |
| |
|
| | param_index++; |
| | if (param_index != param_count) |
| | fprintf(pyfp, ", "); |
| | } |
| |
|
| | fprintf(pyfp, ")\n"); |
| | } |
| | } |
| |
|
| | fprintf(pyfp, "\n"); |
| |
|
| | |
| | { |
| | fprintf(pyfp, " archive = zipfile.ZipFile('%s', 'r')\n", pnnxbinpath.c_str()); |
| |
|
| | for (const Operator* op : ops) |
| | { |
| | if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.") |
| | continue; |
| |
|
| | if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") |
| | { |
| | for (const auto& it : op->attrs) |
| | { |
| | if (it.first == "weight" || it.first == "bias") |
| | { |
| | fprintf(pyfp, " self_%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); |
| | } |
| | else |
| | { |
| | |
| | continue; |
| | } |
| |
|
| | const Attribute& attr = it.second; |
| | for (size_t i = 0; i < attr.shape.size(); i++) |
| | { |
| | fprintf(pyfp, "%d", attr.shape[i]); |
| | if (i + 1 != attr.shape.size()) |
| | fprintf(pyfp, ","); |
| | } |
| |
|
| | fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); |
| | } |
| |
|
| | fprintf(pyfp, " self.%s.set_weight_bias(self_%s_weight, self_%s_bias)\n", sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str()); |
| | fprintf(pyfp, " self.%s.scale = %f\n", sanitize_identifier(op->name).c_str(), op->params.at("scale").f); |
| | fprintf(pyfp, " self.%s.zero_point = %d\n", sanitize_identifier(op->name).c_str(), op->params.at("zero_point").i); |
| |
|
| | continue; |
| | } |
| |
|
| | for (const auto& it : op->attrs) |
| | { |
| | if (it.first == "running_mean" || it.first == "running_var") |
| | { |
| | fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); |
| | } |
| |
|
| | const Attribute& attr = it.second; |
| | for (size_t i = 0; i < attr.shape.size(); i++) |
| | { |
| | fprintf(pyfp, "%d", attr.shape[i]); |
| | if (i + 1 != attr.shape.size()) |
| | fprintf(pyfp, ","); |
| | } |
| |
|
| | if (attr.type == 1 || attr.type == 2 || attr.type == 3) |
| | { |
| | fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); |
| | } |
| | } |
| | } |
| |
|
| | for (const Operator* op : ops) |
| | { |
| | if (op->type != "pnnx.Attribute") |
| | continue; |
| |
|
| | const std::string& key = op->attrs.begin()->first; |
| | const Attribute& attr = op->attrs.begin()->second; |
| |
|
| | bool is_running_mean_var = false; |
| | { |
| | const Operand* r = op->outputs[0]; |
| | if (r->consumers.size() == 1) |
| | { |
| | const Operator* op2 = r->consumers[0]; |
| | if (op2->type == "F.batch_norm" || op2->type == "F.instance_norm") |
| | { |
| | if (r == op2->inputs[1] || r == op2->inputs[2]) |
| | { |
| | is_running_mean_var = true; |
| | } |
| | } |
| | } |
| | } |
| |
|
| | if (is_running_mean_var) |
| | { |
| | fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str()); |
| | } |
| |
|
| | for (size_t i = 0; i < attr.shape.size(); i++) |
| | { |
| | fprintf(pyfp, "%d", attr.shape[i]); |
| | if (i + 1 != attr.shape.size()) |
| | fprintf(pyfp, ","); |
| | } |
| |
|
| | if (attr.type == 1 || attr.type == 2 || attr.type == 3) |
| | { |
| | fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); |
| | } |
| | } |
| |
|
| | fprintf(pyfp, " archive.close()\n"); |
| | } |
| |
|
| | fprintf(pyfp, "\n"); |
| |
|
| | |
| | { |
| | fprintf(pyfp, " def load_pnnx_bin_as_parameter(self, archive, key, shape, dtype, requires_grad=True):\n"); |
| | fprintf(pyfp, " return nn.Parameter(self.load_pnnx_bin_as_tensor(archive, key, shape, dtype), requires_grad)\n"); |
| | fprintf(pyfp, "\n"); |
| | fprintf(pyfp, " def load_pnnx_bin_as_tensor(self, archive, key, shape, dtype):\n"); |
| | fprintf(pyfp, " fd, tmppath = tempfile.mkstemp()\n"); |
| | fprintf(pyfp, " with os.fdopen(fd, 'wb') as tmpf, archive.open(key) as keyfile:\n"); |
| | fprintf(pyfp, " tmpf.write(keyfile.read())\n"); |
| | fprintf(pyfp, " m = np.memmap(tmppath, dtype=dtype, mode='r', shape=shape).copy()\n"); |
| | fprintf(pyfp, " os.remove(tmppath)\n"); |
| | fprintf(pyfp, " return torch.from_numpy(m)\n"); |
| | } |
| |
|
| | fprintf(pyfp, "\n"); |
| |
|
| | |
| | { |
| | fprintf(pyfp, " def forward(self"); |
| |
|
| | for (const Operator* op : ops) |
| | { |
| | if (op->type != "pnnx.Input") |
| | continue; |
| |
|
| | fprintf(pyfp, ", v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); |
| | } |
| |
|
| | fprintf(pyfp, "):\n"); |
| | } |
| |
|
| | |
| | { |
| | for (const Operator* op : ops) |
| | { |
| | if (op->type == "pnnx.Input" || op->type == "pnnx.Output") |
| | continue; |
| |
|
| | fprintf(pyfp, " "); |
| |
|
| | if (op->type == "pnnx.Expression") |
| | { |
| | |
| | for (size_t i = 0; i < op->outputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); |
| | if (i + 1 != op->outputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | std::string expanded_expr = expand_expression(op); |
| | fprintf(pyfp, " = %s\n", expanded_expr.c_str()); |
| | } |
| | else if (op->type == "pnnx.Attribute") |
| | { |
| | const std::string& key = op->attrs.begin()->first; |
| | fprintf(pyfp, "v_%s = self.%s_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str()); |
| | } |
| | else if (op->type == "Tensor.slice") |
| | { |
| | |
| | std::string slice_expr = make_slice_expression(op); |
| | fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), slice_expr.c_str()); |
| | } |
| | else if (op->type == "Tensor.slice_copy") |
| | { |
| | |
| | std::string slice_expr = make_slice_expression(op); |
| | fprintf(pyfp, "v_%s = v_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str()); |
| | fprintf(pyfp, " v_%s[%s] = v_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), slice_expr.c_str(), sanitize_identifier(op->inputs[1]->name).c_str()); |
| | } |
| | else if (op->type == "Tensor.index") |
| | { |
| | |
| | if (op->inputs.size() == 2) |
| | { |
| | std::string expanded_expr = expand_expression(op->inputs[1]->producer); |
| | fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str()); |
| | } |
| | else |
| | { |
| | std::string index_expr = make_index_expression(op); |
| | fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); |
| | } |
| | } |
| | else if (op->type == "Tensor.expand") |
| | { |
| | |
| | fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); |
| | if (op->inputs.size() == 2) |
| | { |
| | fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); |
| | } |
| | else |
| | { |
| | const std::vector<int>& shape = op->params.at("shape").ai; |
| | for (size_t i = 0; i < shape.size(); i++) |
| | { |
| | fprintf(pyfp, "%d", shape[i]); |
| | if (i + 1 != shape.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | } |
| | fprintf(pyfp, ")\n"); |
| | } |
| | else if (op->type == "Tensor.view" || op->type == "Tensor.reshape") |
| | { |
| | |
| | fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); |
| | if (op->inputs.size() == 2) |
| | { |
| | fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); |
| | } |
| | else |
| | { |
| | const std::vector<int>& shape = op->params.at("shape").ai; |
| | for (size_t i = 0; i < shape.size(); i++) |
| | { |
| | fprintf(pyfp, "%d", shape[i]); |
| | if (i + 1 != shape.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | } |
| | fprintf(pyfp, ")\n"); |
| | } |
| | else if (op->type == "Tensor.repeat") |
| | { |
| | |
| | fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); |
| | if (op->inputs.size() == 2) |
| | { |
| | fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); |
| | } |
| | else |
| | { |
| | const std::vector<int>& sizes = op->params.at("sizes").ai; |
| | for (size_t i = 0; i < sizes.size(); i++) |
| | { |
| | fprintf(pyfp, "%d", sizes[i]); |
| | if (i + 1 != sizes.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | } |
| | fprintf(pyfp, ")\n"); |
| | } |
| | else if (op->type == "torch.cat" || op->type == "torch.stack") |
| | { |
| | |
| | fprintf(pyfp, "v_%s = %s(", sanitize_identifier(op->outputs[0]->name).c_str(), op->type.c_str()); |
| | if (op->inputs.size() == 1) |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, "("); |
| | for (size_t i = 0; i < op->inputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); |
| | if (i + 1 != op->inputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | fprintf(pyfp, ")"); |
| | } |
| | fprintf(pyfp, ", dim=%d", op->params.at("dim").i); |
| | fprintf(pyfp, ")\n"); |
| | } |
| | else if (op->type == "torch.einsum") |
| | { |
| | |
| | fprintf(pyfp, "v_%s = %s(", sanitize_identifier(op->outputs[0]->name).c_str(), op->type.c_str()); |
| |
|
| | fprintf(pyfp, "\'%s\'", op->params.at("equation").s.c_str()); |
| |
|
| | for (size_t i = 0; i < op->inputs.size(); i++) |
| | { |
| | fprintf(pyfp, ", v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); |
| | } |
| | fprintf(pyfp, ")\n"); |
| | } |
| | else if (op->type == "prim::TupleUnpack") |
| | { |
| | for (size_t i = 0; i < op->outputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); |
| | if (i + 1 != op->outputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); |
| | } |
| | else if (op->type == "prim::TupleConstruct") |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); |
| | fprintf(pyfp, " = ("); |
| | for (size_t i = 0; i < op->inputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); |
| | } |
| | fprintf(pyfp, ")\n"); |
| | } |
| | else if (op->type == "prim::ListUnpack") |
| | { |
| | for (size_t i = 0; i < op->outputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); |
| | if (i + 1 != op->outputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); |
| | } |
| | else if (op->type == "prim::ListConstruct") |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); |
| | fprintf(pyfp, " = ["); |
| | for (size_t i = 0; i < op->inputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); |
| | if (i + 1 != op->inputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | fprintf(pyfp, "]\n"); |
| | } |
| | else if (op->type == "nn.GRU" || op->type == "nn.RNN") |
| | { |
| | if (op->outputs.size() == 1) |
| | { |
| | fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, "v_%s, v_%s", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->outputs[1]->name).c_str()); |
| | } |
| | fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); |
| | if (op->inputs.size() == 2) |
| | { |
| | fprintf(pyfp, ", v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); |
| | } |
| | fprintf(pyfp, ")\n"); |
| | } |
| | else if (op->type == "nn.LSTM") |
| | { |
| | if (op->outputs.size() == 1) |
| | { |
| | fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, "v_%s, (v_%s, v_%s)", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->outputs[1]->name).c_str(), sanitize_identifier(op->outputs[2]->name).c_str()); |
| | } |
| | fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); |
| | if (op->inputs.size() == 3) |
| | { |
| | fprintf(pyfp, ", (v_%s, v_%s)", sanitize_identifier(op->inputs[1]->name).c_str(), sanitize_identifier(op->inputs[2]->name).c_str()); |
| | } |
| | fprintf(pyfp, ")\n"); |
| | } |
| | else if (op->type == "nn.MultiheadAttention") |
| | { |
| | bool need_weights = true; |
| | if (op->outputs.size() == 1) |
| | { |
| | fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); |
| | need_weights = false; |
| | } |
| | else |
| | { |
| | for (size_t i = 0; i < op->outputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); |
| | if (i + 1 != op->outputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | } |
| | fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); |
| | if (op->inputs.size() == 1) |
| | { |
| | std::string in0 = sanitize_identifier(op->inputs[0]->name); |
| | fprintf(pyfp, "v_%s, v_%s, v_%s", in0.c_str(), in0.c_str(), in0.c_str()); |
| | } |
| | else if (op->inputs.size() == 2) |
| | { |
| | std::string in0 = sanitize_identifier(op->inputs[0]->name); |
| | std::string in1 = sanitize_identifier(op->inputs[1]->name); |
| | if (op->inputnames.size() == 2 && op->inputnames[1] == "attn_mask") |
| | { |
| | fprintf(pyfp, "v_%s, v_%s, v_%s, attn_mask=v_%s", in0.c_str(), in0.c_str(), in0.c_str(), in1.c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, "v_%s, v_%s, v_%s", in0.c_str(), in1.c_str(), in1.c_str()); |
| | } |
| | } |
| | else if (op->inputs.size() == 3) |
| | { |
| | std::string in0 = sanitize_identifier(op->inputs[0]->name); |
| | std::string in1 = sanitize_identifier(op->inputs[1]->name); |
| | std::string in2 = sanitize_identifier(op->inputs[2]->name); |
| | if (op->inputnames.size() == 3 && op->inputnames[2] == "attn_mask") |
| | { |
| | fprintf(pyfp, "v_%s, v_%s, v_%s, attn_mask=v_%s", in0.c_str(), in1.c_str(), in1.c_str(), in2.c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, "v_%s, v_%s, v_%s", in0.c_str(), in1.c_str(), in2.c_str()); |
| | } |
| | } |
| | else if (op->inputs.size() == 4) |
| | { |
| | std::string in0 = sanitize_identifier(op->inputs[0]->name); |
| | std::string in1 = sanitize_identifier(op->inputs[1]->name); |
| | std::string in2 = sanitize_identifier(op->inputs[2]->name); |
| | std::string in3 = sanitize_identifier(op->inputs[3]->name); |
| | if (op->inputnames.size() == 4 && op->inputnames[3] == "attn_mask") |
| | { |
| | fprintf(pyfp, "v_%s, v_%s, v_%s, attn_mask=v_%s", in0.c_str(), in1.c_str(), in2.c_str(), in3.c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, "v_%s, v_%s, v_%s, v_%s", in0.c_str(), in1.c_str(), in2.c_str(), in3.c_str()); |
| | } |
| | } |
| | else |
| | { |
| | for (size_t i = 0; i < op->inputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); |
| | if (i + 1 != op->inputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | } |
| | if (need_weights) |
| | { |
| | fprintf(pyfp, ", need_weights=True"); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, ", need_weights=False"); |
| | } |
| | fprintf(pyfp, ")\n"); |
| | } |
| | else if (op->type.substr(0, 3) == "nn." || op->type.substr(0, 16) == "torchvision.ops.") |
| | { |
| | |
| | for (size_t i = 0; i < op->outputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); |
| | if (i + 1 != op->outputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); |
| | for (size_t i = 0; i < op->inputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); |
| | if (i + 1 != op->inputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | fprintf(pyfp, ")\n"); |
| | } |
| | else if (op->type.find("::") != std::string::npos || op->type.find(".") != std::string::npos) |
| | { |
| | |
| | for (size_t i = 0; i < op->outputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); |
| | if (i + 1 != op->outputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| |
|
| | if (op->type.substr(0, 7) == "Tensor.") |
| | { |
| | if (op->type == "Tensor.fill") |
| | { |
| | fprintf(pyfp, " = v_%s.fill_(", sanitize_identifier(op->inputs[0]->name).c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); |
| | } |
| |
|
| | if (op->inputnames.size() == op->inputs.size()) |
| | { |
| | for (size_t i = 1; i < op->inputs.size(); i++) |
| | { |
| | if (!op->inputnames[i].empty()) |
| | continue; |
| |
|
| | fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); |
| | } |
| |
|
| | for (size_t i = 1; i < op->inputs.size(); i++) |
| | { |
| | if (op->inputnames[i].empty()) |
| | continue; |
| |
|
| | fprintf(pyfp, "%s=v_%s, ", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str()); |
| | } |
| | } |
| | else |
| | { |
| | for (size_t i = 1; i < op->inputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); |
| | } |
| | } |
| | } |
| | else |
| | { |
| | fprintf(pyfp, " = %s(", op->type.c_str()); |
| |
|
| | if (op->inputnames.size() == op->inputs.size()) |
| | { |
| | for (size_t i = 0; i < op->inputs.size(); i++) |
| | { |
| | if (!op->inputnames[i].empty()) |
| | continue; |
| |
|
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); |
| | if (i + 1 != op->inputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| |
|
| | for (size_t i = 0; i < op->inputs.size(); i++) |
| | { |
| | if (op->inputnames[i].empty()) |
| | continue; |
| |
|
| | fprintf(pyfp, "%s=v_%s", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str()); |
| | if (i + 1 != op->inputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | } |
| | else |
| | { |
| | for (size_t i = 0; i < op->inputs.size(); i++) |
| | { |
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); |
| | if (i + 1 != op->inputs.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| | } |
| | } |
| |
|
| | int i = 0; |
| | for (const auto& it : op->params) |
| | { |
| | if (op->type.substr(0, 7) == "Tensor." && i == 0) |
| | { |
| | fprintf(pyfp, "%s=", it.first.c_str()); |
| | } |
| | else if (op->inputs.empty() && i == 0) |
| | { |
| | fprintf(pyfp, "%s=", it.first.c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, ", %s=", it.first.c_str()); |
| | } |
| |
|
| | i++; |
| |
|
| | const Parameter& param = it.second; |
| | if (param.type == 0) |
| | { |
| | fprintf(pyfp, "None"); |
| | } |
| | if (param.type == 1) |
| | { |
| | if (param.b) |
| | fprintf(pyfp, "True"); |
| | else |
| | fprintf(pyfp, "False"); |
| | } |
| | if (param.type == 2) |
| | { |
| | fprintf(pyfp, "%d", param.i); |
| | } |
| | if (param.type == 3) |
| | { |
| | fprintf(pyfp, "%f", param.f); |
| | } |
| | if (param.type == 4) |
| | { |
| | if (param.s.substr(0, 6) == "torch.") |
| | { |
| | fprintf(pyfp, "%s", param.s.c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, "\'%s\'", param.s.c_str()); |
| | } |
| | } |
| | if (param.type == 5) |
| | { |
| | fprintf(pyfp, "("); |
| | for (size_t i = 0; i < param.ai.size(); i++) |
| | { |
| | fprintf(pyfp, "%d", param.ai[i]); |
| | if (i + 1 != param.ai.size() || param.ai.size() == 1) |
| | fprintf(pyfp, ","); |
| | } |
| | fprintf(pyfp, ")"); |
| | } |
| | if (param.type == 6) |
| | { |
| | fprintf(pyfp, "("); |
| | for (size_t i = 0; i < param.af.size(); i++) |
| | { |
| | fprintf(pyfp, "%f", param.af[i]); |
| | if (i + 1 != param.af.size() || param.af.size() == 1) |
| | fprintf(pyfp, ","); |
| | } |
| | fprintf(pyfp, ")"); |
| | } |
| | if (param.type == 7) |
| | { |
| | fprintf(pyfp, "("); |
| | for (size_t i = 0; i < param.as.size(); i++) |
| | { |
| | if (param.as[i].substr(0, 6) == "torch.") |
| | { |
| | fprintf(pyfp, "%s", param.as[i].c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, "\'%s\'", param.as[i].c_str()); |
| | } |
| | if (i + 1 != param.as.size() || param.as.size() == 1) |
| | fprintf(pyfp, ","); |
| | } |
| | fprintf(pyfp, ")"); |
| | } |
| | if (param.type == 10) |
| | { |
| | fprintf(pyfp, "(%f%+fj)", param.c.real(), param.c.imag()); |
| | } |
| | if (param.type == 11) |
| | { |
| | fprintf(pyfp, "("); |
| | for (size_t i = 0; i < param.ac.size(); i++) |
| | { |
| | fprintf(pyfp, "(%f%+fj)", param.ac[i].real(), param.ac[i].imag()); |
| | if (i + 1 != param.ac.size() || param.ac.size() == 1) |
| | fprintf(pyfp, ","); |
| | } |
| | fprintf(pyfp, ")"); |
| | } |
| | } |
| |
|
| | fprintf(pyfp, ")\n"); |
| | } |
| | else |
| | { |
| | fprintf(stderr, "todo %s\n", op->type.c_str()); |
| | } |
| | } |
| | } |
| |
|
| | |
| | { |
| | fprintf(pyfp, " return "); |
| |
|
| | int output_count = 0; |
| | { |
| | for (const Operator* op : ops) |
| | { |
| | if (op->type == "pnnx.Output") |
| | output_count++; |
| | } |
| | } |
| |
|
| | int output_index = 0; |
| | for (const Operator* op : ops) |
| | { |
| | if (op->type != "pnnx.Output") |
| | continue; |
| |
|
| | fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); |
| | if (output_index + 1 != output_count) |
| | fprintf(pyfp, ", "); |
| |
|
| | output_index++; |
| | } |
| |
|
| | fprintf(pyfp, "\n"); |
| | } |
| |
|
| | fprintf(pyfp, "\n"); |
| |
|
| | |
| | { |
| | fprintf(pyfp, "def export_torchscript():\n"); |
| | fprintf(pyfp, " net = Model()\n"); |
| | fprintf(pyfp, " net.eval()\n"); |
| | fprintf(pyfp, "\n"); |
| | fprintf(pyfp, " torch.manual_seed(0)\n"); |
| |
|
| | std::vector<std::string> input_names; |
| | for (const Operator* op : ops) |
| | { |
| | if (op->type != "pnnx.Input") |
| | continue; |
| |
|
| | const Operand* r = op->outputs[0]; |
| | std::string input_name = std::string("v_") + sanitize_identifier(r->name); |
| | if (type_is_integer(r->type)) |
| | { |
| | fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); |
| | for (size_t i = 0; i < r->shape.size(); i++) |
| | { |
| | fprintf(pyfp, "%d", r->shape[i]); |
| | if (i + 1 != r->shape.size() || r->shape.size() == 1) |
| | fprintf(pyfp, ", "); |
| | } |
| | fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); |
| | for (size_t i = 0; i < r->shape.size(); i++) |
| | { |
| | fprintf(pyfp, "%d, ", r->shape[i]); |
| | } |
| | fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); |
| | } |
| |
|
| | input_names.push_back(input_name); |
| | } |
| |
|
| | fprintf(pyfp, "\n"); |
| |
|
| | if (input_names.size() == 1) |
| | { |
| | fprintf(pyfp, " mod = torch.jit.trace(net, %s)\n", input_names[0].c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, " mod = torch.jit.trace(net, ("); |
| |
|
| | for (size_t i = 0; i < input_names.size(); i++) |
| | { |
| | fprintf(pyfp, "%s", input_names[i].c_str()); |
| | if (i + 1 != input_names.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| |
|
| | fprintf(pyfp, "))\n"); |
| | } |
| |
|
| | fprintf(pyfp, " mod.save(\"%s.pt\")\n", pypath.c_str()); |
| | } |
| |
|
| | fprintf(pyfp, "\n"); |
| |
|
| | |
| | { |
| | fprintf(pyfp, "def export_onnx():\n"); |
| | fprintf(pyfp, " net = Model()\n"); |
| | fprintf(pyfp, " net.eval()\n"); |
| | fprintf(pyfp, "\n"); |
| | fprintf(pyfp, " torch.manual_seed(0)\n"); |
| |
|
| | std::vector<std::string> input_names; |
| | for (const Operator* op : ops) |
| | { |
| | if (op->type != "pnnx.Input") |
| | continue; |
| |
|
| | const Operand* r = op->outputs[0]; |
| | std::string input_name = std::string("v_") + sanitize_identifier(r->name); |
| | if (type_is_integer(r->type)) |
| | { |
| | fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); |
| | for (size_t i = 0; i < r->shape.size(); i++) |
| | { |
| | fprintf(pyfp, "%d", r->shape[i]); |
| | if (i + 1 != r->shape.size() || r->shape.size() == 1) |
| | fprintf(pyfp, ", "); |
| | } |
| | fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); |
| | for (size_t i = 0; i < r->shape.size(); i++) |
| | { |
| | fprintf(pyfp, "%d, ", r->shape[i]); |
| | } |
| | fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); |
| | } |
| |
|
| | input_names.push_back(input_name); |
| | } |
| |
|
| | fprintf(pyfp, "\n"); |
| |
|
| | |
| |
|
| | if (input_names.size() == 1) |
| | { |
| | fprintf(pyfp, " torch.onnx._export(net, %s", input_names[0].c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, " torch.onnx._export(net, ("); |
| |
|
| | for (size_t i = 0; i < input_names.size(); i++) |
| | { |
| | fprintf(pyfp, "%s", input_names[i].c_str()); |
| | if (i + 1 != input_names.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| |
|
| | fprintf(pyfp, ")"); |
| | } |
| |
|
| | fprintf(pyfp, ", \"%s.onnx\", export_params=True, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, opset_version=13", pypath.c_str()); |
| |
|
| | fprintf(pyfp, ", input_names=["); |
| | { |
| | int input_count = 0; |
| | { |
| | for (const Operator* op : ops) |
| | { |
| | if (op->type == "pnnx.Input") |
| | input_count++; |
| | } |
| | } |
| |
|
| | int input_index = 0; |
| | for (const Operator* op : ops) |
| | { |
| | if (op->type != "pnnx.Input") |
| | continue; |
| |
|
| | fprintf(pyfp, "'in%d'", input_index); |
| | if (input_index + 1 != input_count) |
| | fprintf(pyfp, ", "); |
| |
|
| | input_index++; |
| | } |
| | } |
| | fprintf(pyfp, "]"); |
| |
|
| | fprintf(pyfp, ", output_names=["); |
| | { |
| | int output_count = 0; |
| | { |
| | for (const Operator* op : ops) |
| | { |
| | if (op->type == "pnnx.Output") |
| | output_count++; |
| | } |
| | } |
| |
|
| | int output_index = 0; |
| | for (const Operator* op : ops) |
| | { |
| | if (op->type != "pnnx.Output") |
| | continue; |
| |
|
| | fprintf(pyfp, "'out%d'", output_index); |
| | if (output_index + 1 != output_count) |
| | fprintf(pyfp, ", "); |
| |
|
| | output_index++; |
| | } |
| | } |
| | fprintf(pyfp, "]"); |
| |
|
| | fprintf(pyfp, ")\n"); |
| | } |
| |
|
| | fprintf(pyfp, "\n"); |
| |
|
| | |
| | { |
| | fprintf(pyfp, "def test_inference():\n"); |
| | fprintf(pyfp, " net = Model()\n"); |
| | fprintf(pyfp, " net.eval()\n"); |
| | fprintf(pyfp, "\n"); |
| | fprintf(pyfp, " torch.manual_seed(0)\n"); |
| |
|
| | std::vector<std::string> input_names; |
| | for (const Operator* op : ops) |
| | { |
| | if (op->type != "pnnx.Input") |
| | continue; |
| |
|
| | const Operand* r = op->outputs[0]; |
| | std::string input_name = std::string("v_") + sanitize_identifier(r->name); |
| | if (type_is_integer(r->type)) |
| | { |
| | fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); |
| | for (size_t i = 0; i < r->shape.size(); i++) |
| | { |
| | fprintf(pyfp, "%d", r->shape[i]); |
| | if (i + 1 != r->shape.size() || r->shape.size() == 1) |
| | fprintf(pyfp, ", "); |
| | } |
| | fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); |
| | for (size_t i = 0; i < r->shape.size(); i++) |
| | { |
| | fprintf(pyfp, "%d, ", r->shape[i]); |
| | } |
| | fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); |
| | } |
| |
|
| | input_names.push_back(input_name); |
| | } |
| |
|
| | fprintf(pyfp, "\n"); |
| |
|
| | if (input_names.size() == 1) |
| | { |
| | fprintf(pyfp, " return net(%s)\n", input_names[0].c_str()); |
| | } |
| | else |
| | { |
| | fprintf(pyfp, " return net("); |
| |
|
| | for (size_t i = 0; i < input_names.size(); i++) |
| | { |
| | fprintf(pyfp, "%s", input_names[i].c_str()); |
| | if (i + 1 != input_names.size()) |
| | fprintf(pyfp, ", "); |
| | } |
| |
|
| | fprintf(pyfp, ")\n"); |
| | } |
| | } |
| |
|
| | fclose(pyfp); |
| |
|
| | return 0; |
| | } |
| |
|
| | int Graph::parse(const std::string& param) |
| | { |
| | std::istringstream is(param); |
| | if (!is.good()) |
| | { |
| | fprintf(stderr, "open failed\n"); |
| | return -1; |
| | } |
| |
|
| | int magic = 0; |
| | { |
| | std::string line; |
| | std::getline(is, line); |
| | std::istringstream iss(line); |
| |
|
| | iss >> magic; |
| | } |
| |
|
| | int operator_count = 0; |
| | int operand_count = 0; |
| | { |
| | std::string line; |
| | std::getline(is, line); |
| | std::istringstream iss(line); |
| |
|
| | iss >> operator_count >> operand_count; |
| | } |
| |
|
| | for (int i = 0; i < operator_count; i++) |
| | { |
| | std::string line; |
| | std::getline(is, line); |
| | std::istringstream iss(line); |
| |
|
| | std::string type; |
| | std::string name; |
| | int input_count = 0; |
| | int output_count = 0; |
| |
|
| | iss >> type >> name >> input_count >> output_count; |
| |
|
| | Operator* op = new_operator(type, name); |
| |
|
| | for (int j = 0; j < input_count; j++) |
| | { |
| | std::string operand_name; |
| | iss >> operand_name; |
| |
|
| | Operand* r = get_operand(operand_name); |
| | r->consumers.push_back(op); |
| | op->inputs.push_back(r); |
| | } |
| |
|
| | for (int j = 0; j < output_count; j++) |
| | { |
| | std::string operand_name; |
| | iss >> operand_name; |
| |
|
| | Operand* r = new_operand(operand_name); |
| | r->producer = op; |
| | op->outputs.push_back(r); |
| | } |
| |
|
| | |
| | while (!iss.eof()) |
| | { |
| | std::string param; |
| | iss >> param; |
| |
|
| | std::string key; |
| | std::string value; |
| | std::istringstream pss(param); |
| | std::getline(pss, key, '='); |
| | std::getline(pss, value); |
| |
|
| | if (key[0] == '@') |
| | { |
| | |
| | |
| | op->attrs[key.substr(1)] = Attribute(); |
| |
|
| | Attribute& attr = op->attrs[key.substr(1)]; |
| |
|
| | attr.type = 0; |
| | if (value.empty()) |
| | continue; |
| |
|
| | if (value[0] == '%') |
| | { |
| | |
| | attr.data = std::vector<char>(value.begin(), value.end()); |
| | } |
| |
|
| | if (value[0] == '(') |
| | { |
| | |
| |
|
| | |
| | std::string typestr = value.substr(value.find_last_of(')') + 1); |
| | attr.type = string_to_type(typestr.c_str()); |
| |
|
| | |
| | std::string lc = value.substr(1, value.find_last_of(')') - 1); |
| | std::istringstream lcss(lc); |
| |
|
| | attr.shape.clear(); |
| | while (!lcss.eof()) |
| | { |
| | std::string elem; |
| | std::getline(lcss, elem, ','); |
| |
|
| | if (elem == "?") |
| | { |
| | attr.shape.push_back(-1); |
| | } |
| | else if (elem[0] == '%') |
| | { |
| | |
| | attr.shape.push_back(-233); |
| | int index = attr.shape.size() - 1; |
| | std::string key = elem.substr(1); |
| | attr.params[std::string("__shape_") + std::to_string(index)] = key; |
| | } |
| | else |
| | { |
| | int i = std::stoi(elem); |
| | attr.shape.push_back(i); |
| | } |
| | } |
| | } |
| | } |
| | else if (key[0] == '$') |
| | { |
| | |
| | load_input_key(op, key.substr(1), value); |
| | } |
| | else if (key[0] == '#') |
| | { |
| | |
| | load_shape(op, key.substr(1), value); |
| | } |
| | else |
| | { |
| | |
| | load_parameter(op, key, value); |
| | } |
| | } |
| | } |
| |
|
| | return 0; |
| | } |
| |
|
| | void Operand::remove_consumer(const Operator* c) |
| | { |
| | auto it = std::find(consumers.begin(), consumers.end(), c); |
| | consumers.erase(it); |
| | } |
| |
|
| | Operator* Graph::new_operator(const std::string& type, const std::string& name) |
| | { |
| | Operator* op = new Operator; |
| | op->type = type; |
| | op->name = name; |
| | ops.push_back(op); |
| | return op; |
| | } |
| |
|
| | Operator* Graph::new_operator_before(const std::string& type, const std::string& name, const Operator* cur) |
| | { |
| | Operator* op = new Operator; |
| | op->type = type; |
| | op->name = name; |
| | ops.insert(std::find(ops.begin(), ops.end(), cur), op); |
| | return op; |
| | } |
| |
|
| | Operator* Graph::new_operator_after(const std::string& type, const std::string& name, const Operator* cur) |
| | { |
| | Operator* op = new Operator; |
| | op->type = type; |
| | op->name = name; |
| | ops.insert(std::find(ops.begin(), ops.end(), cur) + 1, op); |
| | return op; |
| | } |
| |
|
| | #if BUILD_PNNX |
| | Operand* Graph::new_operand(const torch::jit::Value* v) |
| | { |
| | Operand* r = new Operand; |
| | r->name = v->debugName(); |
| |
|
| | r->type = -1; |
| |
|
| | auto pt = v->type()->cast<c10::TensorType>(); |
| | if (pt) |
| | { |
| | if (pt->scalarType().has_value() && pt->dim().has_value()) |
| | { |
| | r->type = get_at_tensor_type(pt->scalarType().value()); |
| | const int ndim = (int)pt->dim().value(); |
| | r->shape.resize(ndim); |
| | for (int i = 0; i < ndim; i++) |
| | { |
| | if (pt->sizes()[i].has_value()) |
| | r->shape[i] = (int)pt->sizes()[i].value(); |
| | else |
| | r->shape[i] = -1; |
| | } |
| | } |
| | } |
| |
|
| | operands.push_back(r); |
| | return r; |
| | } |
| | #endif |
| |
|
| | Operand* Graph::new_operand(const std::string& name) |
| | { |
| | Operand* r = new Operand; |
| | r->name = name; |
| | operands.push_back(r); |
| | return r; |
| | } |
| |
|
| | Operand* Graph::get_operand(const std::string& name) |
| | { |
| | for (Operand* r : operands) |
| | { |
| | if (r->name == name) |
| | return r; |
| | } |
| |
|
| | return 0; |
| | } |
| |
|
| | const Operand* Graph::get_operand(const std::string& name) const |
| | { |
| | for (const Operand* r : operands) |
| | { |
| | if (r->name == name) |
| | return r; |
| | } |
| |
|
| | return 0; |
| | } |
| |
|
| | } |
| |
|