| #include "lm/builder/adjust_counts.hh" |
| #include "lm/common/ngram_stream.hh" |
| #include "lm/builder/payload.hh" |
| #include "util/stream/timer.hh" |
|
|
| #include <algorithm> |
| #include <iostream> |
| #include <limits> |
|
|
| namespace lm { namespace builder { |
|
|
| BadDiscountException::BadDiscountException() throw() {} |
| BadDiscountException::~BadDiscountException() throw() {} |
|
|
| namespace { |
| |
| const WordIndex* FindDifference(const NGram<BuildingPayload> &full, const NGram<BuildingPayload> &lower_last) { |
| const WordIndex *cur_word = full.end() - 1; |
| const WordIndex *pre_word = lower_last.end() - 1; |
| |
| for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {} |
| return cur_word; |
| } |
|
|
| class StatCollector { |
| public: |
| StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<uint64_t> &counts_pruned, std::vector<Discount> &discounts) |
| : orders_(order), full_(orders_.back()), counts_(counts), counts_pruned_(counts_pruned), discounts_(discounts) { |
| memset(&orders_[0], 0, sizeof(OrderStat) * order); |
| } |
|
|
| ~StatCollector() {} |
|
|
| void CalculateDiscounts(const DiscountConfig &config) { |
| counts_.resize(orders_.size()); |
| counts_pruned_.resize(orders_.size()); |
| for (std::size_t i = 0; i < orders_.size(); ++i) { |
| const OrderStat &s = orders_[i]; |
| counts_[i] = s.count; |
| counts_pruned_[i] = s.count_pruned; |
| } |
|
|
| discounts_ = config.overwrite; |
| discounts_.resize(orders_.size()); |
| for (std::size_t i = config.overwrite.size(); i < orders_.size(); ++i) { |
| const OrderStat &s = orders_[i]; |
| try { |
| for (unsigned j = 1; j < 4; ++j) { |
| |
| UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for " |
| << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any " |
| << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?\n" |
| << "Try deduplicating the input. To override this error for e.g. a class-based model, rerun with --discount_fallback\n"); |
| } |
|
|
| |
| discounts_[i].amount[0] = 0.0; |
| float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]); |
| for (unsigned j = 1; j < 4; ++j) { |
| discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]); |
| UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]); |
| } |
| } catch (const BadDiscountException &e) { |
| switch (config.bad_action) { |
| case THROW_UP: |
| throw; |
| case COMPLAIN: |
| std::cerr << "Substituting fallback discounts for order " << i << ": D1=" << config.fallback.amount[1] << " D2=" << config.fallback.amount[2] << " D3+=" << config.fallback.amount[3] << std::endl; |
| case SILENT: |
| break; |
| } |
| discounts_[i] = config.fallback; |
| } |
| } |
| } |
|
|
| void Add(std::size_t order_minus_1, uint64_t count, bool pruned = false) { |
| OrderStat &stat = orders_[order_minus_1]; |
| ++stat.count; |
| if (!pruned) |
| ++stat.count_pruned; |
| if (count < 5) ++stat.n[count]; |
| } |
|
|
| void AddFull(uint64_t count, bool pruned = false) { |
| ++full_.count; |
| if (!pruned) |
| ++full_.count_pruned; |
| if (count < 5) ++full_.n[count]; |
| } |
|
|
| private: |
| struct OrderStat { |
| |
| uint64_t n[5]; |
| uint64_t count; |
| uint64_t count_pruned; |
| }; |
|
|
| std::vector<OrderStat> orders_; |
| OrderStat &full_; |
|
|
| std::vector<uint64_t> &counts_; |
| std::vector<uint64_t> &counts_pruned_; |
| std::vector<Discount> &discounts_; |
| }; |
|
|
| |
| |
| |
| |
| class CollapseStream { |
| public: |
| CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold, const std::vector<bool>& prune_words) : |
| current_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())), |
| prune_threshold_(prune_threshold), |
| prune_words_(prune_words), |
| block_(position) { |
| StartBlock(); |
| } |
|
|
| const NGram<BuildingPayload> &operator*() const { return current_; } |
| const NGram<BuildingPayload> *operator->() const { return ¤t_; } |
|
|
| operator bool() const { return block_; } |
|
|
| CollapseStream &operator++() { |
| assert(block_); |
|
|
| if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) { |
| memcpy(current_.Base(), copy_from_, current_.TotalSize()); |
| UpdateCopyFrom(); |
|
|
| |
| if(current_.Value().count <= prune_threshold_) { |
| current_.Value().Mark(); |
| } |
|
|
| if(!prune_words_.empty()) { |
| for(WordIndex* i = current_.begin(); i != current_.end(); i++) { |
| if(prune_words_[*i]) { |
| current_.Value().Mark(); |
| break; |
| } |
| } |
| } |
|
|
| } |
|
|
| current_.NextInMemory(); |
| uint8_t *block_base = static_cast<uint8_t*>(block_->Get()); |
| if (current_.Base() == block_base + block_->ValidSize()) { |
| block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base); |
| ++block_; |
| StartBlock(); |
| } |
|
|
| |
| if(current_.Value().count <= prune_threshold_) { |
| current_.Value().Mark(); |
| } |
|
|
| if(!prune_words_.empty()) { |
| for(WordIndex* i = current_.begin(); i != current_.end(); i++) { |
| if(prune_words_[*i]) { |
| current_.Value().Mark(); |
| break; |
| } |
| } |
| } |
|
|
| return *this; |
| } |
|
|
| private: |
| void StartBlock() { |
| for (; ; ++block_) { |
| if (!block_) return; |
| if (block_->ValidSize()) break; |
| } |
| current_.ReBase(block_->Get()); |
| copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize(); |
| UpdateCopyFrom(); |
|
|
| |
| if(current_.Value().count <= prune_threshold_) { |
| current_.Value().Mark(); |
| } |
|
|
| if(!prune_words_.empty()) { |
| for(WordIndex* i = current_.begin(); i != current_.end(); i++) { |
| if(prune_words_[*i]) { |
| current_.Value().Mark(); |
| break; |
| } |
| } |
| } |
|
|
| } |
|
|
| |
| void UpdateCopyFrom() { |
| for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) { |
| if (NGram<BuildingPayload>(copy_from_, current_.Order()).begin()[1] != kBOS) break; |
| } |
| } |
|
|
| NGram<BuildingPayload> current_; |
|
|
| |
| uint8_t *copy_from_; |
| uint64_t prune_threshold_; |
| const std::vector<bool>& prune_words_; |
| util::stream::Link block_; |
| }; |
|
|
| } |
|
|
| void AdjustCounts::Run(const util::stream::ChainPositions &positions) { |
| UTIL_TIMER("(%w s) Adjusted counts\n"); |
|
|
| const std::size_t order = positions.size(); |
| StatCollector stats(order, counts_, counts_pruned_, discounts_); |
| if (order == 1) { |
|
|
| |
| for (NGramStream<BuildingPayload> full(positions[0]); full; ++full) { |
|
|
| |
| if(*full->begin() > 2) { |
| if(full->Value().count <= prune_thresholds_[0]) |
| full->Value().Mark(); |
|
|
| if(!prune_words_.empty() && prune_words_[*full->begin()]) |
| full->Value().Mark(); |
| } |
|
|
| stats.AddFull(full->Value().UnmarkedCount(), full->Value().IsMarked()); |
| } |
|
|
| stats.CalculateDiscounts(discount_config_); |
| return; |
| } |
|
|
| NGramStreams<BuildingPayload> streams; |
| streams.Init(positions, positions.size() - 1); |
|
|
| CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back(), prune_words_); |
|
|
| |
| NGramStream<BuildingPayload> *lower_valid = streams.begin(); |
| const NGramStream<BuildingPayload> *const streams_begin = streams.begin(); |
| streams[0]->Value().count = 0; |
| *streams[0]->begin() = kUNK; |
| stats.Add(0, 0); |
| (++streams[0])->Value().count = 0; |
| *streams[0]->begin() = kBOS; |
| |
|
|
| |
| |
| std::vector<uint64_t> actual_counts(positions.size(), 0); |
| |
| actual_counts[0] = std::numeric_limits<uint64_t>::max(); |
|
|
| |
| for (; full; ++full) { |
| const WordIndex *different = FindDifference(*full, **lower_valid); |
| std::size_t same = full->end() - 1 - different; |
|
|
| |
| for (; lower_valid >= streams.begin() + same; --lower_valid) { |
| uint64_t order_minus_1 = lower_valid - streams_begin; |
| if(actual_counts[order_minus_1] <= prune_thresholds_[order_minus_1]) |
| (*lower_valid)->Value().Mark(); |
|
|
| if(!prune_words_.empty()) { |
| for(WordIndex* i = (*lower_valid)->begin(); i != (*lower_valid)->end(); i++) { |
| if(prune_words_[*i]) { |
| (*lower_valid)->Value().Mark(); |
| break; |
| } |
| } |
| } |
|
|
| stats.Add(order_minus_1, (*lower_valid)->Value().UnmarkedCount(), (*lower_valid)->Value().IsMarked()); |
| ++*lower_valid; |
| } |
|
|
| |
| |
| for (std::size_t i = 0; i < same; ++i) { |
| actual_counts[i] += full->Value().UnmarkedCount(); |
| } |
| |
| if (same) ++streams[same - 1]->Value().count; |
|
|
| |
| |
| |
| const WordIndex *full_end = full->end(); |
| |
| const WordIndex *bos; |
| for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) { |
| NGramStream<BuildingPayload> &to = *++lower_valid; |
| std::copy(bos, full_end, to->begin()); |
| to->Value().count = 1; |
| actual_counts[lower_valid - streams_begin] = full->Value().UnmarkedCount(); |
| } |
| |
| if (bos != full->begin()) { |
| |
| NGramStream<BuildingPayload> &to = *++lower_valid; |
| std::copy(bos, full_end, to->begin()); |
|
|
| |
| to->Value().count = full->Value().UnmarkedCount(); |
| actual_counts[lower_valid - streams_begin] = full->Value().UnmarkedCount(); |
| } else { |
| stats.AddFull(full->Value().UnmarkedCount(), full->Value().IsMarked()); |
| } |
| assert(lower_valid >= &streams[0]); |
| } |
|
|
| |
| |
| for (NGramStream<BuildingPayload> *s = streams.begin(); s <= lower_valid; ++s) { |
| uint64_t lower_count = actual_counts[(*s)->Order() - 1]; |
| if(lower_count <= prune_thresholds_[(*s)->Order() - 1]) |
| (*s)->Value().Mark(); |
|
|
| if(!prune_words_.empty()) { |
| for(WordIndex* i = (*s)->begin(); i != (*s)->end(); i++) { |
| if(prune_words_[*i]) { |
| (*s)->Value().Mark(); |
| break; |
| } |
| } |
| } |
|
|
| stats.Add(s - streams.begin(), lower_count, (*s)->Value().IsMarked()); |
| ++*s; |
| } |
| |
| for (NGramStream<BuildingPayload> *s = streams.begin(); s != streams.end(); ++s) |
| s->Poison(); |
|
|
| stats.CalculateDiscounts(discount_config_); |
|
|
| |
| } |
|
|
| }} |
|
|