|
|
#include "ScoreStsg.h" |
|
|
|
|
|
#include <cassert> |
|
|
#include <cstdlib> |
|
|
#include <fstream> |
|
|
#include <iostream> |
|
|
#include <iterator> |
|
|
#include <string> |
|
|
#include <sstream> |
|
|
#include <vector> |
|
|
|
|
|
#include <boost/program_options.hpp> |
|
|
|
|
|
#include "util/string_piece.hh" |
|
|
#include "util/string_piece_hash.hh" |
|
|
#include "util/tokenize_piece.hh" |
|
|
|
|
|
#include "InputFileStream.h" |
|
|
#include "OutputFileStream.h" |
|
|
|
|
|
#include "syntax-common/exception.h" |
|
|
|
|
|
#include "LexicalTable.h" |
|
|
#include "Options.h" |
|
|
#include "RuleGroup.h" |
|
|
#include "RuleTableWriter.h" |
|
|
|
|
|
namespace MosesTraining |
|
|
{ |
|
|
namespace Syntax |
|
|
{ |
|
|
namespace ScoreStsg |
|
|
{ |
|
|
|
|
|
const int ScoreStsg::kCountOfCountsMax = 10; |
|
|
|
|
|
ScoreStsg::ScoreStsg() |
|
|
: Tool("score-stsg") |
|
|
, m_lexTable(m_srcVocab, m_tgtVocab) |
|
|
, m_countOfCounts(kCountOfCountsMax, 0) |
|
|
, m_totalDistinct(0) |
|
|
{ |
|
|
} |
|
|
|
|
|
int ScoreStsg::Main(int argc, char *argv[]) |
|
|
{ |
|
|
|
|
|
ProcessOptions(argc, argv, m_options); |
|
|
|
|
|
|
|
|
Moses::InputFileStream extractStream(m_options.extractFile); |
|
|
Moses::InputFileStream lexStream(m_options.lexFile); |
|
|
|
|
|
|
|
|
Moses::OutputFileStream outStream; |
|
|
Moses::OutputFileStream countOfCountsStream; |
|
|
OpenOutputFileOrDie(m_options.tableFile, outStream); |
|
|
if (m_options.goodTuring || m_options.kneserNey) { |
|
|
OpenOutputFileOrDie(m_options.tableFile+".coc", countOfCountsStream); |
|
|
} |
|
|
|
|
|
|
|
|
if (!m_options.noLex) { |
|
|
m_lexTable.Load(lexStream); |
|
|
} |
|
|
|
|
|
const util::MultiCharacter delimiter("|||"); |
|
|
std::size_t lineNum = 0; |
|
|
std::size_t startLine= 0; |
|
|
std::string line; |
|
|
std::string tmp; |
|
|
RuleGroup ruleGroup; |
|
|
RuleTableWriter ruleTableWriter(m_options, outStream); |
|
|
|
|
|
while (std::getline(extractStream, line)) { |
|
|
++lineNum; |
|
|
|
|
|
|
|
|
util::TokenIter<util::MultiCharacter> it(line, delimiter); |
|
|
StringPiece source = *it++; |
|
|
StringPiece target = *it++; |
|
|
StringPiece ntAlign = *it++; |
|
|
StringPiece fullAlign = *it++; |
|
|
it->CopyToString(&tmp); |
|
|
int count = std::atoi(tmp.c_str()); |
|
|
double treeScore = 0.0f; |
|
|
if (m_options.treeScore && !m_options.inverse) { |
|
|
++it; |
|
|
it->CopyToString(&tmp); |
|
|
treeScore = std::atof(tmp.c_str()); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (source != ruleGroup.GetSource()) { |
|
|
if (lineNum > 1) { |
|
|
ProcessRuleGroupOrDie(ruleGroup, ruleTableWriter, startLine, lineNum-1); |
|
|
} |
|
|
startLine = lineNum; |
|
|
ruleGroup.SetNewSource(source); |
|
|
} |
|
|
|
|
|
|
|
|
ruleGroup.AddRule(target, ntAlign, fullAlign, count, treeScore); |
|
|
} |
|
|
|
|
|
|
|
|
ProcessRuleGroupOrDie(ruleGroup, ruleTableWriter, startLine, lineNum); |
|
|
|
|
|
|
|
|
if (m_options.goodTuring || m_options.kneserNey) { |
|
|
|
|
|
countOfCountsStream << m_totalDistinct << std::endl; |
|
|
|
|
|
for (int i = 1; i <= kCountOfCountsMax; ++i) { |
|
|
countOfCountsStream << m_countOfCounts[i] << std::endl; |
|
|
} |
|
|
} |
|
|
|
|
|
return 0; |
|
|
} |
|
|
|
|
|
void ScoreStsg::TokenizeRuleHalf(const std::string &s, TokenizedRuleHalf &half) |
|
|
{ |
|
|
|
|
|
std::size_t start = s.find_first_not_of(" \t"); |
|
|
if (start == std::string::npos) { |
|
|
throw Exception("rule half is empty"); |
|
|
} |
|
|
std::size_t end = s.find_last_not_of(" \t"); |
|
|
assert(end != std::string::npos); |
|
|
half.string = s.substr(start, end-start+1); |
|
|
|
|
|
|
|
|
half.tokens.clear(); |
|
|
for (TreeFragmentTokenizer p(half.string); |
|
|
p != TreeFragmentTokenizer(); ++p) { |
|
|
half.tokens.push_back(*p); |
|
|
} |
|
|
|
|
|
|
|
|
half.frontierSymbols.clear(); |
|
|
const std::size_t numTokens = half.tokens.size(); |
|
|
for (int i = 0; i < numTokens; ++i) { |
|
|
if (half.tokens[i].type != TreeFragmentToken_WORD) { |
|
|
continue; |
|
|
} |
|
|
if (i == 0 || half.tokens[i-1].type != TreeFragmentToken_LSB) { |
|
|
|
|
|
half.frontierSymbols.resize(half.frontierSymbols.size()+1); |
|
|
half.frontierSymbols.back().value = half.tokens[i].value; |
|
|
half.frontierSymbols.back().isNonTerminal = false; |
|
|
} else if (i+1 < numTokens && |
|
|
half.tokens[i+1].type == TreeFragmentToken_RSB) { |
|
|
|
|
|
half.frontierSymbols.resize(half.frontierSymbols.size()+1); |
|
|
half.frontierSymbols.back().value = half.tokens[i].value; |
|
|
half.frontierSymbols.back().isNonTerminal = true; |
|
|
++i; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
void ScoreStsg::ProcessRuleGroupOrDie(const RuleGroup &group, |
|
|
RuleTableWriter &writer, |
|
|
std::size_t start, |
|
|
std::size_t end) |
|
|
{ |
|
|
try { |
|
|
ProcessRuleGroup(group, writer); |
|
|
} catch (const Exception &e) { |
|
|
std::ostringstream msg; |
|
|
msg << "failed to process rule group at lines " << start << "-" << end |
|
|
<< ": " << e.msg(); |
|
|
Error(msg.str()); |
|
|
} catch (const std::exception &e) { |
|
|
std::ostringstream msg; |
|
|
msg << "failed to process rule group at lines " << start << "-" << end |
|
|
<< ": " << e.what(); |
|
|
Error(msg.str()); |
|
|
} |
|
|
} |
|
|
|
|
|
void ScoreStsg::ProcessRuleGroup(const RuleGroup &group, |
|
|
RuleTableWriter &writer) |
|
|
{ |
|
|
const std::size_t totalCount = group.GetTotalCount(); |
|
|
const std::size_t distinctCount = group.GetSize(); |
|
|
|
|
|
TokenizeRuleHalf(group.GetSource(), m_sourceHalf); |
|
|
|
|
|
const bool fullyLexical = m_sourceHalf.IsFullyLexical(); |
|
|
|
|
|
|
|
|
for (RuleGroup::ConstIterator p = group.Begin(); p != group.End(); ++p) { |
|
|
const RuleGroup::DistinctRule &rule = *p; |
|
|
|
|
|
|
|
|
if (m_options.goodTuring || m_options.kneserNey) { |
|
|
++m_totalDistinct; |
|
|
int countInt = rule.count + 0.99999; |
|
|
if (countInt <= kCountOfCountsMax) { |
|
|
++m_countOfCounts[countInt]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (!fullyLexical && rule.count < m_options.minCountHierarchical) { |
|
|
continue; |
|
|
} |
|
|
|
|
|
TokenizeRuleHalf(rule.target, m_targetHalf); |
|
|
|
|
|
|
|
|
std::vector<std::pair<std::string, int> >::const_iterator q = |
|
|
rule.alignments.begin(); |
|
|
const std::pair<std::string, int> *bestAlignmentAndCount = &(*q++); |
|
|
for (; q != rule.alignments.end(); ++q) { |
|
|
if (q->second > bestAlignmentAndCount->second) { |
|
|
bestAlignmentAndCount = &(*q); |
|
|
} |
|
|
} |
|
|
const std::string &bestAlignment = bestAlignmentAndCount->first; |
|
|
ParseAlignmentString(bestAlignment, m_targetHalf.frontierSymbols.size(), |
|
|
m_tgtToSrc); |
|
|
|
|
|
|
|
|
double lexProb = ComputeLexProb(m_sourceHalf.frontierSymbols, |
|
|
m_targetHalf.frontierSymbols, m_tgtToSrc); |
|
|
|
|
|
|
|
|
writer.WriteLine(m_sourceHalf, m_targetHalf, bestAlignment, lexProb, |
|
|
rule.treeScore, p->count, totalCount, distinctCount); |
|
|
} |
|
|
} |
|
|
|
|
|
void ScoreStsg::ParseAlignmentString(const std::string &s, int numTgtWords, |
|
|
ALIGNMENT &tgtToSrc) |
|
|
{ |
|
|
tgtToSrc.clear(); |
|
|
tgtToSrc.resize(numTgtWords); |
|
|
|
|
|
const std::string digits = "0123456789"; |
|
|
|
|
|
std::string::size_type begin = 0; |
|
|
while (true) { |
|
|
std::string::size_type end = s.find("-", begin); |
|
|
if (end == std::string::npos) { |
|
|
return; |
|
|
} |
|
|
int src = std::atoi(s.substr(begin, end-begin).c_str()); |
|
|
if (end+1 == s.size()) { |
|
|
throw Exception("Target index missing"); |
|
|
} |
|
|
begin = end+1; |
|
|
end = s.find_first_not_of(digits, begin+1); |
|
|
int tgt; |
|
|
if (end == std::string::npos) { |
|
|
tgt = std::atoi(s.substr(begin).c_str()); |
|
|
tgtToSrc[tgt].insert(src); |
|
|
return; |
|
|
} else { |
|
|
tgt = std::atoi(s.substr(begin, end-begin).c_str()); |
|
|
tgtToSrc[tgt].insert(src); |
|
|
} |
|
|
begin = end+1; |
|
|
} |
|
|
} |
|
|
|
|
|
double ScoreStsg::ComputeLexProb(const std::vector<RuleSymbol> &sourceFrontier, |
|
|
const std::vector<RuleSymbol> &targetFrontier, |
|
|
const ALIGNMENT &tgtToSrc) |
|
|
{ |
|
|
double lexScore = 1.0; |
|
|
for (std::size_t i = 0; i < targetFrontier.size(); ++i) { |
|
|
if (targetFrontier[i].isNonTerminal) { |
|
|
continue; |
|
|
} |
|
|
Vocabulary::IdType tgtId = m_tgtVocab.Lookup(targetFrontier[i].value, |
|
|
StringPieceCompatibleHash(), |
|
|
StringPieceCompatibleEquals()); |
|
|
const std::set<std::size_t> &srcIndices = tgtToSrc[i]; |
|
|
if (srcIndices.empty()) { |
|
|
|
|
|
lexScore *= m_lexTable.PermissiveLookup(Vocabulary::NullId(), tgtId); |
|
|
} else { |
|
|
double thisWordScore = 0.0; |
|
|
for (std::set<std::size_t>::const_iterator p = srcIndices.begin(); |
|
|
p != srcIndices.end(); ++p) { |
|
|
Vocabulary::IdType srcId = |
|
|
m_srcVocab.Lookup(sourceFrontier[*p].value, |
|
|
StringPieceCompatibleHash(), |
|
|
StringPieceCompatibleEquals()); |
|
|
thisWordScore += m_lexTable.PermissiveLookup(srcId, tgtId); |
|
|
} |
|
|
lexScore *= thisWordScore / static_cast<double>(srcIndices.size()); |
|
|
} |
|
|
} |
|
|
return lexScore; |
|
|
} |
|
|
|
|
|
void ScoreStsg::ProcessOptions(int argc, char *argv[], Options &options) const |
|
|
{ |
|
|
namespace po = boost::program_options; |
|
|
namespace cls = boost::program_options::command_line_style; |
|
|
|
|
|
|
|
|
|
|
|
std::ostringstream usageTop; |
|
|
usageTop << "Usage: " << name() |
|
|
<< " [OPTION]... EXTRACT LEX TABLE\n\n" |
|
|
<< "STSG rule scorer\n\n" |
|
|
<< "Options"; |
|
|
|
|
|
|
|
|
std::ostringstream usageBottom; |
|
|
usageBottom << "TODO"; |
|
|
|
|
|
|
|
|
po::options_description visible(usageTop.str()); |
|
|
visible.add_options() |
|
|
("GoodTuring", |
|
|
"apply Good-Turing smoothing to relative frequency probability estimates") |
|
|
("Hierarchical", |
|
|
"ignored (included for compatibility with score)") |
|
|
("Inverse", |
|
|
"use inverse mode") |
|
|
("KneserNey", |
|
|
"apply Kneser-Ney smoothing to relative frequency probability estimates") |
|
|
("LogProb", |
|
|
"output log probabilities") |
|
|
("MinCountHierarchical", |
|
|
po::value(&options.minCountHierarchical)-> |
|
|
default_value(options.minCountHierarchical), |
|
|
"filter out rules with frequency < arg (except fully lexical rules)") |
|
|
("NegLogProb", |
|
|
"output negative log probabilities") |
|
|
("NoLex", |
|
|
"do not compute lexical translation score") |
|
|
("NoWordAlignment", |
|
|
"do not output word alignments") |
|
|
("PCFG", |
|
|
"synonym for TreeScore (included for compatibility with score)") |
|
|
("TreeScore", |
|
|
"include pre-computed tree score from extract") |
|
|
("UnpairedExtractFormat", |
|
|
"ignored (included for compatibility with score)") |
|
|
; |
|
|
|
|
|
|
|
|
|
|
|
po::options_description hidden("Hidden options"); |
|
|
hidden.add_options() |
|
|
("ExtractFile", |
|
|
po::value(&options.extractFile), |
|
|
"extract file") |
|
|
("LexFile", |
|
|
po::value(&options.lexFile), |
|
|
"lexical probability file") |
|
|
("TableFile", |
|
|
po::value(&options.tableFile), |
|
|
"output file") |
|
|
; |
|
|
|
|
|
|
|
|
po::options_description cmdLineOptions; |
|
|
cmdLineOptions.add(visible).add(hidden); |
|
|
|
|
|
|
|
|
po::positional_options_description p; |
|
|
p.add("ExtractFile", 1); |
|
|
p.add("LexFile", 1); |
|
|
p.add("TableFile", 1); |
|
|
|
|
|
|
|
|
po::variables_map vm; |
|
|
try { |
|
|
po::store(po::command_line_parser(argc, argv).style(MosesOptionStyle()). |
|
|
options(cmdLineOptions).positional(p).run(), vm); |
|
|
po::notify(vm); |
|
|
} catch (const std::exception &e) { |
|
|
std::ostringstream msg; |
|
|
msg << e.what() << "\n\n" << visible << usageBottom.str(); |
|
|
Error(msg.str()); |
|
|
} |
|
|
|
|
|
if (vm.count("help")) { |
|
|
std::cout << visible << usageBottom.str() << std::endl; |
|
|
std::exit(0); |
|
|
} |
|
|
|
|
|
|
|
|
if (!vm.count("ExtractFile") || |
|
|
!vm.count("LexFile") || |
|
|
!vm.count("TableFile")) { |
|
|
std::ostringstream msg; |
|
|
std::cerr << visible << usageBottom.str() << std::endl; |
|
|
std::exit(1); |
|
|
} |
|
|
|
|
|
|
|
|
if (vm.count("GoodTuring")) { |
|
|
options.goodTuring = true; |
|
|
} |
|
|
if (vm.count("Inverse")) { |
|
|
options.inverse = true; |
|
|
} |
|
|
if (vm.count("KneserNey")) { |
|
|
options.kneserNey = true; |
|
|
} |
|
|
if (vm.count("LogProb")) { |
|
|
options.logProb = true; |
|
|
} |
|
|
if (vm.count("NegLogProb")) { |
|
|
options.negLogProb = true; |
|
|
} |
|
|
if (vm.count("NoLex")) { |
|
|
options.noLex = true; |
|
|
} |
|
|
if (vm.count("NoWordAlignment")) { |
|
|
options.noWordAlignment = true; |
|
|
} |
|
|
if (vm.count("TreeScore") || vm.count("PCFG")) { |
|
|
options.treeScore = true; |
|
|
} |
|
|
} |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|