ncnn / tools /pnnx /src /ir.h
camenduru's picture
thanks to ncnn ❤
be903e2
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef PNNX_IR_H
#define PNNX_IR_H
#include <limits.h>
#include <complex>
#include <initializer_list>
#include <limits>
#include <map>
#include <set>
#include <string>
#include <vector>
#if BUILD_PNNX
namespace torch {
namespace jit {
struct Value;
struct Node;
} // namespace jit
} // namespace torch
namespace at {
class Tensor;
}
#endif // BUILD_PNNX
namespace pnnx {
class Parameter
{
public:
Parameter()
: type(0)
{
}
Parameter(bool _b)
: type(1), b(_b)
{
}
Parameter(int _i)
: type(2), i(_i)
{
}
Parameter(long _l)
: type(2)
{
if (_l == std::numeric_limits<long>::max()) _l = INT_MAX;
if (_l == std::numeric_limits<long>::min()) _l = INT_MIN;
i = (int)_l;
}
Parameter(long long _l)
: type(2)
{
if (_l == std::numeric_limits<long long>::max()) _l = INT_MAX;
if (_l == std::numeric_limits<long long>::min()) _l = INT_MIN;
i = (int)_l;
}
Parameter(float _f)
: type(3), f(_f)
{
}
Parameter(double _d)
: type(3), f((float)_d)
{
}
Parameter(const char* _s)
: type(4), s(_s)
{
}
Parameter(const std::string& _s)
: type(4), s(_s)
{
}
Parameter(const std::initializer_list<int>& _ai)
: type(5), ai(_ai)
{
}
Parameter(const std::initializer_list<int64_t>& _ai)
: type(5)
{
for (const auto& x : _ai)
{
int64_t _l = x;
if (_l == std::numeric_limits<int64_t>::max()) _l = INT_MAX;
if (_l == std::numeric_limits<int64_t>::min()) _l = INT_MIN;
ai.push_back((int)_l);
}
}
Parameter(const std::vector<int>& _ai)
: type(5), ai(_ai)
{
}
Parameter(const std::initializer_list<float>& _af)
: type(6), af(_af)
{
}
Parameter(const std::initializer_list<double>& _af)
: type(6)
{
for (const auto& x : _af)
af.push_back((float)x);
}
Parameter(const std::vector<float>& _af)
: type(6), af(_af)
{
}
Parameter(const std::vector<double>& _af)
: type(6)
{
for (const auto& x : _af)
af.push_back((float)x);
}
Parameter(const std::initializer_list<const char*>& _as)
: type(7)
{
for (const auto& x : _as)
as.push_back(std::string(x));
}
Parameter(const std::initializer_list<std::string>& _as)
: type(7), as(_as)
{
}
Parameter(const std::vector<std::string>& _as)
: type(7), as(_as)
{
}
Parameter(const std::complex<float>& _c)
: type(10), c(_c)
{
}
Parameter(const std::complex<double>& _c)
: type(10), c(_c)
{
}
Parameter(const std::initializer_list<std::complex<float> >& _ac)
: type(11), ac(_ac)
{
}
Parameter(const std::initializer_list<std::complex<double> >& _ac)
: type(11)
{
for (const auto& x : _ac)
ac.push_back(std::complex<float>(x));
}
Parameter(const std::vector<std::complex<float> >& _ac)
: type(11), ac(_ac)
{
}
Parameter(const std::vector<std::complex<double> >& _ac)
: type(11)
{
for (const auto& x : _ac)
ac.push_back(std::complex<float>(x));
}
#if BUILD_PNNX
Parameter(const torch::jit::Node* value_node);
Parameter(const torch::jit::Value* value);
#endif // BUILD_PNNX
static Parameter parse_from_string(const std::string& value);
static std::string encode_to_string(const Parameter& param);
// 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others 10=c 11=ac
int type;
// value
bool b;
int i;
float f;
std::complex<float> c;
std::vector<int> ai;
std::vector<float> af;
std::vector<std::complex<float> > ac;
// keep std::string typed member the last for cross cxxabi compatibility
std::string s;
std::vector<std::string> as;
};
bool operator==(const Parameter& lhs, const Parameter& rhs);
class Attribute
{
public:
Attribute()
: type(0)
{
}
#if BUILD_PNNX
Attribute(const at::Tensor& t);
#endif // BUILD_PNNX
Attribute(const std::initializer_list<int>& shape, const std::vector<float>& t);
size_t elemsize() const;
int elemcount() const;
// convenient routines for manipulate fp32/fp16 weight
std::vector<float> get_float32_data() const;
void set_float32_data(const std::vector<float>& data);
// 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool 10=c64 11=c128 12=c32
int type;
std::vector<int> shape;
std::vector<char> data;
std::map<std::string, Parameter> params;
};
bool operator==(const Attribute& lhs, const Attribute& rhs);
// concat two attributes along the first axis
Attribute operator+(const Attribute& a, const Attribute& b);
class Operator;
class Operand
{
public:
void remove_consumer(const Operator* c);
Operator* producer;
std::vector<Operator*> consumers;
// 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool 10=c64 11=c128 12=c32
int type;
std::vector<int> shape;
// keep std::string typed member the last for cross cxxabi compatibility
std::string name;
std::map<std::string, Parameter> params;
private:
friend class Graph;
Operand()
{
type = 0;
}
};
class Operator
{
public:
std::vector<Operand*> inputs;
std::vector<Operand*> outputs;
// keep std::string typed member the last for cross cxxabi compatibility
std::string type;
std::string name;
std::vector<std::string> inputnames;
std::map<std::string, Parameter> params;
std::map<std::string, Attribute> attrs;
private:
friend class Graph;
Operator()
{
}
};
class Graph
{
public:
Graph();
~Graph();
int load(const std::string& parampath, const std::string& binpath);
int save(const std::string& parampath, const std::string& binpath);
int python(const std::string& pypath, const std::string& binpath);
int parse(const std::string& param);
Operator* new_operator(const std::string& type, const std::string& name);
Operator* new_operator_before(const std::string& type, const std::string& name, const Operator* cur);
Operator* new_operator_after(const std::string& type, const std::string& name, const Operator* cur);
#if BUILD_PNNX
Operand* new_operand(const torch::jit::Value* v);
#endif
Operand* new_operand(const std::string& name);
Operand* get_operand(const std::string& name);
const Operand* get_operand(const std::string& name) const;
std::vector<Operator*> ops;
std::vector<Operand*> operands;
private:
Graph(const Graph& rhs);
Graph& operator=(const Graph& rhs);
};
} // namespace pnnx
#endif // PNNX_IR_H