Spaces:
Running
Running
| // Copyright (c) OpenMMLab. All rights reserved | |
| using namespace parrots; | |
| using namespace std; | |
| template <typename T> | |
| void contour_expand_parrots(T& ctx, const SSElement& attr, | |
| const OperatorBase::in_list_t& ins, | |
| OperatorBase::out_list_t& outs) { | |
| int min_kernel_area, kernel_num; | |
| SSAttrs(attr) | |
| .get<int>("min_kernel_area", min_kernel_area) | |
| .get<int>("kernel_num", kernel_num) | |
| .done(); | |
| at::Tensor kernel_mask; | |
| at::Tensor internal_kernel_label; | |
| kernel_mask = buildATensor(ctx, ins[0]); | |
| internal_kernel_label = buildATensor(ctx, ins[1]); | |
| auto out = contour_expand(kernel_mask, internal_kernel_label, min_kernel_area, | |
| kernel_num); | |
| int n = out.size(), m = 0; | |
| for (int i = 0; i < n; ++i) | |
| if (m < out[i].size()) m = out[i].size(); | |
| auto options = torch::TensorOptions().dtype(at::kInt); | |
| auto tensor = torch::zeros({n, m}, options); | |
| for (int i = 0; i < n; i++) | |
| tensor.slice(0, i, i + 1) = | |
| torch::from_blob(out[i].data(), {out[i].size()}, options); | |
| updateDArray(ctx, tensor, outs[0]); | |
| } | |
| PARROTS_EXTENSION_REGISTER(contour_expand) | |
| .attr("min_kernel_area") | |
| .attr("kernel_num") | |
| .input(2) | |
| .output(1) | |
| .apply(contour_expand_parrots<HostContext>) | |
| .done(); | |