| #include "search/vertex.hh" |
|
|
| #include "search/context.hh" |
|
|
| #include <boost/unordered_map.hpp> |
|
|
| #include <algorithm> |
| #include <functional> |
| #include <cassert> |
|
|
| namespace search { |
|
|
| namespace { |
|
|
| const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); |
|
|
| class DivideLeft { |
| public: |
| explicit DivideLeft(unsigned char index) |
| : index_(index) {} |
|
|
| uint64_t operator()(const lm::ngram::ChartState &state) const { |
| return (index_ < state.left.length) ? |
| state.left.pointers[index_] : |
| (kCompleteAdd - state.left.full); |
| } |
|
|
| private: |
| unsigned char index_; |
| }; |
|
|
| class DivideRight { |
| public: |
| explicit DivideRight(unsigned char index) |
| : index_(index) {} |
|
|
| uint64_t operator()(const lm::ngram::ChartState &state) const { |
| return (index_ < state.right.length) ? |
| static_cast<uint64_t>(state.right.words[index_]) : |
| (kCompleteAdd - state.left.full); |
| } |
|
|
| private: |
| unsigned char index_; |
| }; |
|
|
| template <class Divider> void Split(const Divider ÷r, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) { |
| |
| typedef boost::unordered_map<uint64_t, std::size_t> Lookup; |
| Lookup lookup; |
| for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) { |
| uint64_t key = divider(i->state); |
| std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size()))); |
| if (res.second) { |
| extend.resize(extend.size() + 1); |
| extend.back().AppendHypothesis(*i); |
| } else { |
| extend[res.first->second].AppendHypothesis(*i); |
| } |
| } |
| |
| } |
|
|
| lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) { |
| return right.words[index]; |
| } |
|
|
| uint64_t Identify(const lm::ngram::Left &left, unsigned char index) { |
| return left.pointers[index]; |
| } |
|
|
| template <class Side> class DetermineSame { |
| public: |
| DetermineSame(const Side &side, unsigned char guaranteed) |
| : side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {} |
|
|
| void Consider(const Side &other) { |
| if (shared_ != other.length) { |
| complete_ = false; |
| if (shared_ > other.length) |
| shared_ = other.length; |
| } |
| for (unsigned char i = guaranteed_; i < shared_; ++i) { |
| if (Identify(side_, i) != Identify(other, i)) { |
| shared_ = i; |
| complete_ = false; |
| return; |
| } |
| } |
| } |
|
|
| unsigned char Shared() const { return shared_; } |
|
|
| bool Complete() const { return complete_; } |
|
|
| private: |
| const Side &side_; |
| unsigned char guaranteed_, shared_; |
| bool complete_; |
| }; |
|
|
| |
| |
| const unsigned char kPolicyAlternate = 0; |
| |
| const unsigned char kPolicyOneLeft = 1; |
| |
| const unsigned char kPolicyOneRight = 2; |
| |
| |
|
|
| } |
|
|
| namespace { |
| struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> { |
| bool operator()(const HypoState &first, const HypoState &second) const { |
| return first.score > second.score; |
| } |
| }; |
| } |
|
|
| void VertexNode::FinishRoot() { |
| std::sort(hypos_.begin(), hypos_.end(), GreaterByScore()); |
| extend_.clear(); |
| |
| state_.left.full = false; |
| state_.left.length = 0; |
| state_.right.length = 0; |
| right_full_ = false; |
| niceness_ = 0; |
| policy_ = kPolicyAlternate; |
| if (hypos_.size() == 1) { |
| extend_.resize(1); |
| extend_.front().AppendHypothesis(hypos_.front()); |
| extend_.front().FinishedAppending(0, 0); |
| } |
| if (hypos_.empty()) { |
| bound_ = -INFINITY; |
| } else { |
| bound_ = hypos_.front().score; |
| } |
| } |
|
|
| void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) { |
| assert(!hypos_.empty()); |
| assert(extend_.empty()); |
| bound_ = hypos_.front().score; |
| state_ = hypos_.front().state; |
| bool all_full = state_.left.full; |
| bool all_non_full = !state_.left.full; |
| DetermineSame<lm::ngram::Left> left(state_.left, common_left); |
| DetermineSame<lm::ngram::Right> right(state_.right, common_right); |
| for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) { |
| all_full &= i->state.left.full; |
| all_non_full &= !i->state.left.full; |
| left.Consider(i->state.left); |
| right.Consider(i->state.right); |
| } |
| state_.left.full = all_full && left.Complete(); |
| right_full_ = all_full && right.Complete(); |
| state_.left.length = left.Shared(); |
| state_.right.length = right.Shared(); |
|
|
| if (!all_full && !all_non_full) { |
| policy_ = kPolicyAlternate; |
| } else if (left.Complete()) { |
| policy_ = kPolicyOneRight; |
| } else if (right.Complete()) { |
| policy_ = kPolicyOneLeft; |
| } else { |
| policy_ = kPolicyAlternate; |
| } |
| niceness_ = state_.left.length + state_.right.length; |
| } |
|
|
| void VertexNode::BuildExtend() { |
| |
| if (!extend_.empty()) return; |
| |
| if (hypos_.size() <= 1) return; |
| bool left_branch = true; |
| switch (policy_) { |
| case kPolicyAlternate: |
| left_branch = (state_.left.length <= state_.right.length); |
| break; |
| case kPolicyOneLeft: |
| left_branch = true; |
| break; |
| case kPolicyOneRight: |
| left_branch = false; |
| break; |
| } |
| if (left_branch) { |
| Split(DivideLeft(state_.left.length), hypos_, extend_); |
| } else { |
| Split(DivideRight(state_.right.length), hypos_, extend_); |
| } |
| for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) { |
| |
| i->FinishedAppending(state_.left.length, state_.right.length); |
| } |
| } |
|
|
| } |
|
|