|
|
|
|
|
|
|
|
|
|
|
#include "text_region_grouping.h" |
|
|
|
|
|
#include <algorithm> |
|
|
#include <memory> |
|
|
#include <unordered_set> |
|
|
#include <unordered_map> |
|
|
#include <chrono> |
|
|
#include <stack> |
|
|
|
|
|
#include "../geometry.h" |
|
|
#include "../common.h" |
|
|
#include "../scope_timer.h" |
|
|
#include "../non_maximal_suppression/nms_kd_tree.h" |
|
|
|
|
|
using namespace std; |
|
|
|
|
|
|
|
|
vector<vector<int64_t>> relations_to_clusters(const unordered_map<int64_t, int64_t> &lineRelations, int64_t numQuads) |
|
|
{ |
|
|
unordered_map<int64_t, int64_t> reverseLookup; |
|
|
for (auto &kv : lineRelations) { |
|
|
reverseLookup.emplace(kv.second, kv.first); |
|
|
} |
|
|
|
|
|
vector<TextLine> ret; |
|
|
|
|
|
unordered_set<int64_t> visited; |
|
|
for (auto &kv : lineRelations) { |
|
|
int64_t root = kv.first; |
|
|
if (visited.count(root)) continue; |
|
|
|
|
|
|
|
|
bool bad = false; |
|
|
auto rlIter = reverseLookup.find(root); |
|
|
while (rlIter != reverseLookup.end()) { |
|
|
root = rlIter->second; |
|
|
rlIter = reverseLookup.find(root); |
|
|
if (visited.count(root)) { |
|
|
bad = true; |
|
|
break; |
|
|
} |
|
|
visited.insert(root); |
|
|
} |
|
|
|
|
|
|
|
|
if (bad) continue; |
|
|
|
|
|
|
|
|
TextLine line; |
|
|
auto iter = lineRelations.end(); |
|
|
do |
|
|
{ |
|
|
line.push_back(root); |
|
|
visited.insert(root); |
|
|
iter = lineRelations.find(root); |
|
|
if (iter != lineRelations.end()) { |
|
|
root = iter->second; |
|
|
} |
|
|
} while (iter != lineRelations.end()); |
|
|
|
|
|
ret.push_back(move(line)); |
|
|
} |
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < numQuads; ++i) { |
|
|
if (! visited.count(i)) { |
|
|
TextLine line; |
|
|
line.push_back(i); |
|
|
|
|
|
ret.push_back(move(line)); |
|
|
} |
|
|
} |
|
|
|
|
|
return ret; |
|
|
} |
|
|
|
|
|
template<typename T> |
|
|
inline T default_match(const Quad_<T> &a, const Quad_<T> &query, const Quad_<T> &b) |
|
|
{ |
|
|
return std::max<T>(intersection_area(query, b), 0); |
|
|
} |
|
|
|
|
|
template<typename T> |
|
|
inline T height_match(const Quad_<T> &a, const Quad_<T> &query, const Quad_<T> &b) |
|
|
{ |
|
|
T aHeight = a.Height(); |
|
|
T bHeight = b.Height(); |
|
|
|
|
|
T ratio = aHeight / bHeight; |
|
|
if (ratio > 1) { |
|
|
ratio = 1 / ratio; |
|
|
} |
|
|
|
|
|
|
|
|
if (ratio < 0.5) { |
|
|
return 0; |
|
|
} |
|
|
|
|
|
T dfMatch = default_match(a, query, b); |
|
|
return dfMatch * ratio; |
|
|
} |
|
|
|
|
|
template<typename T, typename CtorFn, typename MatchFn> |
|
|
vector<vector<int64_t>> cluster_quads(const vector<Quad_<T>> &vQuads, CtorFn queryConstructor, MatchFn matchFn) |
|
|
{ |
|
|
torch::Tensor tAllIxAreas = torch::zeros({ (int)vQuads.size(), (int)vQuads.size() }, torch::kFloat32); |
|
|
auto accAllIxAreas = tAllIxAreas.accessor<float, 2>(); |
|
|
|
|
|
NMS_KDTree<Quad_<T>> kdTree; |
|
|
kdTree.Build(vQuads); |
|
|
|
|
|
for (int64_t i = 0; i < vQuads.size(); ++i) { |
|
|
for (int64_t direction = 0; direction < 2; ++direction) { |
|
|
auto queryPts = queryConstructor(i, direction); |
|
|
Quad_<T> queryQuad{ queryPts.data() }; |
|
|
|
|
|
kdTree.FindIntersections(queryQuad, |
|
|
[i, &accAllIxAreas, &vQuads, &queryQuad, &matchFn, direction] |
|
|
(int64_t k, float pctN, float pctM, float bdsIOU) |
|
|
{ |
|
|
if (i == k) return; |
|
|
|
|
|
auto oI = i, oK = k; |
|
|
if (direction == 1) { |
|
|
swap(oI, oK); |
|
|
} |
|
|
|
|
|
float matchVal = matchFn(vQuads[oI], queryQuad, vQuads[oK]); |
|
|
accAllIxAreas[oI][oK] = max(accAllIxAreas[oI][oK], matchVal); |
|
|
} |
|
|
); |
|
|
} |
|
|
} |
|
|
|
|
|
torch::Tensor tAllIxIdxs; |
|
|
tie(tAllIxAreas, tAllIxIdxs) = torch::sort(tAllIxAreas, 1, true); |
|
|
|
|
|
accAllIxAreas = tAllIxAreas.accessor<float, 2>(); |
|
|
auto accAllIxIdxs = tAllIxIdxs.accessor<int64_t, 2>(); |
|
|
|
|
|
stack<tuple<int64_t, int64_t>> idxsToProcess; |
|
|
for (int64_t i = 0; i < vQuads.size(); ++i) { |
|
|
idxsToProcess.emplace(i, 0); |
|
|
} |
|
|
|
|
|
unordered_map<int64_t, tuple<int64_t, T, int64_t>> ownerLookup; |
|
|
|
|
|
while (! idxsToProcess.empty()) { |
|
|
int64_t i, k; |
|
|
tie(i, k) = idxsToProcess.top(); |
|
|
idxsToProcess.pop(); |
|
|
|
|
|
for (; k < vQuads.size(); ++k) { |
|
|
T ixArea = accAllIxAreas[i][k]; |
|
|
|
|
|
|
|
|
if (ixArea == 0) break; |
|
|
|
|
|
int64_t oIdx = accAllIxIdxs[i][k]; |
|
|
auto ownerIter = ownerLookup.find(oIdx); |
|
|
|
|
|
if (ownerIter == ownerLookup.end()) { |
|
|
ownerLookup.emplace(oIdx, make_tuple(i, ixArea, k)); |
|
|
break; |
|
|
} else { |
|
|
int64_t exI, exK; |
|
|
T exIxArea; |
|
|
tie(exI, exIxArea, exK) = ownerIter->second; |
|
|
|
|
|
|
|
|
if (ixArea > exIxArea) { |
|
|
ownerIter->second = make_tuple(i, ixArea, k); |
|
|
|
|
|
idxsToProcess.emplace(exI, exK + 1); |
|
|
break; |
|
|
} |
|
|
|
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
unordered_map<int64_t, int64_t> bijection; |
|
|
for (auto &kv : ownerLookup) { |
|
|
bijection.emplace(get<0>(kv.second), kv.first); |
|
|
} |
|
|
|
|
|
return relations_to_clusters(bijection, vQuads.size()); |
|
|
} |
|
|
|
|
|
template<typename T> |
|
|
vector<TextLine> quads_to_lines(const vector<Quad_<T>> &vQuads, T horizontalTolerance) |
|
|
{ |
|
|
auto queryCtor = [&] (int64_t i, int64_t direction) { |
|
|
const Quad_<T> &currQuad = vQuads[i]; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Point_<T> d1 = currQuad[1] - currQuad[0]; |
|
|
Point_<T> d2 = currQuad[2] - currQuad[3]; |
|
|
Point_<T> dEnd = direction == 0 ? (currQuad[2] - currQuad[1]) : (currQuad[3] - currQuad[0]); |
|
|
|
|
|
T w1 = length(d1); |
|
|
T w2 = length(d2); |
|
|
T endHeight = length(dEnd); |
|
|
T width = (w1 + w2) / 2; |
|
|
|
|
|
d1 /= w1; |
|
|
d2 /= w2; |
|
|
dEnd /= endHeight; |
|
|
|
|
|
T avgCharWidth = std::max<T>(endHeight * 0.75f, 1.0f); |
|
|
|
|
|
Point_<T> endPt = direction == 0 ? currQuad[1] : currQuad[0]; |
|
|
|
|
|
Point_<T> rp0 = endPt + (T(0.1) * endHeight * dEnd); |
|
|
Point_<T> rp1 = endPt + (T(0.9) * endHeight * dEnd); |
|
|
|
|
|
if (direction == 1) { |
|
|
d1 *= -1.0f; |
|
|
d2 *= -1.0f; |
|
|
} |
|
|
|
|
|
Point_<T> qp1 = rp0 + (avgCharWidth * horizontalTolerance * d1); |
|
|
Point_<T> qp2 = rp1 + (avgCharWidth * horizontalTolerance * d2); |
|
|
|
|
|
if (direction == 0) { |
|
|
|
|
|
array<Point_<T>, 4> pts{ rp0, qp1, qp2, rp1 }; |
|
|
|
|
|
return pts; |
|
|
} else { |
|
|
array<Point_<T>, 4> pts{ qp1, rp0, rp1, qp2 }; |
|
|
|
|
|
return pts; |
|
|
} |
|
|
}; |
|
|
|
|
|
return cluster_quads(vQuads, queryCtor, height_match<T>); |
|
|
} |
|
|
|
|
|
template<typename T> |
|
|
PhraseList lines_to_phrases(const vector<Quad_<T>> &vQuads, const vector<TextLine> &lines, |
|
|
T verticalTolerance) |
|
|
{ |
|
|
vector<array<Point_<T>, 4>> linesPts; |
|
|
for (const TextLine &line : lines) { |
|
|
const Quad_<T> &leftQuad = vQuads[line.front()]; |
|
|
const Quad_<T> &rightQuad = vQuads[line.back()]; |
|
|
|
|
|
linesPts.push_back({leftQuad[0], rightQuad[1], rightQuad[2], leftQuad[3]}); |
|
|
} |
|
|
|
|
|
vector<Quad_<T>> vLines; |
|
|
for (auto &line : linesPts) { |
|
|
vLines.emplace_back(line.data()); |
|
|
} |
|
|
|
|
|
auto queryCtor = [&] (int64_t i, int64_t direction) { |
|
|
const Quad_<T> &currQuad = vLines[i]; |
|
|
|
|
|
Point_<T> d1 = currQuad[3] - currQuad[0]; |
|
|
Point_<T> d2 = currQuad[2] - currQuad[1]; |
|
|
|
|
|
if (direction == 0) { |
|
|
Point_<T> qp1 = currQuad[3] + (verticalTolerance * d1); |
|
|
Point_<T> qp2 = currQuad[2] + (verticalTolerance * d2); |
|
|
|
|
|
array<Point_<T>, 4> pts{ currQuad[3], currQuad[2], qp2, qp1 }; |
|
|
|
|
|
return pts; |
|
|
} else { |
|
|
Point_<T> qp1 = currQuad[0] - (verticalTolerance * d1); |
|
|
Point_<T> qp2 = currQuad[1] - (verticalTolerance * d2); |
|
|
|
|
|
array<Point_<T>, 4> pts{ qp1, qp2, currQuad[1], currQuad[0] }; |
|
|
|
|
|
return pts; |
|
|
} |
|
|
}; |
|
|
|
|
|
vector<vector<int64_t>> phraseClusters = cluster_quads(vLines, queryCtor, height_match<T>); |
|
|
|
|
|
PhraseList phrases; |
|
|
for (const vector<int64_t> &lineIdxs : phraseClusters) { |
|
|
Phrase phrase; |
|
|
for (int64_t lineIdx : lineIdxs) { |
|
|
phrase.push_back(lines[lineIdx]); |
|
|
} |
|
|
phrases.push_back(move(phrase)); |
|
|
} |
|
|
|
|
|
return phrases; |
|
|
} |
|
|
|
|
|
|
|
|
template<typename T> |
|
|
PhraseList process_image(torch::Tensor quads, |
|
|
T horizontalTolerance, T verticalTolerance, bool verbose) |
|
|
{ |
|
|
static bool s_timerEnabled = true; |
|
|
|
|
|
if (verbose) { |
|
|
cout << "Text Grouper - Processing Image..." << endl; |
|
|
} |
|
|
|
|
|
auto quadsAccess = quads.accessor<T, 3>(); |
|
|
|
|
|
vector<Quad_<T>> vQuads; |
|
|
for (int64_t i = 0; i < quadsAccess.size(0); ++i) { |
|
|
vQuads.emplace_back(quadsAccess[i].data()); |
|
|
} |
|
|
|
|
|
double tQuadsToLines, tLinesToPhrases; |
|
|
vector<TextLine> lines; |
|
|
PhraseList phrases; |
|
|
|
|
|
{ |
|
|
|
|
|
CudaStoreTimer t(tQuadsToLines, s_timerEnabled && verbose, false); |
|
|
lines = quads_to_lines(vQuads, horizontalTolerance); |
|
|
} |
|
|
|
|
|
{ |
|
|
|
|
|
CudaStoreTimer t(tLinesToPhrases, s_timerEnabled && verbose, false); |
|
|
phrases = lines_to_phrases(vQuads, lines, verticalTolerance); |
|
|
} |
|
|
|
|
|
if (s_timerEnabled && verbose) { |
|
|
cout << "Text Grouper " << quads.size(0) |
|
|
<< " - To Lines: " << tQuadsToLines << "ms" |
|
|
<< ", To Phrases: " << tLinesToPhrases << "ms" |
|
|
<< endl; |
|
|
} |
|
|
|
|
|
return phrases; |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<PhraseList> text_region_grouping(torch::Tensor sparseQuads, torch::Tensor sparseCounts, |
|
|
float horizontalTolerance, |
|
|
float verticalTolerance, |
|
|
bool verbose) |
|
|
{ |
|
|
sparseQuads = sparseQuads.to(torch::kFloat32); |
|
|
sparseCounts = sparseCounts.to(torch::kInt64); |
|
|
|
|
|
auto countsAccess = sparseCounts.accessor<int64_t, 1>(); |
|
|
|
|
|
vector<PhraseList> ret; |
|
|
|
|
|
int64_t offset = 0, ct = 0; |
|
|
for (int64_t i = 0; i < countsAccess.size(0); ++i, offset += ct) { |
|
|
ct = countsAccess[i]; |
|
|
|
|
|
auto currQuads = sparseQuads.slice(0, offset, offset + ct); |
|
|
|
|
|
ret.push_back(process_image<float>(currQuads, horizontalTolerance, verticalTolerance, verbose)); |
|
|
} |
|
|
|
|
|
return ret; |
|
|
} |
|
|
|