| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include "utils.h" |
| |
|
| | namespace pnnx { |
| |
|
| | const torch::jit::Node* find_node_by_kind(const std::shared_ptr<torch::jit::Graph>& graph, const std::string& kind) |
| | { |
| | for (const auto& n : graph->nodes()) |
| | { |
| | if (n->kind().toDisplayString() == kind) |
| | return n; |
| | } |
| |
|
| | return 0; |
| | } |
| |
|
| | unsigned short float32_to_float16(float value) |
| | { |
| | |
| | union |
| | { |
| | unsigned int u; |
| | float f; |
| | } tmp; |
| |
|
| | tmp.f = value; |
| |
|
| | |
| | unsigned short sign = (tmp.u & 0x80000000) >> 31; |
| | unsigned short exponent = (tmp.u & 0x7F800000) >> 23; |
| | unsigned int significand = tmp.u & 0x7FFFFF; |
| |
|
| | |
| |
|
| | |
| | unsigned short fp16; |
| | if (exponent == 0) |
| | { |
| | |
| | fp16 = (sign << 15) | (0x00 << 10) | 0x00; |
| | } |
| | else if (exponent == 0xFF) |
| | { |
| | |
| | fp16 = (sign << 15) | (0x1F << 10) | (significand ? 0x200 : 0x00); |
| | } |
| | else |
| | { |
| | |
| | short newexp = exponent + (-127 + 15); |
| | if (newexp >= 31) |
| | { |
| | |
| | fp16 = (sign << 15) | (0x1F << 10) | 0x00; |
| | } |
| | else if (newexp <= 0) |
| | { |
| | |
| | fp16 = (sign << 15) | (0x00 << 10) | 0x00; |
| | } |
| | else |
| | { |
| | |
| | fp16 = (sign << 15) | (newexp << 10) | (significand >> 13); |
| | } |
| | } |
| |
|
| | return fp16; |
| | } |
| |
|
| | float float16_to_float32(unsigned short value) |
| | { |
| | |
| | unsigned short sign = (value & 0x8000) >> 15; |
| | unsigned short exponent = (value & 0x7c00) >> 10; |
| | unsigned short significand = value & 0x03FF; |
| |
|
| | |
| |
|
| | |
| | union |
| | { |
| | unsigned int u; |
| | float f; |
| | } tmp; |
| | if (exponent == 0) |
| | { |
| | if (significand == 0) |
| | { |
| | |
| | tmp.u = (sign << 31); |
| | } |
| | else |
| | { |
| | |
| | exponent = 0; |
| | |
| | while ((significand & 0x200) == 0) |
| | { |
| | significand <<= 1; |
| | exponent++; |
| | } |
| | significand <<= 1; |
| | significand &= 0x3FF; |
| | tmp.u = (sign << 31) | ((-exponent + (-15 + 127)) << 23) | (significand << 13); |
| | } |
| | } |
| | else if (exponent == 0x1F) |
| | { |
| | |
| | tmp.u = (sign << 31) | (0xFF << 23) | (significand << 13); |
| | } |
| | else |
| | { |
| | |
| | tmp.u = (sign << 31) | ((exponent + (-15 + 127)) << 23) | (significand << 13); |
| | } |
| |
|
| | return tmp.f; |
| | } |
| |
|
| | } |
| |
|