| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include "pass_level1.h" |
| |
|
| | #include "../utils.h" |
| |
|
| | namespace pnnx { |
| |
|
| | class LSTM : public FuseModulePass |
| | { |
| | public: |
| | const char* match_type_str() const |
| | { |
| | return "__torch__.torch.nn.modules.rnn.LSTM"; |
| | } |
| |
|
| | const char* type_str() const |
| | { |
| | return "nn.LSTM"; |
| | } |
| |
|
| | void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const |
| | { |
| | |
| | |
| | |
| |
|
| | const torch::jit::Node* lstm = find_node_by_kind(graph, "aten::lstm"); |
| |
|
| | const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct"); |
| | if (return_tuple && return_tuple->inputs().size() == 3 && lstm->outputs().size() == 3 |
| | && return_tuple->inputs()[0] == lstm->outputs()[1] && return_tuple->inputs()[1] == lstm->outputs()[2] && return_tuple->inputs()[2] == lstm->outputs()[0]) |
| | { |
| | |
| | |
| | fprintf(stderr, "swapped detected !\n"); |
| | op->params["pnnx_rnn_output_swapped"] = 1; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); |
| | const auto& weight_hh_l0 = mod.attr("weight_hh_l0").toTensor(); |
| |
|
| | op->params["input_size"] = weight_ih_l0.size(1); |
| | op->params["hidden_size"] = weight_ih_l0.size(0) / 4; |
| | op->params["num_layers"] = lstm->namedInput("num_layers"); |
| | op->params["bias"] = lstm->namedInput("has_biases"); |
| | op->params["batch_first"] = lstm->namedInput("batch_first"); |
| | op->params["bidirectional"] = lstm->namedInput("bidirectional"); |
| | op->params["proj_size"] = weight_ih_l0.size(0) / 4 == weight_hh_l0.size(1) ? 0 : weight_hh_l0.size(1); |
| |
|
| | const int num_layers = op->params["num_layers"].i; |
| | const bool bias = op->params["bias"].b; |
| | const bool bidirectional = op->params["bidirectional"].b; |
| | const int proj_size = op->params["proj_size"].i; |
| |
|
| | for (int k = 0; k < num_layers; k++) |
| | { |
| | std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k); |
| | std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k); |
| |
|
| | op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor(); |
| | op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor(); |
| |
|
| | if (bias) |
| | { |
| | std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k); |
| | std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k); |
| |
|
| | op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor(); |
| | op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); |
| | } |
| |
|
| | if (proj_size > 0) |
| | { |
| | std::string weight_hr_lk_key = std::string("weight_hr_l") + std::to_string(k); |
| |
|
| | op->attrs[weight_hr_lk_key] = mod.attr(weight_hr_lk_key).toTensor(); |
| | } |
| |
|
| | if (bidirectional) |
| | { |
| | std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; |
| | std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse"; |
| |
|
| | op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor(); |
| | op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor(); |
| |
|
| | if (bias) |
| | { |
| | std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse"; |
| | std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse"; |
| |
|
| | op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); |
| | op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); |
| | } |
| |
|
| | if (proj_size > 0) |
| | { |
| | std::string weight_hr_lk_reverse_key = std::string("weight_hr_l") + std::to_string(k) + "_reverse"; |
| |
|
| | op->attrs[weight_hr_lk_reverse_key] = mod.attr(weight_hr_lk_reverse_key).toTensor(); |
| | } |
| | } |
| | } |
| | } |
| | }; |
| |
|
| | REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(LSTM) |
| |
|
| | } |
| |
|