|
|
|
|
|
|
|
|
|
|
|
#include "text_region_grouping.h" |
|
|
|
|
|
#include <algorithm> |
|
|
#include <memory> |
|
|
#include <unordered_set> |
|
|
#include <unordered_map> |
|
|
#include <chrono> |
|
|
#include <stack> |
|
|
#include <numeric> |
|
|
#include <vector> |
|
|
|
|
|
using namespace std; |
|
|
|
|
|
PhraseList rel_list_to_phrases(const relations_list_t &relList) |
|
|
{ |
|
|
PhraseList ret; |
|
|
ret.reserve(relList.size()); |
|
|
|
|
|
for (const text_line_t &line : relList) { |
|
|
TextLine tl; |
|
|
tl.reserve(line.size()); |
|
|
|
|
|
for (const auto &rel : line) { |
|
|
tl.push_back(get<0>(rel)); |
|
|
} |
|
|
|
|
|
ret.push_back({ move(tl) }); |
|
|
} |
|
|
|
|
|
return ret; |
|
|
} |
|
|
|
|
|
template<typename rel_to_2_from_map_t, typename T> |
|
|
relations_list_t rel_chain_to_groups(const rel_to_2_from_map_t &inChain, int64_t numRegions, const T *inProbs); |
|
|
|
|
|
template<typename T> |
|
|
relations_list_t dense_relations_to_graph_impl(torch::Tensor relationsTensor) |
|
|
{ |
|
|
if (relationsTensor.size(0) == 0) { |
|
|
return relations_list_t{}; |
|
|
} |
|
|
|
|
|
if (relationsTensor.size(0) != relationsTensor.size(1)) { |
|
|
throw std::runtime_error("The relations tensor must be a square matrix!"); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto relations = relationsTensor.accessor<T, 2>(); |
|
|
|
|
|
const int64_t numRegions = relationsTensor.size(0); |
|
|
torch::Tensor fromIdxsTensor = torch::full({ numRegions }, -1, torch::kInt64); |
|
|
torch::Tensor fromProbsTensor = torch::zeros({ numRegions }, relationsTensor.options()); |
|
|
|
|
|
|
|
|
|
|
|
auto fromIdxs = fromIdxsTensor.data_ptr<int64_t>(); |
|
|
auto fromProbs = fromProbsTensor.data_ptr<T>(); |
|
|
|
|
|
for (int64_t fromIdx = 0; fromIdx < numRegions; ++fromIdx) { |
|
|
auto fromRel = relations[fromIdx]; |
|
|
|
|
|
for (int64_t toIdx = 0; toIdx < numRegions; ++toIdx) { |
|
|
auto relProb = fromRel[toIdx]; |
|
|
|
|
|
if (relProb >= 0.5) { |
|
|
T &maxProb = fromProbs[toIdx]; |
|
|
if (fromIdxs[toIdx] == -1 || relProb > maxProb) { |
|
|
fromIdxs[toIdx] = fromIdx; |
|
|
maxProb = relProb; |
|
|
} |
|
|
|
|
|
|
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
return rel_chain_to_groups(fromIdxs, numRegions, fromProbs); |
|
|
} |
|
|
|
|
|
relations_list_t dense_relations_to_graph_with_probs(torch::Tensor relationsTensor) |
|
|
{ |
|
|
relations_list_t ret; |
|
|
AT_DISPATCH_FLOATING_TYPES( |
|
|
relationsTensor.scalar_type(), |
|
|
"dense_relations_to_graph", |
|
|
([&] { |
|
|
ret = dense_relations_to_graph_impl<scalar_t>(relationsTensor); |
|
|
}) |
|
|
); |
|
|
return ret; |
|
|
} |
|
|
|
|
|
PhraseList dense_relations_to_graph(torch::Tensor relations) |
|
|
{ |
|
|
return rel_list_to_phrases(dense_relations_to_graph_with_probs(relations)); |
|
|
} |
|
|
|
|
|
template<typename T> |
|
|
relations_list_t sparse_relations_to_graph_impl(torch::Tensor relationsTensor, torch::Tensor neighborIdxsTensor) |
|
|
{ |
|
|
if (relationsTensor.size(0) == 0) { |
|
|
return relations_list_t{}; |
|
|
} |
|
|
|
|
|
auto maxRelsTensor = torch::zeros({ relationsTensor.size(0) }, relationsTensor.options()); |
|
|
auto fromIdxsTensor = torch::full({ relationsTensor.size(0) }, -1, torch::kInt64); |
|
|
|
|
|
auto relations = relationsTensor.accessor<T, 2>(); |
|
|
auto neighborIdxs = neighborIdxsTensor.accessor<int64_t, 2>(); |
|
|
auto maxRels = maxRelsTensor.data_ptr<T>(); |
|
|
auto fromIdxs = fromIdxsTensor.data_ptr<int64_t>(); |
|
|
|
|
|
const int64_t N = relationsTensor.size(0); |
|
|
const int64_t K = relationsTensor.size(1); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int64_t fromIdx = 0; fromIdx < N; ++fromIdx) { |
|
|
auto fromNeighborIdxs = neighborIdxs[fromIdx].data(); |
|
|
auto fromRelations = relations[fromIdx].data(); |
|
|
|
|
|
|
|
|
for (int64_t c = 1; c < K; ++c) { |
|
|
|
|
|
int64_t toIdx = fromNeighborIdxs[c] - 1; |
|
|
|
|
|
T toProb = fromRelations[c]; |
|
|
|
|
|
if (toProb > 0.5f) { |
|
|
T &bestProb = maxRels[toIdx]; |
|
|
if (toProb > bestProb) { |
|
|
bestProb = toProb; |
|
|
fromIdxs[toIdx] = fromIdx; |
|
|
} |
|
|
|
|
|
|
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
return rel_chain_to_groups(fromIdxs, N, maxRels); |
|
|
} |
|
|
|
|
|
relations_list_t sparse_relations_to_graph(torch::Tensor relationsTensor, torch::Tensor neighborIdxs) |
|
|
{ |
|
|
relations_list_t ret; |
|
|
|
|
|
AT_DISPATCH_FLOATING_TYPES( |
|
|
relationsTensor.scalar_type(), |
|
|
"sparse_relations_to_graph", |
|
|
([&] { |
|
|
ret = sparse_relations_to_graph_impl<scalar_t>(relationsTensor, neighborIdxs); |
|
|
}) |
|
|
); |
|
|
|
|
|
return ret; |
|
|
} |
|
|
|
|
|
template<typename rel_to_2_from_map_t, typename T> |
|
|
relations_list_t rel_chain_to_groups(const rel_to_2_from_map_t &inChain, const int64_t numRegions, const T *inProbs) |
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto outChainTensor = torch::full({ numRegions }, -1, torch::kInt64); |
|
|
auto outChain = outChainTensor.data_ptr<int64_t>(); |
|
|
|
|
|
auto outProbsTensor = torch::ones({ numRegions }, torch::kFloat); |
|
|
auto outProbs = outProbsTensor.data_ptr<float>(); |
|
|
|
|
|
for (int64_t toIdx = 0; toIdx < numRegions; ++toIdx) { |
|
|
int64_t fromIdx = inChain[toIdx]; |
|
|
if (fromIdx != -1) { |
|
|
outChain[fromIdx] = toIdx; |
|
|
outProbs[fromIdx] = static_cast<float>(inProbs[toIdx]); |
|
|
} |
|
|
} |
|
|
|
|
|
std::vector<bool> processed; processed.resize(numRegions, false); |
|
|
|
|
|
text_line_t currChain; currChain.reserve(32); |
|
|
relations_list_t groups; |
|
|
|
|
|
for (int64_t toIdx = 0; toIdx < numRegions; ++toIdx) { |
|
|
int64_t fromIdx = inChain[toIdx]; |
|
|
|
|
|
if (fromIdx == -1 || processed[toIdx]) { |
|
|
continue; |
|
|
} |
|
|
|
|
|
processed[toIdx] = true; |
|
|
currChain.clear(); |
|
|
currChain.emplace_back(toIdx, outProbs[fromIdx]); |
|
|
|
|
|
int64_t currIdx = toIdx; |
|
|
while (true) { |
|
|
fromIdx = inChain[currIdx]; |
|
|
|
|
|
if (fromIdx == -1 || processed[fromIdx]) { |
|
|
break; |
|
|
} |
|
|
|
|
|
processed[fromIdx] = true; |
|
|
currChain.emplace_back(fromIdx, outProbs[fromIdx]); |
|
|
currIdx = fromIdx; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
text_line_t group{ std::rbegin(currChain), std::rend(currChain) }; |
|
|
|
|
|
|
|
|
int64_t nextIdx = toIdx; |
|
|
while (true) { |
|
|
int64_t nextToIdx = outChain[nextIdx]; |
|
|
|
|
|
if (nextToIdx == -1 || processed[nextToIdx]) { |
|
|
break; |
|
|
} |
|
|
|
|
|
processed[nextToIdx] = true; |
|
|
group.emplace_back(nextToIdx, static_cast<float>(inProbs[nextToIdx])); |
|
|
nextIdx = nextToIdx; |
|
|
} |
|
|
|
|
|
groups.push_back(move(group)); |
|
|
} |
|
|
|
|
|
|
|
|
for (int64_t wIdx = 0; wIdx < numRegions; ++wIdx) { |
|
|
if (! processed[wIdx]) { |
|
|
groups.push_back({ { wIdx, 1.0f } }); |
|
|
} |
|
|
} |
|
|
|
|
|
return groups; |
|
|
} |
|
|
|