File size: 6,006 Bytes
fd49381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
#include "search/vertex.hh"
#include "search/context.hh"
#include <boost/unordered_map.hpp>
#include <algorithm>
#include <functional>
#include <cassert>
namespace search {
namespace {
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
class DivideLeft {
public:
explicit DivideLeft(unsigned char index)
: index_(index) {}
uint64_t operator()(const lm::ngram::ChartState &state) const {
return (index_ < state.left.length) ?
state.left.pointers[index_] :
(kCompleteAdd - state.left.full);
}
private:
unsigned char index_;
};
class DivideRight {
public:
explicit DivideRight(unsigned char index)
: index_(index) {}
uint64_t operator()(const lm::ngram::ChartState &state) const {
return (index_ < state.right.length) ?
static_cast<uint64_t>(state.right.words[index_]) :
(kCompleteAdd - state.left.full);
}
private:
unsigned char index_;
};
template <class Divider> void Split(const Divider ÷r, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) {
// Map from divider to index in extend.
typedef boost::unordered_map<uint64_t, std::size_t> Lookup;
Lookup lookup;
for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) {
uint64_t key = divider(i->state);
std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size())));
if (res.second) {
extend.resize(extend.size() + 1);
extend.back().AppendHypothesis(*i);
} else {
extend[res.first->second].AppendHypothesis(*i);
}
}
//assert((extend.size() != 1) || (hypos.size() == 1));
}
lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) {
return right.words[index];
}
uint64_t Identify(const lm::ngram::Left &left, unsigned char index) {
return left.pointers[index];
}
template <class Side> class DetermineSame {
public:
DetermineSame(const Side &side, unsigned char guaranteed)
: side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {}
void Consider(const Side &other) {
if (shared_ != other.length) {
complete_ = false;
if (shared_ > other.length)
shared_ = other.length;
}
for (unsigned char i = guaranteed_; i < shared_; ++i) {
if (Identify(side_, i) != Identify(other, i)) {
shared_ = i;
complete_ = false;
return;
}
}
}
unsigned char Shared() const { return shared_; }
bool Complete() const { return complete_; }
private:
const Side &side_;
unsigned char guaranteed_, shared_;
bool complete_;
};
// Custom enum to save memory: valid values of policy_.
// Alternate and there is still alternation to do.
const unsigned char kPolicyAlternate = 0;
// Branch based on left state only, because right ran out or this is a left tree.
const unsigned char kPolicyOneLeft = 1;
// Branch based on right state only.
const unsigned char kPolicyOneRight = 2;
// Reveal everything in the next branch. Used to terminate the left/right policies.
// static const unsigned char kPolicyEverything = 3;
} // namespace
namespace {
struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> {
bool operator()(const HypoState &first, const HypoState &second) const {
return first.score > second.score;
}
};
} // namespace
void VertexNode::FinishRoot() {
std::sort(hypos_.begin(), hypos_.end(), GreaterByScore());
extend_.clear();
// HACK: extend to one hypo so that root can be blank.
state_.left.full = false;
state_.left.length = 0;
state_.right.length = 0;
right_full_ = false;
niceness_ = 0;
policy_ = kPolicyAlternate;
if (hypos_.size() == 1) {
extend_.resize(1);
extend_.front().AppendHypothesis(hypos_.front());
extend_.front().FinishedAppending(0, 0);
}
if (hypos_.empty()) {
bound_ = -INFINITY;
} else {
bound_ = hypos_.front().score;
}
}
void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) {
assert(!hypos_.empty());
assert(extend_.empty());
bound_ = hypos_.front().score;
state_ = hypos_.front().state;
bool all_full = state_.left.full;
bool all_non_full = !state_.left.full;
DetermineSame<lm::ngram::Left> left(state_.left, common_left);
DetermineSame<lm::ngram::Right> right(state_.right, common_right);
for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) {
all_full &= i->state.left.full;
all_non_full &= !i->state.left.full;
left.Consider(i->state.left);
right.Consider(i->state.right);
}
state_.left.full = all_full && left.Complete();
right_full_ = all_full && right.Complete();
state_.left.length = left.Shared();
state_.right.length = right.Shared();
if (!all_full && !all_non_full) {
policy_ = kPolicyAlternate;
} else if (left.Complete()) {
policy_ = kPolicyOneRight;
} else if (right.Complete()) {
policy_ = kPolicyOneLeft;
} else {
policy_ = kPolicyAlternate;
}
niceness_ = state_.left.length + state_.right.length;
}
void VertexNode::BuildExtend() {
// Already built.
if (!extend_.empty()) return;
// Nothing to build since this is a leaf.
if (hypos_.size() <= 1) return;
bool left_branch = true;
switch (policy_) {
case kPolicyAlternate:
left_branch = (state_.left.length <= state_.right.length);
break;
case kPolicyOneLeft:
left_branch = true;
break;
case kPolicyOneRight:
left_branch = false;
break;
}
if (left_branch) {
Split(DivideLeft(state_.left.length), hypos_, extend_);
} else {
Split(DivideRight(state_.right.length), hypos_, extend_);
}
for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
// TODO: provide more here for branching?
i->FinishedAppending(state_.left.length, state_.right.length);
}
}
} // namespace search
|