| | #include "search/vertex.hh" |
| |
|
| | #include "search/context.hh" |
| |
|
| | #include <boost/unordered_map.hpp> |
| |
|
| | #include <algorithm> |
| | #include <functional> |
| | #include <cassert> |
| |
|
| | namespace search { |
| |
|
| | namespace { |
| |
|
| | const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); |
| |
|
| | class DivideLeft { |
| | public: |
| | explicit DivideLeft(unsigned char index) |
| | : index_(index) {} |
| |
|
| | uint64_t operator()(const lm::ngram::ChartState &state) const { |
| | return (index_ < state.left.length) ? |
| | state.left.pointers[index_] : |
| | (kCompleteAdd - state.left.full); |
| | } |
| |
|
| | private: |
| | unsigned char index_; |
| | }; |
| |
|
| | class DivideRight { |
| | public: |
| | explicit DivideRight(unsigned char index) |
| | : index_(index) {} |
| |
|
| | uint64_t operator()(const lm::ngram::ChartState &state) const { |
| | return (index_ < state.right.length) ? |
| | static_cast<uint64_t>(state.right.words[index_]) : |
| | (kCompleteAdd - state.left.full); |
| | } |
| |
|
| | private: |
| | unsigned char index_; |
| | }; |
| |
|
| | template <class Divider> void Split(const Divider ÷r, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) { |
| | |
| | typedef boost::unordered_map<uint64_t, std::size_t> Lookup; |
| | Lookup lookup; |
| | for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) { |
| | uint64_t key = divider(i->state); |
| | std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size()))); |
| | if (res.second) { |
| | extend.resize(extend.size() + 1); |
| | extend.back().AppendHypothesis(*i); |
| | } else { |
| | extend[res.first->second].AppendHypothesis(*i); |
| | } |
| | } |
| | |
| | } |
| |
|
| | lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) { |
| | return right.words[index]; |
| | } |
| |
|
| | uint64_t Identify(const lm::ngram::Left &left, unsigned char index) { |
| | return left.pointers[index]; |
| | } |
| |
|
| | template <class Side> class DetermineSame { |
| | public: |
| | DetermineSame(const Side &side, unsigned char guaranteed) |
| | : side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {} |
| |
|
| | void Consider(const Side &other) { |
| | if (shared_ != other.length) { |
| | complete_ = false; |
| | if (shared_ > other.length) |
| | shared_ = other.length; |
| | } |
| | for (unsigned char i = guaranteed_; i < shared_; ++i) { |
| | if (Identify(side_, i) != Identify(other, i)) { |
| | shared_ = i; |
| | complete_ = false; |
| | return; |
| | } |
| | } |
| | } |
| |
|
| | unsigned char Shared() const { return shared_; } |
| |
|
| | bool Complete() const { return complete_; } |
| |
|
| | private: |
| | const Side &side_; |
| | unsigned char guaranteed_, shared_; |
| | bool complete_; |
| | }; |
| |
|
| | |
| | |
| | const unsigned char kPolicyAlternate = 0; |
| | |
| | const unsigned char kPolicyOneLeft = 1; |
| | |
| | const unsigned char kPolicyOneRight = 2; |
| | |
| | |
| |
|
| | } |
| |
|
| | namespace { |
| | struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> { |
| | bool operator()(const HypoState &first, const HypoState &second) const { |
| | return first.score > second.score; |
| | } |
| | }; |
| | } |
| |
|
| | void VertexNode::FinishRoot() { |
| | std::sort(hypos_.begin(), hypos_.end(), GreaterByScore()); |
| | extend_.clear(); |
| | |
| | state_.left.full = false; |
| | state_.left.length = 0; |
| | state_.right.length = 0; |
| | right_full_ = false; |
| | niceness_ = 0; |
| | policy_ = kPolicyAlternate; |
| | if (hypos_.size() == 1) { |
| | extend_.resize(1); |
| | extend_.front().AppendHypothesis(hypos_.front()); |
| | extend_.front().FinishedAppending(0, 0); |
| | } |
| | if (hypos_.empty()) { |
| | bound_ = -INFINITY; |
| | } else { |
| | bound_ = hypos_.front().score; |
| | } |
| | } |
| |
|
| | void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) { |
| | assert(!hypos_.empty()); |
| | assert(extend_.empty()); |
| | bound_ = hypos_.front().score; |
| | state_ = hypos_.front().state; |
| | bool all_full = state_.left.full; |
| | bool all_non_full = !state_.left.full; |
| | DetermineSame<lm::ngram::Left> left(state_.left, common_left); |
| | DetermineSame<lm::ngram::Right> right(state_.right, common_right); |
| | for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) { |
| | all_full &= i->state.left.full; |
| | all_non_full &= !i->state.left.full; |
| | left.Consider(i->state.left); |
| | right.Consider(i->state.right); |
| | } |
| | state_.left.full = all_full && left.Complete(); |
| | right_full_ = all_full && right.Complete(); |
| | state_.left.length = left.Shared(); |
| | state_.right.length = right.Shared(); |
| |
|
| | if (!all_full && !all_non_full) { |
| | policy_ = kPolicyAlternate; |
| | } else if (left.Complete()) { |
| | policy_ = kPolicyOneRight; |
| | } else if (right.Complete()) { |
| | policy_ = kPolicyOneLeft; |
| | } else { |
| | policy_ = kPolicyAlternate; |
| | } |
| | niceness_ = state_.left.length + state_.right.length; |
| | } |
| |
|
| | void VertexNode::BuildExtend() { |
| | |
| | if (!extend_.empty()) return; |
| | |
| | if (hypos_.size() <= 1) return; |
| | bool left_branch = true; |
| | switch (policy_) { |
| | case kPolicyAlternate: |
| | left_branch = (state_.left.length <= state_.right.length); |
| | break; |
| | case kPolicyOneLeft: |
| | left_branch = true; |
| | break; |
| | case kPolicyOneRight: |
| | left_branch = false; |
| | break; |
| | } |
| | if (left_branch) { |
| | Split(DivideLeft(state_.left.length), hypos_, extend_); |
| | } else { |
| | Split(DivideRight(state_.right.length), hypos_, extend_); |
| | } |
| | for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) { |
| | |
| | i->FinishedAppending(state_.left.length, state_.right.length); |
| | } |
| | } |
| |
|
| | } |
| |
|