|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <cstdlib> |
|
|
#include <iostream> |
|
|
#include <map> |
|
|
#include <stdexcept> |
|
|
#include <set> |
|
|
|
|
|
#include "moses/IOWrapper.h" |
|
|
#include "moses/LatticeMBR.h" |
|
|
#include "moses/Manager.h" |
|
|
#include "moses/Timer.h" |
|
|
#include "moses/StaticData.h" |
|
|
#include "util/exception.hh" |
|
|
|
|
|
#include <boost/foreach.hpp> |
|
|
#include "moses/TranslationTask.h" |
|
|
|
|
|
using namespace std; |
|
|
using namespace Moses; |
|
|
|
|
|
|
|
|
enum gridkey {lmbr_p,lmbr_r,lmbr_prune,lmbr_scale}; |
|
|
|
|
|
namespace Moses |
|
|
{ |
|
|
|
|
|
class Grid |
|
|
{ |
|
|
public: |
|
|
|
|
|
void addParam(gridkey key, const string& arg, float defaultValue) { |
|
|
m_args[arg] = key; |
|
|
UTIL_THROW_IF2(m_grid.find(key) != m_grid.end(), |
|
|
"Couldn't find value for key " << (int) key); |
|
|
m_grid[key].push_back(defaultValue); |
|
|
} |
|
|
|
|
|
|
|
|
void parseArgs(int& argc, char const**& argv) { |
|
|
char const** newargv = new char const*[argc+1]; |
|
|
int newargc = 0; |
|
|
for (int i = 0; i < argc; ++i) { |
|
|
bool consumed = false; |
|
|
for (map<string,gridkey>::const_iterator argi = m_args.begin(); argi != m_args.end(); ++argi) { |
|
|
if (!strcmp(argv[i], argi->first.c_str())) { |
|
|
++i; |
|
|
if (i >= argc) { |
|
|
cerr << "Error: missing parameter for " << argi->first << endl; |
|
|
throw runtime_error("Missing parameter"); |
|
|
} else { |
|
|
string value = argv[i]; |
|
|
gridkey key = argi->second; |
|
|
if (m_grid[key].size() != 1) { |
|
|
throw runtime_error("Duplicate grid argument"); |
|
|
} |
|
|
m_grid[key].clear(); |
|
|
char delim = ','; |
|
|
string::size_type lastpos = value.find_first_not_of(delim); |
|
|
string::size_type pos = value.find_first_of(delim,lastpos); |
|
|
while (string::npos != pos || string::npos != lastpos) { |
|
|
float param = atof(value.substr(lastpos, pos-lastpos).c_str()); |
|
|
if (!param) { |
|
|
cerr << "Error: Illegal grid parameter for " << argi->first << endl; |
|
|
throw runtime_error("Illegal grid parameter"); |
|
|
} |
|
|
m_grid[key].push_back(param); |
|
|
lastpos = value.find_first_not_of(delim,pos); |
|
|
pos = value.find_first_of(delim,lastpos); |
|
|
} |
|
|
consumed = true; |
|
|
} |
|
|
if (consumed) break; |
|
|
} |
|
|
} |
|
|
if (!consumed) { |
|
|
|
|
|
|
|
|
newargv[newargc] = argv[i]; |
|
|
++newargc; |
|
|
} |
|
|
} |
|
|
argc = newargc; |
|
|
argv = newargv; |
|
|
} |
|
|
|
|
|
|
|
|
const vector<float>& getGrid(gridkey key) const { |
|
|
map<gridkey,vector<float> >::const_iterator iter = m_grid.find(key); |
|
|
assert (iter != m_grid.end()); |
|
|
return iter->second; |
|
|
|
|
|
} |
|
|
|
|
|
private: |
|
|
map<gridkey,vector<float> > m_grid; |
|
|
map<string,gridkey> m_args; |
|
|
}; |
|
|
|
|
|
} |
|
|
|
|
|
int main(int argc, char const* argv[]) |
|
|
{ |
|
|
cerr << "Lattice MBR Grid search" << endl; |
|
|
|
|
|
Grid grid; |
|
|
grid.addParam(lmbr_p, "-lmbr-p", 0.5); |
|
|
grid.addParam(lmbr_r, "-lmbr-r", 0.5); |
|
|
grid.addParam(lmbr_prune, "-lmbr-pruning-factor",30.0); |
|
|
grid.addParam(lmbr_scale, "-mbr-scale",1.0); |
|
|
|
|
|
grid.parseArgs(argc,argv); |
|
|
|
|
|
Parameter* params = new Parameter(); |
|
|
if (!params->LoadParam(argc,argv)) { |
|
|
params->Explain(); |
|
|
exit(1); |
|
|
} |
|
|
|
|
|
ResetUserTime(); |
|
|
if (!StaticData::LoadDataStatic(params, argv[0])) { |
|
|
exit(1); |
|
|
} |
|
|
|
|
|
StaticData& SD = const_cast<StaticData&>(StaticData::Instance()); |
|
|
boost::shared_ptr<AllOptions> opts(new AllOptions(*SD.options())); |
|
|
LMBR_Options& lmbr = opts->lmbr; |
|
|
MBR_Options& mbr = opts->mbr; |
|
|
lmbr.enabled = true; |
|
|
|
|
|
boost::shared_ptr<IOWrapper> ioWrapper(new IOWrapper(*opts)); |
|
|
if (!ioWrapper) { |
|
|
throw runtime_error("Failed to initialise IOWrapper"); |
|
|
} |
|
|
size_t nBestSize = mbr.size; |
|
|
|
|
|
if (nBestSize <= 0) { |
|
|
throw new runtime_error("Non-positive size specified for n-best list"); |
|
|
} |
|
|
|
|
|
const vector<float>& pgrid = grid.getGrid(lmbr_p); |
|
|
const vector<float>& rgrid = grid.getGrid(lmbr_r); |
|
|
const vector<float>& prune_grid = grid.getGrid(lmbr_prune); |
|
|
const vector<float>& scale_grid = grid.getGrid(lmbr_scale); |
|
|
|
|
|
boost::shared_ptr<InputType> source; |
|
|
while((source = ioWrapper->ReadInput()) != NULL) { |
|
|
|
|
|
boost::shared_ptr<TranslationTask> ttask; |
|
|
ttask = TranslationTask::create(source, ioWrapper); |
|
|
Manager manager(ttask); |
|
|
manager.Decode(); |
|
|
TrellisPathList nBestList; |
|
|
manager.CalcNBest(nBestSize, nBestList,true); |
|
|
|
|
|
BOOST_FOREACH(float const& p, pgrid) { |
|
|
lmbr.precision = p; |
|
|
BOOST_FOREACH(float const& r, rgrid) { |
|
|
lmbr.ratio = r; |
|
|
BOOST_FOREACH(size_t const prune_i, prune_grid) { |
|
|
lmbr.pruning_factor = prune_i; |
|
|
BOOST_FOREACH(float const& scale_i, scale_grid) { |
|
|
mbr.scale = scale_i; |
|
|
size_t lineCount = source->GetTranslationId(); |
|
|
cout << lineCount << " ||| " << p << " " |
|
|
<< r << " " << size_t(prune_i) << " " << scale_i |
|
|
<< " ||| "; |
|
|
vector<Word> mbrBestHypo = doLatticeMBR(manager,nBestList); |
|
|
manager.OutputBestHypo(mbrBestHypo, cout); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|