File size: 5,347 Bytes
fd49381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
#include "HwcmScorer.h"
#include <fstream>
#include "ScoreStats.h"
#include "Util.h"
#include "util/tokenize_piece.hh"
// HWCM score (Liu and Gildea, 2005). Implements F1 instead of precision for better modelling of hypothesis length.
// assumes dependency trees on target side (generated by scripts/training/wrappers/conll2mosesxml.py ; use with option --brackets for reference).
// reads reference trees from separate file {REFERENCE_FILE}.trees to support mix of string-based and tree-based metrics.
using namespace std;
namespace MosesTuning
{
HwcmScorer::HwcmScorer(const string& config)
: StatisticsBasedScorer("HWCM",config) {}
HwcmScorer::~HwcmScorer() {}
void HwcmScorer::setReferenceFiles(const vector<string>& referenceFiles)
{
// For each line in the reference file, create a tree object
if (referenceFiles.size() != 1) {
throw runtime_error("HWCM only supports a single reference");
}
m_ref_trees.clear();
m_ref_hwc.clear();
ifstream in((referenceFiles[0] + ".trees").c_str());
if (!in) {
throw runtime_error("Unable to open " + referenceFiles[0] + ".trees");
}
string line;
while (getline(in,line)) {
line = this->preprocessSentence(line);
TreePointer tree (boost::make_shared<InternalTree>(line));
m_ref_trees.push_back(tree);
vector<map<string, int> > hwc (kHwcmOrder);
vector<string> history(kHwcmOrder);
extractHeadWordChain(tree, history, hwc);
m_ref_hwc.push_back(hwc);
vector<int> totals(kHwcmOrder);
for (size_t i = 0; i < kHwcmOrder; i++) {
for (map<string, int>::const_iterator it = m_ref_hwc.back()[i].begin(); it != m_ref_hwc.back()[i].end(); it++) {
totals[i] += it->second;
}
}
m_ref_lengths.push_back(totals);
}
TRACE_ERR(endl);
}
void HwcmScorer::extractHeadWordChain(TreePointer tree, vector<string> & history, vector<map<string, int> > & hwc)
{
if (tree->GetLength() > 0) {
string head = getHead(tree);
if (head.empty()) {
for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
extractHeadWordChain(*it, history, hwc);
}
} else {
vector<string> new_history(kHwcmOrder);
new_history[0] = head;
hwc[0][head]++;
for (size_t hist_idx = 0; hist_idx < kHwcmOrder-1; hist_idx++) {
if (!history[hist_idx].empty()) {
string chain = history[hist_idx] + " " + head;
hwc[hist_idx+1][chain]++;
if (hist_idx+2 < kHwcmOrder) {
new_history[hist_idx+1] = chain;
}
}
}
for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
extractHeadWordChain(*it, new_history, hwc);
}
}
}
}
string HwcmScorer::getHead(TreePointer tree)
{
// assumption (only true for dependency parse: each constituent has a preterminal label, and corresponding terminal is head)
// if constituent has multiple preterminals, first one is picked; if it has no preterminals, empty string is returned
for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
TreePointer child = *it;
if (child->GetLength() == 1 && child->GetChildren()[0]->IsTerminal()) {
return child->GetChildren()[0]->GetLabel();
}
}
return "";
}
void HwcmScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
{
if (sid >= m_ref_trees.size()) {
stringstream msg;
msg << "Sentence id (" << sid << ") not found in reference set";
throw runtime_error(msg.str());
}
string sentence = this->preprocessSentence(text);
// if sentence has '|||', assume that tree is in second position (n-best-list);
// otherwise, assume it is in first position (calling 'evaluate' with tree as reference)
util::TokenIter<util::MultiCharacter> it(sentence, util::MultiCharacter("|||"));
++it;
if (it) {
sentence = it->as_string();
}
TreePointer tree (boost::make_shared<InternalTree>(sentence));
vector<map<string, int> > hwc_test (kHwcmOrder);
vector<string> history(kHwcmOrder);
extractHeadWordChain(tree, history, hwc_test);
ostringstream stats;
for (size_t i = 0; i < kHwcmOrder; i++) {
int correct = 0;
int test_total = 0;
for (map<string, int>::const_iterator it = hwc_test[i].begin(); it != hwc_test[i].end(); it++) {
test_total += it->second;
map<string, int>::const_iterator it2 = m_ref_hwc[sid][i].find(it->first);
if (it2 != m_ref_hwc[sid][i].end()) {
correct += std::min(it->second, it2->second);
}
}
stats << correct << " " << test_total << " " << m_ref_lengths[sid][i] << " " ;
}
string stats_str = stats.str();
entry.set(stats_str);
}
float HwcmScorer::calculateScore(const vector<ScoreStatsType>& comps) const
{
float precision = 0;
float recall = 0;
for (size_t i = 0; i < kHwcmOrder; i++) {
float matches = comps[i*3];
float test_total = comps[1+(i*3)];
float ref_total = comps[2+(i*3)];
if (test_total > 0) {
precision += matches/test_total;
}
if (ref_total > 0) {
recall += matches/ref_total;
}
}
precision /= (float)kHwcmOrder;
recall /= (float)kHwcmOrder;
return (2*precision*recall)/(precision+recall); // f1-score
}
} |