File size: 5,793 Bytes
fd49381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
#include "OxLM.h"
#include <boost/archive/binary_iarchive.hpp>
#include <boost/archive/binary_oarchive.hpp>
#include <boost/filesystem.hpp>
#include <boost/functional/hash.hpp>
#include "moses/FactorCollection.h"
#include "moses/InputType.h"
#include "moses/TranslationTask.h"
using namespace std;
using namespace oxlm;
namespace Moses
{
template<class Model>
OxLM<Model>::OxLM(const string &line)
: LanguageModelSingleFactor(line), normalized(true),
posBackOff(false), posFactorType(1),
persistentCache(false)
{
ReadParameters();
FactorCollection &factorCollection = FactorCollection::Instance();
// needed by parent language model classes. Why didn't they set these themselves?
m_sentenceStart = factorCollection.AddFactor(Output, m_factorType, BOS_);
m_sentenceStartWord[m_factorType] = m_sentenceStart;
m_sentenceEnd = factorCollection.AddFactor(Output, m_factorType, EOS_);
m_sentenceEndWord[m_factorType] = m_sentenceEnd;
cacheHits = totalHits = 0;
}
template<class Model>
OxLM<Model>::~OxLM()
{
if (persistentCache) {
if (cache.get()) {
string cache_file = m_filePath + ".phrases.cache.bin";
savePersistentCache(cache_file);
}
double cache_hit_ratio = 100.0 * cacheHits / totalHits;
cerr << "Cache hit ratio: " << cache_hit_ratio << endl;
}
}
template<class Model>
void OxLM<Model>::SetParameter(const string& key, const string& value)
{
if (key == "normalized") {
normalized = Scan<bool>(value);
} else if (key == "persistent-cache") {
persistentCache = Scan<bool>(value);
} else if (key == "normalized") {
normalized = Scan<bool>(value);
} else if (key == "pos-back-off") {
posBackOff = Scan<bool>(value);
} else if (key == "pos-factor-type") {
posFactorType = Scan<FactorType>(value);
} else {
LanguageModelSingleFactor::SetParameter(key, value);
}
}
template<class Model>
void OxLM<Model>::Load(AllOptions::ptr const& opts)
{
model.load(m_filePath);
boost::shared_ptr<Vocabulary> vocab = model.getVocab();
mapper = boost::make_shared<OxLMMapper>(vocab, posBackOff, posFactorType);
kSTART = vocab->convert("<s>");
kSTOP = vocab->convert("</s>");
kUNKNOWN = vocab->convert("<unk>");
size_t ngram_order = model.getConfig()->ngram_order;
UTIL_THROW_IF2(
m_nGramOrder != ngram_order,
"Wrong order for OxLM: LM has " << ngram_order << ", but Moses expects " << m_nGramOrder);
}
template<class Model>
double OxLM<Model>::GetScore(int word, const vector<int>& context) const
{
if (normalized) {
return model.getLogProb(word, context);
} else {
return model.getUnnormalizedScore(word, context);
}
}
template<class Model>
LMResult OxLM<Model>::GetValue(
const vector<const Word*> &contextFactor, State* finalState) const
{
if (!cache.get()) {
cache.reset(new QueryCache());
string cache_file = m_filePath + ".phrases.cache.bin";
loadPersistentCache(cache_file);
}
vector<int> context;
int word;
mapper->convert(contextFactor, context, word);
size_t context_width = m_nGramOrder - 1;
if (!context.empty() && context.back() == kSTART) {
context.resize(context_width, kSTART);
} else {
context.resize(context_width, kUNKNOWN);
}
double score;
if (persistentCache) {
++totalHits;
NGram query(word, context);
pair<double, bool> ret = cache->get(query);
if (ret.second) {
score = ret.first;
++cacheHits;
} else {
score = GetScore(word, context);
cache->put(query, score);
}
} else {
score = GetScore(word, context);
}
LMResult ret;
ret.score = score;
ret.unknown = (word == kUNKNOWN);
// calc state from hash of last n-1 words
size_t seed = 0;
boost::hash_combine(seed, word);
for (size_t i = 0; i < context.size() && i < context_width - 1; ++i) {
int id = context[i];
boost::hash_combine(seed, id);
}
(*finalState) = (State*) seed;
return ret;
}
template<class Model>
void OxLM<Model>::loadPersistentCache(const string& cache_file) const
{
if (boost::filesystem::exists(cache_file)) {
ifstream f(cache_file);
boost::archive::binary_iarchive iar(f);
cerr << "Loading n-gram probability cache from " << cache_file << endl;
iar >> *cache;
cerr << "Done loading " << cache->size()
<< " n-gram probabilities..." << endl;
} else {
cerr << "Cache file not found" << endl;
}
}
template<class Model>
void OxLM<Model>::savePersistentCache(const string& cache_file) const
{
ofstream f(cache_file);
boost::archive::binary_oarchive oar(f);
cerr << "Saving persistent cache to " << cache_file << endl;
oar << *cache;
cerr << "Done saving " << cache->size()
<< " n-gram probabilities..." << endl;
}
template<class Model>
void OxLM<Model>::InitializeForInput(ttasksptr const& ttask)
{
const InputType& source = *ttask->GetSource();
LanguageModelSingleFactor::InitializeForInput(ttask);
if (persistentCache) {
if (!cache.get()) {
cache.reset(new QueryCache());
}
int sentence_id = source.GetTranslationId();
string cache_file = m_filePath + "." + to_string(sentence_id) + ".cache.bin";
loadPersistentCache(cache_file);
}
}
template<class Model>
void OxLM<Model>::CleanUpAfterSentenceProcessing(const InputType& source)
{
// Thread safe: the model cache is thread specific.
model.clearCache();
if (persistentCache) {
int sentence_id = source.GetTranslationId();
string cache_file = m_filePath + "." + to_string(sentence_id) + ".cache.bin";
savePersistentCache(cache_file);
cache->clear();
}
LanguageModelSingleFactor::CleanUpAfterSentenceProcessing(source);
}
template class OxLM<LM>;
template class OxLM<FactoredLM>;
template class OxLM<FactoredMaxentLM>;
template class OxLM<FactoredTreeLM>;
}
|