| | #ifndef LM_SEARCH_HASHED_H |
| | #define LM_SEARCH_HASHED_H |
| |
|
| | #include "lm/model_type.hh" |
| | #include "lm/config.hh" |
| | #include "lm/read_arpa.hh" |
| | #include "lm/return.hh" |
| | #include "lm/weights.hh" |
| |
|
| | #include "util/bit_packing.hh" |
| | #include "util/probing_hash_table.hh" |
| |
|
| | #include <algorithm> |
| | #include <iostream> |
| | #include <vector> |
| |
|
| | namespace util { class FilePiece; } |
| |
|
| | namespace lm { |
| | namespace ngram { |
| | class BinaryFormat; |
| | class ProbingVocabulary; |
| | namespace detail { |
| |
|
| | inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { |
| | uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL); |
| | return ret; |
| | } |
| |
|
| | #pragma pack(push) |
| | #pragma pack(4) |
| | struct ProbEntry { |
| | uint64_t key; |
| | Prob value; |
| | typedef uint64_t Key; |
| | typedef Prob Value; |
| | uint64_t GetKey() const { |
| | return key; |
| | } |
| | }; |
| |
|
| | #pragma pack(pop) |
| |
|
| | class LongestPointer { |
| | public: |
| | explicit LongestPointer(const float &to) : to_(&to) {} |
| |
|
| | LongestPointer() : to_(NULL) {} |
| |
|
| | bool Found() const { |
| | return to_ != NULL; |
| | } |
| |
|
| | float Prob() const { |
| | return *to_; |
| | } |
| |
|
| | private: |
| | const float *to_; |
| | }; |
| |
|
| | template <class Value> class HashedSearch { |
| | public: |
| | typedef uint64_t Node; |
| |
|
| | typedef typename Value::ProbingProxy UnigramPointer; |
| | typedef typename Value::ProbingProxy MiddlePointer; |
| | typedef ::lm::ngram::detail::LongestPointer LongestPointer; |
| |
|
| | static const ModelType kModelType = Value::kProbingModelType; |
| | static const bool kDifferentRest = Value::kDifferentRest; |
| | static const unsigned int kVersion = 0; |
| |
|
| | |
| | static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {} |
| |
|
| | static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { |
| | uint64_t ret = Unigram::Size(counts[0]); |
| | for (unsigned char n = 1; n < counts.size() - 1; ++n) { |
| | ret += Middle::Size(counts[n], config.probing_multiplier); |
| | } |
| | return ret + Longest::Size(counts.back(), config.probing_multiplier); |
| | } |
| |
|
| | uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); |
| |
|
| | void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing); |
| |
|
| | unsigned char Order() const { |
| | return middle_.size() + 2; |
| | } |
| |
|
| | typename Value::Weights &UnknownUnigram() { return unigram_.Unknown(); } |
| |
|
| | UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const { |
| | extend_left = static_cast<uint64_t>(word); |
| | next = extend_left; |
| | UnigramPointer ret(unigram_.Lookup(word)); |
| | independent_left = ret.IndependentLeft(); |
| | return ret; |
| | } |
| |
|
| | MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const { |
| | node = extend_pointer; |
| | return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value); |
| | } |
| |
|
| | MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const { |
| | node = CombineWordHash(node, word); |
| | typename Middle::ConstIterator found; |
| | if (!middle_[order_minus_2].Find(node, found)) { |
| | independent_left = true; |
| | return MiddlePointer(); |
| | } |
| | extend_pointer = node; |
| | MiddlePointer ret(found->value); |
| | independent_left = ret.IndependentLeft(); |
| | return ret; |
| | } |
| |
|
| | LongestPointer LookupLongest(WordIndex word, const Node &node) const { |
| | |
| | typename Longest::ConstIterator found; |
| | if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer(); |
| | return LongestPointer(found->value.prob); |
| | } |
| |
|
| | |
| | |
| | bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { |
| | assert(begin != end); |
| | node = static_cast<Node>(*begin); |
| | for (const WordIndex *i = begin + 1; i < end; ++i) { |
| | node = CombineWordHash(node, *i); |
| | } |
| | return true; |
| | } |
| |
|
| | private: |
| | |
| | void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); |
| |
|
| | template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); |
| |
|
| | class Unigram { |
| | public: |
| | Unigram() {} |
| |
|
| | Unigram(void *start, uint64_t count) : |
| | unigram_(static_cast<typename Value::Weights*>(start)) |
| | #ifdef DEBUG |
| | , count_(count) |
| | #endif |
| | {} |
| |
|
| | static uint64_t Size(uint64_t count) { |
| | return (count + 1) * sizeof(typename Value::Weights); |
| | } |
| |
|
| | const typename Value::Weights &Lookup(WordIndex index) const { |
| | #ifdef DEBUG |
| | assert(index < count_); |
| | #endif |
| | return unigram_[index]; |
| | } |
| |
|
| | typename Value::Weights &Unknown() { return unigram_[0]; } |
| |
|
| | |
| | typename Value::Weights *Raw() { return unigram_; } |
| |
|
| | private: |
| | typename Value::Weights *unigram_; |
| | #ifdef DEBUG |
| | uint64_t count_; |
| | #endif |
| | }; |
| |
|
| | Unigram unigram_; |
| |
|
| | typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle; |
| | std::vector<Middle> middle_; |
| |
|
| | typedef util::ProbingHashTable<ProbEntry, util::IdentityHash> Longest; |
| | Longest longest_; |
| | }; |
| |
|
| | } |
| | } |
| | } |
| |
|
| | #endif |
| |
|