Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq-0.10.2/examples/backtranslation/prepare-wmt18en2de.sh +135 -0
- fairseq-0.10.2/examples/byte_level_bpe/README.md +88 -0
- fairseq-0.10.2/examples/byte_level_bpe/get_bitext.py +254 -0
- fairseq-0.10.2/examples/byte_level_bpe/get_data.sh +47 -0
- fairseq-0.10.2/examples/byte_level_bpe/gru_transformer.py +107 -0
- fairseq-0.10.2/examples/conv_seq2seq/README.md +25 -0
- fairseq-0.10.2/examples/layerdrop/README.md +154 -0
- fairseq-0.10.2/examples/noisychannel/rerank_score_lm.py +81 -0
- fairseq-0.10.2/examples/quant_noise/README.md +298 -0
- fairseq-0.10.2/examples/quant_noise/transformer_quantization_config.yaml +33 -0
- mosesdecoder/moses/FF/BleuScoreFeature.h +191 -0
- mosesdecoder/moses/FF/ConstrainedDecoding.cpp +212 -0
- mosesdecoder/moses/FF/ControlRecombination.cpp +96 -0
- mosesdecoder/moses/FF/ControlRecombination.h +88 -0
- mosesdecoder/moses/FF/CorrectionPattern.cpp +354 -0
- mosesdecoder/moses/FF/CorrectionPattern.h +73 -0
- mosesdecoder/moses/FF/CoveredReferenceFeature.cpp +129 -0
- mosesdecoder/moses/FF/DecodeFeature.cpp +117 -0
- mosesdecoder/moses/FF/DistortionScoreProducer.h +57 -0
- mosesdecoder/moses/FF/DynamicCacheBasedLanguageModel.cpp +459 -0
- mosesdecoder/moses/FF/DynamicCacheBasedLanguageModel.h +164 -0
- mosesdecoder/moses/FF/EditOps.h +64 -0
- mosesdecoder/moses/FF/ExampleStatelessFF.cpp +69 -0
- mosesdecoder/moses/FF/ExampleTranslationOptionListFeature.h +67 -0
- mosesdecoder/moses/FF/FFState.cpp +9 -0
- mosesdecoder/moses/FF/FFState.h +39 -0
- mosesdecoder/moses/FF/FeatureFunction.h +200 -0
- mosesdecoder/moses/FF/GlobalLexicalModel.h +102 -0
- mosesdecoder/moses/FF/GlobalLexicalModelUnlimited.cpp +340 -0
- mosesdecoder/moses/FF/HyperParameterAsWeight.cpp +29 -0
- mosesdecoder/moses/FF/InputFeature.h +70 -0
- mosesdecoder/moses/FF/Model1Feature.cpp +276 -0
- mosesdecoder/moses/FF/Model1Feature.h +118 -0
- mosesdecoder/moses/FF/PhraseBoundaryFeature.h +70 -0
- mosesdecoder/moses/FF/PhraseDistanceFeature.cpp +123 -0
- mosesdecoder/moses/FF/PhraseDistanceFeature.h +56 -0
- mosesdecoder/moses/FF/PhraseLengthFeature.cpp +46 -0
- mosesdecoder/moses/FF/PhrasePenalty.cpp +51 -0
- mosesdecoder/moses/FF/PhrasePenalty.h +50 -0
- mosesdecoder/moses/FF/ReferenceComparison.cpp +11 -0
- mosesdecoder/moses/FF/RuleScope.h +56 -0
- mosesdecoder/moses/FF/SoftMatchingFeature.h +68 -0
- mosesdecoder/moses/FF/SoftSourceSyntacticConstraintsFeature.cpp +651 -0
- mosesdecoder/moses/FF/SourceGHKMTreeInputMatchFeature.cpp +75 -0
- mosesdecoder/moses/FF/SourceGHKMTreeInputMatchFeature.h +49 -0
- mosesdecoder/moses/FF/SourceWordDeletionFeature.h +63 -0
- mosesdecoder/moses/FF/SpanLength.cpp +88 -0
- mosesdecoder/moses/FF/SpanLength.h +56 -0
- mosesdecoder/moses/FF/SparseHieroReorderingFeature.cpp +224 -0
- mosesdecoder/moses/FF/SparseHieroReorderingFeatureTest.cpp +36 -0
fairseq-0.10.2/examples/backtranslation/prepare-wmt18en2de.sh
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
|
| 3 |
+
|
| 4 |
+
echo 'Cloning Moses github repository (for tokenization scripts)...'
|
| 5 |
+
git clone https://github.com/moses-smt/mosesdecoder.git
|
| 6 |
+
|
| 7 |
+
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
|
| 8 |
+
git clone https://github.com/rsennrich/subword-nmt.git
|
| 9 |
+
|
| 10 |
+
SCRIPTS=mosesdecoder/scripts
|
| 11 |
+
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
|
| 12 |
+
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
|
| 13 |
+
NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
|
| 14 |
+
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
|
| 15 |
+
BPEROOT=subword-nmt/subword_nmt
|
| 16 |
+
BPE_TOKENS=32000
|
| 17 |
+
|
| 18 |
+
URLS=(
|
| 19 |
+
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
|
| 20 |
+
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
|
| 21 |
+
"http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
|
| 22 |
+
"http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
|
| 23 |
+
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
|
| 24 |
+
"http://statmt.org/wmt14/test-full.tgz"
|
| 25 |
+
)
|
| 26 |
+
FILES=(
|
| 27 |
+
"training-parallel-europarl-v7.tgz"
|
| 28 |
+
"training-parallel-commoncrawl.tgz"
|
| 29 |
+
"training-parallel-nc-v13.tgz"
|
| 30 |
+
"rapid2016.tgz"
|
| 31 |
+
"dev.tgz"
|
| 32 |
+
"test-full.tgz"
|
| 33 |
+
)
|
| 34 |
+
CORPORA=(
|
| 35 |
+
"training/europarl-v7.de-en"
|
| 36 |
+
"commoncrawl.de-en"
|
| 37 |
+
"training-parallel-nc-v13/news-commentary-v13.de-en"
|
| 38 |
+
"rapid2016.de-en"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
if [ ! -d "$SCRIPTS" ]; then
|
| 42 |
+
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
|
| 43 |
+
exit 1
|
| 44 |
+
fi
|
| 45 |
+
|
| 46 |
+
OUTDIR=wmt18_en_de
|
| 47 |
+
|
| 48 |
+
src=en
|
| 49 |
+
tgt=de
|
| 50 |
+
lang=en-de
|
| 51 |
+
prep=$OUTDIR
|
| 52 |
+
tmp=$prep/tmp
|
| 53 |
+
orig=orig
|
| 54 |
+
|
| 55 |
+
mkdir -p $orig $tmp $prep
|
| 56 |
+
|
| 57 |
+
cd $orig
|
| 58 |
+
|
| 59 |
+
for ((i=0;i<${#URLS[@]};++i)); do
|
| 60 |
+
file=${FILES[i]}
|
| 61 |
+
if [ -f $file ]; then
|
| 62 |
+
echo "$file already exists, skipping download"
|
| 63 |
+
else
|
| 64 |
+
url=${URLS[i]}
|
| 65 |
+
wget "$url"
|
| 66 |
+
if [ -f $file ]; then
|
| 67 |
+
echo "$url successfully downloaded."
|
| 68 |
+
else
|
| 69 |
+
echo "$url not successfully downloaded."
|
| 70 |
+
exit 1
|
| 71 |
+
fi
|
| 72 |
+
if [ ${file: -4} == ".tgz" ]; then
|
| 73 |
+
tar zxvf $file
|
| 74 |
+
elif [ ${file: -4} == ".tar" ]; then
|
| 75 |
+
tar xvf $file
|
| 76 |
+
fi
|
| 77 |
+
fi
|
| 78 |
+
done
|
| 79 |
+
cd ..
|
| 80 |
+
|
| 81 |
+
echo "pre-processing train data..."
|
| 82 |
+
for l in $src $tgt; do
|
| 83 |
+
rm $tmp/train.tags.$lang.tok.$l
|
| 84 |
+
for f in "${CORPORA[@]}"; do
|
| 85 |
+
cat $orig/$f.$l | \
|
| 86 |
+
perl $NORM_PUNC $l | \
|
| 87 |
+
perl $REM_NON_PRINT_CHAR | \
|
| 88 |
+
perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
|
| 89 |
+
done
|
| 90 |
+
done
|
| 91 |
+
|
| 92 |
+
echo "pre-processing test data..."
|
| 93 |
+
for l in $src $tgt; do
|
| 94 |
+
if [ "$l" == "$src" ]; then
|
| 95 |
+
t="src"
|
| 96 |
+
else
|
| 97 |
+
t="ref"
|
| 98 |
+
fi
|
| 99 |
+
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
|
| 100 |
+
sed -e 's/<seg id="[0-9]*">\s*//g' | \
|
| 101 |
+
sed -e 's/\s*<\/seg>\s*//g' | \
|
| 102 |
+
sed -e "s/\’/\'/g" | \
|
| 103 |
+
perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
|
| 104 |
+
echo ""
|
| 105 |
+
done
|
| 106 |
+
|
| 107 |
+
echo "splitting train and valid..."
|
| 108 |
+
for l in $src $tgt; do
|
| 109 |
+
awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
|
| 110 |
+
awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
|
| 111 |
+
done
|
| 112 |
+
|
| 113 |
+
TRAIN=$tmp/train.de-en
|
| 114 |
+
BPE_CODE=$prep/code
|
| 115 |
+
rm -f $TRAIN
|
| 116 |
+
for l in $src $tgt; do
|
| 117 |
+
cat $tmp/train.$l >> $TRAIN
|
| 118 |
+
done
|
| 119 |
+
|
| 120 |
+
echo "learn_bpe.py on ${TRAIN}..."
|
| 121 |
+
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
|
| 122 |
+
|
| 123 |
+
for L in $src $tgt; do
|
| 124 |
+
for f in train.$L valid.$L test.$L; do
|
| 125 |
+
echo "apply_bpe.py to ${f}..."
|
| 126 |
+
python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
|
| 127 |
+
done
|
| 128 |
+
done
|
| 129 |
+
|
| 130 |
+
perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
|
| 131 |
+
perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
|
| 132 |
+
|
| 133 |
+
for L in $src $tgt; do
|
| 134 |
+
cp $tmp/bpe.test.$L $prep/test.$L
|
| 135 |
+
done
|
fairseq-0.10.2/examples/byte_level_bpe/README.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Neural Machine Translation with Byte-Level Subwords
|
| 2 |
+
|
| 3 |
+
https://arxiv.org/abs/1909.03341
|
| 4 |
+
|
| 5 |
+
We provide an implementation of byte-level byte-pair encoding (BBPE), taking IWSLT 2017 Fr-En translation as
|
| 6 |
+
example.
|
| 7 |
+
|
| 8 |
+
## Data
|
| 9 |
+
Get data and generate fairseq binary dataset:
|
| 10 |
+
```bash
|
| 11 |
+
bash ./get_data.sh
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
## Model Training
|
| 15 |
+
Train Transformer model with Bi-GRU embedding contextualization (implemented in `gru_transformer.py`):
|
| 16 |
+
```bash
|
| 17 |
+
# VOCAB=bytes
|
| 18 |
+
# VOCAB=chars
|
| 19 |
+
VOCAB=bbpe2048
|
| 20 |
+
# VOCAB=bpe2048
|
| 21 |
+
# VOCAB=bbpe4096
|
| 22 |
+
# VOCAB=bpe4096
|
| 23 |
+
# VOCAB=bpe16384
|
| 24 |
+
```
|
| 25 |
+
```bash
|
| 26 |
+
fairseq-train "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
|
| 27 |
+
--arch gru_transformer --encoder-layers 2 --decoder-layers 2 --dropout 0.3 --share-all-embeddings \
|
| 28 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' \
|
| 29 |
+
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
| 30 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
| 31 |
+
--log-format 'simple' --log-interval 100 --save-dir "checkpoints/${VOCAB}" \
|
| 32 |
+
--batch-size 100 --max-update 100000 --update-freq 2
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Generation
|
| 36 |
+
`fairseq-generate` requires bytes (BBPE) decoder to convert byte-level representation back to characters:
|
| 37 |
+
```bash
|
| 38 |
+
# BPE=--bpe bytes
|
| 39 |
+
# BPE=--bpe characters
|
| 40 |
+
BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe2048.model
|
| 41 |
+
# BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe2048.model
|
| 42 |
+
# BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe4096.model
|
| 43 |
+
# BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe4096.model
|
| 44 |
+
# BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe16384.model
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
fairseq-generate "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
|
| 49 |
+
--source-lang fr --gen-subset test --sacrebleu --path "checkpoints/${VOCAB}/checkpoint_last.pt" \
|
| 50 |
+
--tokenizer moses --moses-target-lang en ${BPE}
|
| 51 |
+
```
|
| 52 |
+
When using `fairseq-interactive`, bytes (BBPE) encoder/decoder is required to tokenize input data and detokenize model predictions:
|
| 53 |
+
```bash
|
| 54 |
+
fairseq-interactive "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
|
| 55 |
+
--path "checkpoints/${VOCAB}/checkpoint_last.pt" --input data/test.fr --tokenizer moses --moses-source-lang fr \
|
| 56 |
+
--moses-target-lang en ${BPE} --buffer-size 1000 --max-tokens 10000
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Results
|
| 60 |
+
| Vocabulary | Model | BLEU |
|
| 61 |
+
|:-------------:|:-------------:|:-------------:|
|
| 62 |
+
| Joint BPE 16k ([Kudo, 2018](https://arxiv.org/abs/1804.10959)) | 512d LSTM 2+2 | 33.81 |
|
| 63 |
+
| Joint BPE 16k | Transformer base 2+2 (w/ GRU) | 36.64 (36.72) |
|
| 64 |
+
| Joint BPE 4k | Transformer base 2+2 (w/ GRU) | 35.49 (36.10) |
|
| 65 |
+
| Joint BBPE 4k | Transformer base 2+2 (w/ GRU) | 35.61 (35.82) |
|
| 66 |
+
| Joint BPE 2k | Transformer base 2+2 (w/ GRU) | 34.87 (36.13) |
|
| 67 |
+
| Joint BBPE 2k | Transformer base 2+2 (w/ GRU) | 34.98 (35.43) |
|
| 68 |
+
| Characters | Transformer base 2+2 (w/ GRU) | 31.78 (33.30) |
|
| 69 |
+
| Bytes | Transformer base 2+2 (w/ GRU) | 31.57 (33.62) |
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
## Citation
|
| 73 |
+
```
|
| 74 |
+
@misc{wang2019neural,
|
| 75 |
+
title={Neural Machine Translation with Byte-Level Subwords},
|
| 76 |
+
author={Changhan Wang and Kyunghyun Cho and Jiatao Gu},
|
| 77 |
+
year={2019},
|
| 78 |
+
eprint={1909.03341},
|
| 79 |
+
archivePrefix={arXiv},
|
| 80 |
+
primaryClass={cs.CL}
|
| 81 |
+
}
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
## Contact
|
| 86 |
+
Changhan Wang ([changhan@fb.com](mailto:changhan@fb.com)),
|
| 87 |
+
Kyunghyun Cho ([kyunghyuncho@fb.com](mailto:kyunghyuncho@fb.com)),
|
| 88 |
+
Jiatao Gu ([jgu@fb.com](mailto:jgu@fb.com))
|
fairseq-0.10.2/examples/byte_level_bpe/get_bitext.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import os
|
| 9 |
+
import os.path as op
|
| 10 |
+
from collections import namedtuple
|
| 11 |
+
from multiprocessing import cpu_count
|
| 12 |
+
from typing import List, Optional
|
| 13 |
+
|
| 14 |
+
import sentencepiece as sp
|
| 15 |
+
from fairseq.data.encoders.byte_bpe import ByteBPE
|
| 16 |
+
from fairseq.data.encoders.byte_utils import byte_encode
|
| 17 |
+
from fairseq.data.encoders.bytes import Bytes
|
| 18 |
+
from fairseq.data.encoders.characters import Characters
|
| 19 |
+
from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
|
| 20 |
+
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
SPLITS = ["train", "valid", "test"]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _convert_xml(in_path: str, out_path: str):
|
| 27 |
+
with open(in_path) as f, open(out_path, "w") as f_o:
|
| 28 |
+
for s in f:
|
| 29 |
+
ss = s.strip()
|
| 30 |
+
if not ss.startswith("<seg"):
|
| 31 |
+
continue
|
| 32 |
+
ss = ss.replace("</seg>", "").split('">')
|
| 33 |
+
assert len(ss) == 2
|
| 34 |
+
f_o.write(ss[1].strip() + "\n")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _convert_train(in_path: str, out_path: str):
|
| 38 |
+
with open(in_path) as f, open(out_path, "w") as f_o:
|
| 39 |
+
for s in f:
|
| 40 |
+
ss = s.strip()
|
| 41 |
+
if ss.startswith("<"):
|
| 42 |
+
continue
|
| 43 |
+
f_o.write(ss.strip() + "\n")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _get_bytes(in_path: str, out_path: str):
|
| 47 |
+
with open(in_path) as f, open(out_path, "w") as f_o:
|
| 48 |
+
for s in f:
|
| 49 |
+
f_o.write(Bytes.encode(s.strip()) + "\n")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _get_chars(in_path: str, out_path: str):
|
| 53 |
+
with open(in_path) as f, open(out_path, "w") as f_o:
|
| 54 |
+
for s in f:
|
| 55 |
+
f_o.write(Characters.encode(s.strip()) + "\n")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
|
| 59 |
+
Args = namedtuple(
|
| 60 |
+
"Args",
|
| 61 |
+
[
|
| 62 |
+
"moses_source_lang",
|
| 63 |
+
"moses_target_lang",
|
| 64 |
+
"moses_no_dash_splits",
|
| 65 |
+
"moses_no_escape",
|
| 66 |
+
],
|
| 67 |
+
)
|
| 68 |
+
args = Args(
|
| 69 |
+
moses_source_lang=src,
|
| 70 |
+
moses_target_lang=tgt,
|
| 71 |
+
moses_no_dash_splits=False,
|
| 72 |
+
moses_no_escape=False,
|
| 73 |
+
)
|
| 74 |
+
pretokenizer = MosesTokenizer(args)
|
| 75 |
+
with open(in_path) as f, open(out_path, "w") as f_o:
|
| 76 |
+
for s in f:
|
| 77 |
+
f_o.write(pretokenizer.encode(s.strip()) + "\n")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
|
| 81 |
+
with open(out_path, "w") as f_o:
|
| 82 |
+
for lang in [src, tgt]:
|
| 83 |
+
with open(f"{in_path_prefix}.{lang}") as f:
|
| 84 |
+
for s in f:
|
| 85 |
+
f_o.write(byte_encode(s.strip()) + "\n")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
|
| 89 |
+
arguments = [
|
| 90 |
+
f"--input={in_path}",
|
| 91 |
+
f"--model_prefix={model_prefix}",
|
| 92 |
+
f"--model_type=bpe",
|
| 93 |
+
f"--vocab_size={vocab_size}",
|
| 94 |
+
"--character_coverage=1.0",
|
| 95 |
+
"--normalization_rule_name=identity",
|
| 96 |
+
f"--num_threads={cpu_count()}",
|
| 97 |
+
]
|
| 98 |
+
sp.SentencePieceTrainer.Train(" ".join(arguments))
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _apply_bbpe(model_path: str, in_path: str, out_path: str):
|
| 102 |
+
Args = namedtuple("Args", ["sentencepiece_model_path"])
|
| 103 |
+
args = Args(sentencepiece_model_path=model_path)
|
| 104 |
+
tokenizer = ByteBPE(args)
|
| 105 |
+
with open(in_path) as f, open(out_path, "w") as f_o:
|
| 106 |
+
for s in f:
|
| 107 |
+
f_o.write(tokenizer.encode(s.strip()) + "\n")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _apply_bpe(model_path: str, in_path: str, out_path: str):
|
| 111 |
+
Args = namedtuple("Args", ["sentencepiece_model"])
|
| 112 |
+
args = Args(sentencepiece_model=model_path)
|
| 113 |
+
tokenizer = SentencepieceBPE(args)
|
| 114 |
+
with open(in_path) as f, open(out_path, "w") as f_o:
|
| 115 |
+
for s in f:
|
| 116 |
+
f_o.write(tokenizer.encode(s.strip()) + "\n")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _concat_files(in_paths: List[str], out_path: str):
|
| 120 |
+
with open(out_path, "w") as f_o:
|
| 121 |
+
for p in in_paths:
|
| 122 |
+
with open(p) as f:
|
| 123 |
+
for r in f:
|
| 124 |
+
f_o.write(r)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def preprocess_iwslt17(
|
| 128 |
+
root: str,
|
| 129 |
+
src: str,
|
| 130 |
+
tgt: str,
|
| 131 |
+
bpe_size: Optional[int],
|
| 132 |
+
need_chars: bool,
|
| 133 |
+
bbpe_size: Optional[int],
|
| 134 |
+
need_bytes: bool,
|
| 135 |
+
):
|
| 136 |
+
# extract bitext
|
| 137 |
+
in_root = op.join(root, f"{src}-{tgt}")
|
| 138 |
+
for lang in [src, tgt]:
|
| 139 |
+
_convert_train(
|
| 140 |
+
op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
|
| 141 |
+
op.join(root, f"train.{lang}"),
|
| 142 |
+
)
|
| 143 |
+
_convert_xml(
|
| 144 |
+
op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
|
| 145 |
+
op.join(root, f"valid.{lang}"),
|
| 146 |
+
)
|
| 147 |
+
_convert_xml(
|
| 148 |
+
op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
|
| 149 |
+
op.join(root, f"test.{lang}"),
|
| 150 |
+
)
|
| 151 |
+
# pre-tokenize
|
| 152 |
+
for lang in [src, tgt]:
|
| 153 |
+
for split in SPLITS:
|
| 154 |
+
pretokenize(
|
| 155 |
+
op.join(root, f"{split}.{lang}"),
|
| 156 |
+
op.join(root, f"{split}.moses.{lang}"),
|
| 157 |
+
src,
|
| 158 |
+
tgt,
|
| 159 |
+
)
|
| 160 |
+
# tokenize with BPE vocabulary
|
| 161 |
+
if bpe_size is not None:
|
| 162 |
+
# learn vocabulary
|
| 163 |
+
concated_train_path = op.join(root, "train.all")
|
| 164 |
+
_concat_files(
|
| 165 |
+
[op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
|
| 166 |
+
concated_train_path,
|
| 167 |
+
)
|
| 168 |
+
bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
|
| 169 |
+
_get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
|
| 170 |
+
os.remove(concated_train_path)
|
| 171 |
+
# apply
|
| 172 |
+
for lang in [src, tgt]:
|
| 173 |
+
for split in SPLITS:
|
| 174 |
+
_apply_bpe(
|
| 175 |
+
bpe_model_prefix + ".model",
|
| 176 |
+
op.join(root, f"{split}.moses.{lang}"),
|
| 177 |
+
op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
|
| 178 |
+
)
|
| 179 |
+
# tokenize with bytes vocabulary
|
| 180 |
+
if need_bytes:
|
| 181 |
+
for lang in [src, tgt]:
|
| 182 |
+
for split in SPLITS:
|
| 183 |
+
_get_bytes(
|
| 184 |
+
op.join(root, f"{split}.moses.{lang}"),
|
| 185 |
+
op.join(root, f"{split}.moses.bytes.{lang}"),
|
| 186 |
+
)
|
| 187 |
+
# tokenize with characters vocabulary
|
| 188 |
+
if need_chars:
|
| 189 |
+
for lang in [src, tgt]:
|
| 190 |
+
for split in SPLITS:
|
| 191 |
+
_get_chars(
|
| 192 |
+
op.join(root, f"{split}.moses.{lang}"),
|
| 193 |
+
op.join(root, f"{split}.moses.chars.{lang}"),
|
| 194 |
+
)
|
| 195 |
+
# tokenize with byte-level BPE vocabulary
|
| 196 |
+
if bbpe_size is not None:
|
| 197 |
+
# learn vocabulary
|
| 198 |
+
bchar_path = op.join(root, "train.bchar")
|
| 199 |
+
_convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
|
| 200 |
+
bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
|
| 201 |
+
_get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
|
| 202 |
+
os.remove(bchar_path)
|
| 203 |
+
# apply
|
| 204 |
+
for lang in [src, tgt]:
|
| 205 |
+
for split in SPLITS:
|
| 206 |
+
_apply_bbpe(
|
| 207 |
+
bbpe_model_prefix + ".model",
|
| 208 |
+
op.join(root, f"{split}.moses.{lang}"),
|
| 209 |
+
op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def main():
|
| 214 |
+
parser = argparse.ArgumentParser()
|
| 215 |
+
parser.add_argument("--root", type=str, default="data")
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--bpe-vocab",
|
| 218 |
+
default=None,
|
| 219 |
+
type=int,
|
| 220 |
+
help="Generate tokenized bitext with BPE of size K."
|
| 221 |
+
"Default to None (disabled).",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--bbpe-vocab",
|
| 225 |
+
default=None,
|
| 226 |
+
type=int,
|
| 227 |
+
help="Generate tokenized bitext with BBPE of size K."
|
| 228 |
+
"Default to None (disabled).",
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--byte-vocab",
|
| 232 |
+
action="store_true",
|
| 233 |
+
help="Generate tokenized bitext with bytes vocabulary",
|
| 234 |
+
)
|
| 235 |
+
parser.add_argument(
|
| 236 |
+
"--char-vocab",
|
| 237 |
+
action="store_true",
|
| 238 |
+
help="Generate tokenized bitext with chars vocabulary",
|
| 239 |
+
)
|
| 240 |
+
args = parser.parse_args()
|
| 241 |
+
|
| 242 |
+
preprocess_iwslt17(
|
| 243 |
+
args.root,
|
| 244 |
+
"fr",
|
| 245 |
+
"en",
|
| 246 |
+
args.bpe_vocab,
|
| 247 |
+
args.char_vocab,
|
| 248 |
+
args.bbpe_vocab,
|
| 249 |
+
args.byte_vocab,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
main()
|
fairseq-0.10.2/examples/byte_level_bpe/get_data.sh
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the MIT license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
PY_BIN_ROOT=
|
| 9 |
+
|
| 10 |
+
# PyPI dependency
|
| 11 |
+
${PY_BIN_ROOT}pip install sentencepiece sacremoses
|
| 12 |
+
|
| 13 |
+
# Get data
|
| 14 |
+
if [ ! -d "data" ]; then
|
| 15 |
+
mkdir data
|
| 16 |
+
fi
|
| 17 |
+
|
| 18 |
+
if [ ! -f "data/fr-en.tgz" ]; then
|
| 19 |
+
wget https://wit3.fbk.eu/archive/2017-01-trnted/texts/fr/en/fr-en.tgz -P data
|
| 20 |
+
tar xvf data/fr-en.tgz -C data
|
| 21 |
+
fi
|
| 22 |
+
${PY_BIN_ROOT}python get_bitext.py --bpe-vocab 16384 --byte-vocab --char-vocab
|
| 23 |
+
for VOCAB_SIZE in 2048 4096; do
|
| 24 |
+
${PY_BIN_ROOT}python get_bitext.py --bpe-vocab ${VOCAB_SIZE} --bbpe-vocab ${VOCAB_SIZE}
|
| 25 |
+
done
|
| 26 |
+
rm -r data/fr-en data/fr-en.tgz
|
| 27 |
+
|
| 28 |
+
# Generate binary dataset
|
| 29 |
+
${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_bpe16384 --joined-dictionary \
|
| 30 |
+
--workers "$(nproc)" --trainpref data/train.moses.bpe16384 --validpref data/valid.moses.bpe16384 \
|
| 31 |
+
--testpref data/test.moses.bpe16384
|
| 32 |
+
|
| 33 |
+
${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_bytes --joined-dictionary \
|
| 34 |
+
--workers "$(nproc)" --trainpref data/train.moses.bytes --validpref data/valid.moses.bytes \
|
| 35 |
+
--testpref data/test.moses.bytes
|
| 36 |
+
|
| 37 |
+
${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_chars --joined-dictionary \
|
| 38 |
+
--workers "$(nproc)" --trainpref data/train.moses.chars --validpref data/valid.moses.chars \
|
| 39 |
+
--testpref data/test.moses.chars
|
| 40 |
+
|
| 41 |
+
for VOCAB_SIZE in 2048 4096; do
|
| 42 |
+
for TYPE in bbpe bpe; do
|
| 43 |
+
${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir "data/bin_${TYPE}${VOCAB_SIZE}" \
|
| 44 |
+
--joined-dictionary --workers "$(nproc)" --trainpref "data/train.moses.${TYPE}${VOCAB_SIZE}" \
|
| 45 |
+
--validpref "data/valid.moses.${TYPE}${VOCAB_SIZE}" --testpref "data/test.moses.${TYPE}${VOCAB_SIZE}"
|
| 46 |
+
done
|
| 47 |
+
done
|
fairseq-0.10.2/examples/byte_level_bpe/gru_transformer.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the MIT license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from fairseq.models import register_model, register_model_architecture
|
| 14 |
+
from fairseq.models.transformer import TransformerEncoder, TransformerModel
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@register_model("gru_transformer")
|
| 18 |
+
class GRUTransformerModel(TransformerModel):
|
| 19 |
+
@classmethod
|
| 20 |
+
def build_encoder(cls, args, src_dict, embed_tokens):
|
| 21 |
+
return GRUTransformerEncoder(args, src_dict, embed_tokens)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class GRUTransformerEncoder(TransformerEncoder):
|
| 25 |
+
def __init__(self, args, dictionary, embed_tokens):
|
| 26 |
+
super().__init__(args, dictionary, embed_tokens)
|
| 27 |
+
self.emb_ctx = nn.GRU(
|
| 28 |
+
input_size=embed_tokens.embedding_dim,
|
| 29 |
+
hidden_size=embed_tokens.embedding_dim // 2,
|
| 30 |
+
num_layers=1,
|
| 31 |
+
bidirectional=True,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward_embedding(self, src_tokens):
|
| 35 |
+
# embed tokens and positions
|
| 36 |
+
x = embed = self.embed_scale * self.embed_tokens(src_tokens)
|
| 37 |
+
if self.embed_positions is not None:
|
| 38 |
+
x = embed + self.embed_positions(src_tokens)
|
| 39 |
+
|
| 40 |
+
# contextualize embeddings
|
| 41 |
+
x = x.transpose(0, 1)
|
| 42 |
+
x = self.dropout_module(x)
|
| 43 |
+
x, _ = self.emb_ctx.forward(x)
|
| 44 |
+
x = x.transpose(0, 1)
|
| 45 |
+
|
| 46 |
+
if self.layernorm_embedding is not None:
|
| 47 |
+
x = self.layernorm_embedding(x)
|
| 48 |
+
x = self.dropout_module(x)
|
| 49 |
+
return x, embed
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@register_model_architecture("gru_transformer", "gru_transformer")
|
| 53 |
+
def gru_transformer_base_architecture(args):
|
| 54 |
+
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
| 55 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
| 56 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
| 57 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
| 58 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
| 59 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
| 60 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
| 61 |
+
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
| 62 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
| 63 |
+
args.decoder_ffn_embed_dim = getattr(
|
| 64 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
| 65 |
+
)
|
| 66 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
| 67 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
| 68 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
| 69 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
| 70 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
| 71 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
| 72 |
+
args.activation_fn = getattr(args, "activation_fn", "relu")
|
| 73 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 74 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
| 75 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
| 76 |
+
args.share_decoder_input_output_embed = getattr(
|
| 77 |
+
args, "share_decoder_input_output_embed", False
|
| 78 |
+
)
|
| 79 |
+
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
|
| 80 |
+
args.no_token_positional_embeddings = getattr(
|
| 81 |
+
args, "no_token_positional_embeddings", False
|
| 82 |
+
)
|
| 83 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
| 84 |
+
args.no_cross_attention = getattr(args, "no_cross_attention", False)
|
| 85 |
+
args.cross_self_attention = getattr(args, "cross_self_attention", False)
|
| 86 |
+
args.layer_wise_attention = getattr(args, "layer_wise_attention", False)
|
| 87 |
+
|
| 88 |
+
args.decoder_output_dim = getattr(
|
| 89 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
| 90 |
+
)
|
| 91 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
| 92 |
+
|
| 93 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
| 94 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@register_model_architecture("gru_transformer", "gru_transformer_big")
|
| 98 |
+
def gru_transformer_big(args):
|
| 99 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
| 100 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
| 101 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
| 102 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
| 103 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
| 104 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
|
| 105 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
| 106 |
+
args.dropout = getattr(args, "dropout", 0.3)
|
| 107 |
+
gru_transformer_base_architecture(args)
|
fairseq-0.10.2/examples/conv_seq2seq/README.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Convolutional Sequence to Sequence Learning (Gehring et al., 2017)
|
| 2 |
+
|
| 3 |
+
## Pre-trained models
|
| 4 |
+
|
| 5 |
+
Description | Dataset | Model | Test set(s)
|
| 6 |
+
---|---|---|---
|
| 7 |
+
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2) <br> newstest2012/2013: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
|
| 8 |
+
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2)
|
| 9 |
+
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2)
|
| 10 |
+
|
| 11 |
+
## Example usage
|
| 12 |
+
|
| 13 |
+
See the [translation README](../translation/README.md) for instructions on reproducing results for WMT'14 En-De and
|
| 14 |
+
WMT'14 En-Fr using the `fconv_wmt_en_de` and `fconv_wmt_en_fr` model architectures.
|
| 15 |
+
|
| 16 |
+
## Citation
|
| 17 |
+
|
| 18 |
+
```bibtex
|
| 19 |
+
@inproceedings{gehring2017convs2s,
|
| 20 |
+
title = {Convolutional Sequence to Sequence Learning},
|
| 21 |
+
author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N},
|
| 22 |
+
booktitle = {Proc. of ICML},
|
| 23 |
+
year = 2017,
|
| 24 |
+
}
|
| 25 |
+
```
|
fairseq-0.10.2/examples/layerdrop/README.md
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)
|
| 2 |
+
This page contains information for how to train models with LayerDrop, based on this [paper](https://arxiv.org/abs/1909.11556).
|
| 3 |
+
|
| 4 |
+
## Citation:
|
| 5 |
+
If you found this technique useful, please cite our paper:
|
| 6 |
+
```bibtex
|
| 7 |
+
@article{fan2019reducing,
|
| 8 |
+
title={Reducing Transformer Depth on Demand with Structured Dropout},
|
| 9 |
+
author={Fan, Angela and Grave, Edouard and Joulin, Armand},
|
| 10 |
+
journal={arXiv preprint arXiv:1909.11556},
|
| 11 |
+
year={2019}
|
| 12 |
+
}
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
## Pre-trained models
|
| 16 |
+
|
| 17 |
+
Model | Description | Download
|
| 18 |
+
---|---|---
|
| 19 |
+
`layerdrop_wmt_en_de_12_6` | Transformer + LayerDrop 0.2 trained on WMT16 en-de with 12 encoder and 6 decoder layers | [layerdrop_wmt_en_de_12_6.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/layerdrop_wmt_en_de_12_6.tar.gz)
|
| 20 |
+
`roberta_layerdrop.base` | RoBERTa Base + LayerDrop 0.2 | [roberta_layerdrop.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.base.qnli.tar.gz)
|
| 21 |
+
`roberta_layerdrop.large` | RoBERTa Large + LayerDrop 0.2 | [roberta_layerdrop.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.large.tar.gz)
|
| 22 |
+
`roberta_layerdrop.large.mnli` | `roberta_layerdrop.large` finetuned on [MNLI](http://www.nyu.edu/projects/bowman/multinli) | [roberta_layerdrop.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.large.mnli.tar.gz)
|
| 23 |
+
`roberta_layerdrop.large.qnli` | `roberta_layerdrop.large` finetuned on [QNLI](https://arxiv.org/abs/1804.07461) | [roberta_layerdrop.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.large.qnli.tar.gz)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
Evaluate performance of these pre-trained models:
|
| 27 |
+
```bash
|
| 28 |
+
# Example for Machine Translation
|
| 29 |
+
fairseq-generate /path/to/bped/wmt/data --path nmt_checkpoint.pt \
|
| 30 |
+
--beam 8 --lenpen 0.4 \
|
| 31 |
+
--batch-size 64 \
|
| 32 |
+
--remove-bpe \
|
| 33 |
+
--gen-subset test > wmt16_gen.txt
|
| 34 |
+
bash scripts/compound_split_bleu.sh wmt16_gen.txt
|
| 35 |
+
# prints BLEU4 = 30.17
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
```python
|
| 39 |
+
# Example for RoBERTa + LayerDrop finetuned on MNLI:
|
| 40 |
+
from fairseq.models.roberta import RobertaModel
|
| 41 |
+
|
| 42 |
+
roberta_layerdrop = RobertaModel.from_pretrained(
|
| 43 |
+
'/path/to/MNLI/model',
|
| 44 |
+
checkpoint_file='mnli_checkpoint.pt',
|
| 45 |
+
data_name_or_path='/path/to/MNLI/data/MNLI-bin'
|
| 46 |
+
)
|
| 47 |
+
label_map = {0: 'contradiction', 2: 'neutral', 1: 'entailment'}
|
| 48 |
+
ncorrect, nsamples = 0, 0
|
| 49 |
+
roberta_layerdrop.cuda()
|
| 50 |
+
roberta_layerdrop.eval()
|
| 51 |
+
with open('/path/to/MNLI/data/dev_matched.tsv') as fin:
|
| 52 |
+
fin.readline()
|
| 53 |
+
for index, line in enumerate(fin):
|
| 54 |
+
tokens = line.strip().split('\t')
|
| 55 |
+
sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
|
| 56 |
+
tokens = roberta_layerdrop.encode(sent1, sent2)
|
| 57 |
+
prediction = roberta_layerdrop.predict('sentence_classification_head', tokens).argmax().item()
|
| 58 |
+
prediction_label = label_map[prediction]
|
| 59 |
+
ncorrect += int(prediction_label == target)
|
| 60 |
+
nsamples += 1
|
| 61 |
+
print('| Accuracy: ', float(ncorrect)/float(nsamples))
|
| 62 |
+
# prints | Accuracy: 0.9026999490575649
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Example for RoBERTa + LayerDrop finetuned on QNLI:
|
| 66 |
+
roberta = RobertaModel.from_pretrained(
|
| 67 |
+
'/path/to/QNLI/model',
|
| 68 |
+
checkpoint_file='qnli_checkpoint.pt',
|
| 69 |
+
data_name_or_path='/path/to/QNLI/data/QNLI-bin'
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
label_fn = lambda label: roberta.task.label_dictionary.string(
|
| 73 |
+
[label + roberta.task.target_dictionary.nspecial]
|
| 74 |
+
)
|
| 75 |
+
ncorrect, nsamples = 0, 0
|
| 76 |
+
roberta.cuda()
|
| 77 |
+
roberta.eval()
|
| 78 |
+
with open('/path/to/QNLI/data/dev.tsv') as fin:
|
| 79 |
+
fin.readline()
|
| 80 |
+
for index, line in enumerate(fin):
|
| 81 |
+
tokens = line.strip().split('\t')
|
| 82 |
+
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
|
| 83 |
+
tokens = roberta.encode(sent1, sent2)
|
| 84 |
+
prediction = roberta.predict('sentence_classification_head', tokens).argmax().item()
|
| 85 |
+
prediction_label = label_fn(prediction)
|
| 86 |
+
ncorrect += int(prediction_label == target)
|
| 87 |
+
nsamples += 1
|
| 88 |
+
print('| Accuracy: ', float(ncorrect)/float(nsamples))
|
| 89 |
+
# prints | Accuracy: 0.9480139117700896
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
## Example usage
|
| 94 |
+
|
| 95 |
+
To train a model with LayerDrop, add the following flags. We recommend 0.2, a value that worked well in our experiments. For Language Models that are decoder-only, you need only the decoder flag. For RoBERTa, an encoder, you need only the encoder flag. The encoder and decoder LayerDrop values can be set differently.
|
| 96 |
+
```
|
| 97 |
+
--encoder-layerdrop 0.2 --decoder-layerdrop 0.2
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
To prune a model that has been trained with LayerDrop, add the following flags followed by a comma separated list of which layers you would like to keep.
|
| 101 |
+
```
|
| 102 |
+
--encoder-layers-to-keep 0,2,4,6,8,10,12,14 --decoder-layers-to-keep 0,2,4,6,8,10,12,14
|
| 103 |
+
```
|
| 104 |
+
Setting these flags should print a message such as:
|
| 105 |
+
```
|
| 106 |
+
| Pruning model to specified layer configuration
|
| 107 |
+
```
|
| 108 |
+
You should also see a smaller number of parameters in the model, for example the 16-Layer Transformer Language Model prints:
|
| 109 |
+
```
|
| 110 |
+
num. model params: 246933504
|
| 111 |
+
```
|
| 112 |
+
while a model pruned to 8 Layers prints:
|
| 113 |
+
```
|
| 114 |
+
num. model params: 146163712
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
If you would like to pick up training with a model that has been pruned, simply adding these flags is sufficient. If you would like to use a script that only does evaluation (no training), you may need to pass an override command. A specific example would be for language modeling:
|
| 118 |
+
```bash
|
| 119 |
+
fairseq-eval-lm /path/to/wikitext-103 \
|
| 120 |
+
--path /path/to/model/checkpoint.pt \
|
| 121 |
+
--model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}"
|
| 122 |
+
```
|
| 123 |
+
This model override command overrides the training parameters and updates the model arguments so that the pruned model is run instead of the full model.
|
| 124 |
+
|
| 125 |
+
## Reproduce Paper Results
|
| 126 |
+
|
| 127 |
+
Looking to reproduce the results in the paper?
|
| 128 |
+
|
| 129 |
+
1. For Translation on WMT16 en-de, we followed this setting [here](https://github.com/pytorch/fairseq/blob/master/examples/scaling_nmt/README.md)
|
| 130 |
+
2. To train RoBERTa, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/roberta)
|
| 131 |
+
3. To train Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
## Tips
|
| 135 |
+
|
| 136 |
+
1. If you would like to train large models with better performance, LayerDrop should be set to a smaller value such as 0.1 or 0.2. Too much LayerDrop will mean the model has too much regularization, so may not reach the best performance. Since LayerDrop adds regularization, you may achieve the best performance by slightly reducing the amount of standard dropout (for example, reduce by 0.1).
|
| 137 |
+
|
| 138 |
+
2. If you would like to train large models to be pruned and made smaller, LayerDrop should be set to a larger value such as 0.5 if you want to prune very aggressively (such as removing half the network or more). If you would like to prune fewer layers away, LayerDrop can be set to a smaller value such as 0.2. Our experiments were conducted with low values of LayerDrop (such as 0.1 and 0.2), for reference.
|
| 139 |
+
|
| 140 |
+
3. When pruning layers at inference time, it is best to spread out the layers remaining so they are evenly spaced throughout the network. For example, if you want to remove 50% of the network, keeping every other layer is good.
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
## FAQ
|
| 144 |
+
|
| 145 |
+
1. How did the sharing layers experiment work? In an appendix (https://openreview.net/pdf?id=SylO2yStDr) we added an experiment on Wikitext-103 language modeling that combined LayerDrop with Weight Sharing. We shared chunks of 2 layers such that every other layer had shared weights. For example, if our network has layers 1 through 6, then layer 1 and 2 are shared, layer 3 and 4 are shared, and layer 5 and 6 are shared.
|
| 146 |
+
|
| 147 |
+
2. LayerDrop hasn't been helping in my setting? During training time, LayerDrop can help regularize your network. This is most important if your network is already overfitting - if your network is underfitting, it is possible LayerDrop is adding too much regularization. We recommend using smaller values (such as 0.1 or 0.2) and also decreasing the quantity of standard dropout (for example, reduce by 0.1).
|
| 148 |
+
|
| 149 |
+
3. Can you train a model without LayerDrop and finetune with LayerDrop (e.g. for BERT)? In our experiments, we did not see great performance. Models such as RoBERTa have trained for a long time in the pre-training setting, so only finetuning with LayerDrop for a few epochs on a downstream task such as MNLI does not achieve the robustness required for successful pruning.
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
## Having an issue or have a question?
|
| 153 |
+
|
| 154 |
+
Please open an issue in this repository with the details of your question. Thanks!
|
fairseq-0.10.2/examples/noisychannel/rerank_score_lm.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from fairseq import options
|
| 9 |
+
|
| 10 |
+
from . import rerank_options, rerank_utils
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def score_lm(args):
|
| 14 |
+
using_nbest = args.nbest_list is not None
|
| 15 |
+
(
|
| 16 |
+
pre_gen,
|
| 17 |
+
left_to_right_preprocessed_dir,
|
| 18 |
+
right_to_left_preprocessed_dir,
|
| 19 |
+
backwards_preprocessed_dir,
|
| 20 |
+
lm_preprocessed_dir,
|
| 21 |
+
) = rerank_utils.get_directories(
|
| 22 |
+
args.data_dir_name,
|
| 23 |
+
args.num_rescore,
|
| 24 |
+
args.gen_subset,
|
| 25 |
+
args.gen_model_name,
|
| 26 |
+
args.shard_id,
|
| 27 |
+
args.num_shards,
|
| 28 |
+
args.sampling,
|
| 29 |
+
args.prefix_len,
|
| 30 |
+
args.target_prefix_frac,
|
| 31 |
+
args.source_prefix_frac,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
|
| 35 |
+
if using_nbest:
|
| 36 |
+
print("Using predefined n-best list from interactive.py")
|
| 37 |
+
predictions_bpe_file = args.nbest_list
|
| 38 |
+
|
| 39 |
+
gen_output = rerank_utils.BitextOutputFromGen(
|
| 40 |
+
predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
if args.language_model is not None:
|
| 44 |
+
lm_score_file = rerank_utils.rescore_file_name(
|
| 45 |
+
pre_gen, args.prefix_len, args.lm_name, lm_file=True
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if args.language_model is not None and not os.path.isfile(lm_score_file):
|
| 49 |
+
print("STEP 4.5: language modeling for P(T)")
|
| 50 |
+
if args.lm_bpe_code is None:
|
| 51 |
+
bpe_status = "no bpe"
|
| 52 |
+
elif args.lm_bpe_code == "shared":
|
| 53 |
+
bpe_status = "shared"
|
| 54 |
+
else:
|
| 55 |
+
bpe_status = "different"
|
| 56 |
+
|
| 57 |
+
rerank_utils.lm_scoring(
|
| 58 |
+
lm_preprocessed_dir,
|
| 59 |
+
bpe_status,
|
| 60 |
+
gen_output,
|
| 61 |
+
pre_gen,
|
| 62 |
+
args.lm_dict,
|
| 63 |
+
args.lm_name,
|
| 64 |
+
args.language_model,
|
| 65 |
+
args.lm_bpe_code,
|
| 66 |
+
128,
|
| 67 |
+
lm_score_file,
|
| 68 |
+
args.target_lang,
|
| 69 |
+
args.source_lang,
|
| 70 |
+
prefix_len=args.prefix_len,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def cli_main():
|
| 75 |
+
parser = rerank_options.get_reranking_parser()
|
| 76 |
+
args = options.parse_args_and_arch(parser)
|
| 77 |
+
score_lm(args)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
cli_main()
|
fairseq-0.10.2/examples/quant_noise/README.md
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training with Quantization Noise for Extreme Model Compression ({Fan\*, Stock\*} *et al.*, 2020)
|
| 2 |
+
This page contains information for how to train and quantize models with Quantization Noise, for both scalar quantization like `int8` and Iterative Product Quantization.
|
| 3 |
+
Check out our paper [here](https://arxiv.org/abs/2004.07320).
|
| 4 |
+
|
| 5 |
+
Looking for pretrained models? They will be added shortly.
|
| 6 |
+
Looking for code to train vision models? We are working on open sourcing our code as part of ClassyVision. Please check back, but note that both the Scalar and Iterative Product Quantization counterparts of the `nn.Conv2d` module are already included in this release.
|
| 7 |
+
|
| 8 |
+
**Contents**:
|
| 9 |
+
- [Walk through of code](#walk-through-the-code)
|
| 10 |
+
- [Reproduce NLP Results](#looking-to-reproduce-the-nlp-results-in-the-paper)
|
| 11 |
+
- [Reproduce Vision Results](#looking-to-reproduce-the-vision-results-in-the-paper)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
## Citation
|
| 15 |
+
```bibtex
|
| 16 |
+
@article{fan2020training,
|
| 17 |
+
title={Training with Quantization Noise for Extreme Model Compression},
|
| 18 |
+
author={Angela Fan* and Pierre Stock* and and Benjamin Graham and Edouard Grave and Remi Gribonval and Herve Jegou and Armand Joulin},
|
| 19 |
+
year={2020},
|
| 20 |
+
eprint={2004.07320},
|
| 21 |
+
archivePrefix={arXiv},
|
| 22 |
+
primaryClass={cs.ML}
|
| 23 |
+
}
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
## Walk through the code
|
| 27 |
+
|
| 28 |
+
Training a model with Quant-Noise improves the performance in subsequent inference-time quantization by training models to be robust to quantization. This technique is useful for both scalar and product quantization methods, as well as multiple domains. We detail below our approach to train, quantize models and integrate our code to quantize your favorite models.
|
| 29 |
+
|
| 30 |
+
### Scalar Quantization
|
| 31 |
+
|
| 32 |
+
Unlike the section [Iterative Product Quantization](#iterative-product-quantization) which gives state-of-the-art compression, this section showcases the usefulness of our approach for simple scalar quantization baselines such as int8 using on-GPU Fake Quantization.
|
| 33 |
+
|
| 34 |
+
#### Training
|
| 35 |
+
|
| 36 |
+
Scalar quantization with Quant-Noise consists in randomly quantizing a proportion `p` of the weights during training. Scalar quantization is implemented [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quantization/scalar) under the form of Fake Quantization, meaning that we emulate int8 on GPU by quantizing and de-quantizing both the weights and the activations. We rely on PyTorch's [quantization primitives](https://github.com/pytorch/pytorch/tree/master/torch/quantization).
|
| 37 |
+
|
| 38 |
+
To train a model with Quant-Noise, add the following flag:
|
| 39 |
+
```
|
| 40 |
+
--quant-noise-scalar 0.5
|
| 41 |
+
```
|
| 42 |
+
Large values of noise make the network easier to quantize but may result in higher non-quantized test and validation perplexities.
|
| 43 |
+
|
| 44 |
+
#### Quantization
|
| 45 |
+
|
| 46 |
+
When evaluating a network, all quantized modules and activation hooks automatically switch to `p=1` so the validation accuracy reported by Fairseq is actually the quantized one, nothing more to do.
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
#### Integration with your own code
|
| 50 |
+
|
| 51 |
+
Looking to quantize your own models with Quant-Noise + Scalar Quantization?
|
| 52 |
+
- Use the function `quantize_model_` implemented [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quantization/scalar/utils.py) to (1) replace all your modules by their quantized counterparts and (2) add hooks to those modules to quantize the activations.
|
| 53 |
+
- Then, perform your training as usual. Note that in `eval()` mode, the network is always fully quantized (weights and activations) by default (`p=1`).
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
### Iterative Product Quantization
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
Iterative Product Quantization with Quant-Noise proceeds in two steps. First, a model must be trained uncompressed with Quant-Noise. Second, the model must be quantized with iPQ. Note that we implement here the simplest form of noise, which consists in randomly dropping a proportion `p` of blocks, and that worked as well as assigning those blocks to their current centroid.
|
| 61 |
+
|
| 62 |
+
#### Training
|
| 63 |
+
|
| 64 |
+
To train a model with Quant-Noise, add the following flags:
|
| 65 |
+
```
|
| 66 |
+
--quant-noise-pq 0.1 --quant-noise-pq-block-size 8
|
| 67 |
+
```
|
| 68 |
+
`quant-noise-pq` controls how much dropout is applied to the blocks of the weight matrix. `quant-noise-pq-block-size` controls the size of the weight matrix blocks.
|
| 69 |
+
We recommend training with 0.05 to 0.2 Quant-Noise, a value that worked well in our experiments. For the block-size, we recommend training with block-size of 8. Note that the block size must be a multiple of `input_features`, see the size checks [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quant_noise.py). Large block sizes result in higher compression ratio but may induce a loss in accuracy.
|
| 70 |
+
|
| 71 |
+
We currently support training Transformer based models, such as sequence-to-sequence, language models, and BERT architectures. The `quant_noise` function [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quant_noise.py) wraps a module. It splits a weight matrix into blocks and applies random dropout to these blocks.
|
| 72 |
+
In the Transformer architectures, quant-noise is applied to the input and output embeddings, the attention, and the FFN.
|
| 73 |
+
|
| 74 |
+
Quant-Noise can also be combined with **LayerDrop** (see [here](https://github.com/pytorch/fairseq/tree/master/examples/layerdrop)) to add its pruning effect to the quantized model and make the model even smaller. We recommend training with LayerDrop 0.1 or 0.2.
|
| 75 |
+
|
| 76 |
+
#### Quantization
|
| 77 |
+
|
| 78 |
+
We implement an improved version of product quantization from Stock et al, **iPQ**, described [here](https://arxiv.org/abs/1907.05686), see code with old API [here](https://github.com/facebookresearch/kill-the-bits). Note that we improved the iPQ API in terms of both compute speed and usability as described below.
|
| 79 |
+
|
| 80 |
+
For the particular case of PQ, quantization is made sequentially. We recommend first quantizing the FFNs, then the EMBs, and finally the ATTNs. Quantization is done in two sub-steps:
|
| 81 |
+
- First, perform `n` steps of Product Quantization (generally `n=20` is enough).
|
| 82 |
+
- Then, finetune the obtained centroids.
|
| 83 |
+
|
| 84 |
+
#### Integration with your own code
|
| 85 |
+
|
| 86 |
+
Looking to quantize your own models with Quant-Noise + iPQ?
|
| 87 |
+
- First wrap your modules with the `quant_noise` function [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quant_noise.py), which is module-agnostic and train your favorite model.
|
| 88 |
+
- Then, quantize your trained model using the code [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quantization/pq). This can be done *without any changes to your training loop*. Below is an example code for integration.
|
| 89 |
+
Note that we tried our approach only on Transformers and various Convolutional Models such as EfficientNets.
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
from fairseq.modules.quantization.pq import quantize_model_, SizeTracker
|
| 93 |
+
|
| 94 |
+
# get configuration parameters
|
| 95 |
+
n_centroids_config = config["n_centroids"]
|
| 96 |
+
block_sizes_config = config["block_sizes"]
|
| 97 |
+
layers_to_quantize = config["layers_to_quantize"]
|
| 98 |
+
|
| 99 |
+
# size tracker for keeping track of assignments, centroids and non-compressed sizes
|
| 100 |
+
size_tracker = SizeTracker(model)
|
| 101 |
+
|
| 102 |
+
# Quantize model by stages
|
| 103 |
+
for step in range(len(layers_to_quantize)):
|
| 104 |
+
|
| 105 |
+
# quantize model in-place
|
| 106 |
+
quantized_layers = quantize_model_(
|
| 107 |
+
model,
|
| 108 |
+
size_tracker,
|
| 109 |
+
layers_to_quantize,
|
| 110 |
+
block_sizes_config,
|
| 111 |
+
n_centroids_config,
|
| 112 |
+
step=step,
|
| 113 |
+
)
|
| 114 |
+
logger.info(f"Finetuning stage {step}, quantized layers: {quantized_layers}")
|
| 115 |
+
logger.info(f"{size_tracker}")
|
| 116 |
+
|
| 117 |
+
# Don't forget to re-create/update trainer/optimizer since model parameters have changed
|
| 118 |
+
optimizer = ...
|
| 119 |
+
|
| 120 |
+
# Finetune the centroids with your usual training loop for a few epochs
|
| 121 |
+
trainer.train_epoch()
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
## Looking to reproduce the NLP results in the paper?
|
| 126 |
+
|
| 127 |
+
We detail below how to reproduce the state-of-the-art results in reported in the paper for Quant-Noise + Iterative Product Quantization.
|
| 128 |
+
|
| 129 |
+
### Training with Quant-Noise
|
| 130 |
+
|
| 131 |
+
To **train** RoBERTa + QuantNoise, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/roberta).
|
| 132 |
+
The following command can be used to train a RoBERTa Base + QuantNoise model:
|
| 133 |
+
|
| 134 |
+
```bash
|
| 135 |
+
TOTAL_UPDATES=125000
|
| 136 |
+
WARMUP_UPDATES=10000
|
| 137 |
+
PEAK_LR=0.0005
|
| 138 |
+
TOKENS_PER_SAMPLE=512
|
| 139 |
+
MAX_POSITIONS=512
|
| 140 |
+
MAX_SENTENCES=16
|
| 141 |
+
UPDATE_FREQ=2
|
| 142 |
+
DATA_DIR=/path/to/data/here
|
| 143 |
+
|
| 144 |
+
fairseq-train $DATA_DIR \
|
| 145 |
+
--task masked_lm --criterion masked_lm --arch roberta_base \
|
| 146 |
+
--sample-break-mode complete \
|
| 147 |
+
--tokens-per-sample $TOKENS_PER_SAMPLE --max-positions $MAX_POSITIONS \
|
| 148 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 \
|
| 149 |
+
--clip-norm 0.0 \
|
| 150 |
+
--lr-scheduler polynomial_decay --lr $PEAK_LR \
|
| 151 |
+
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \
|
| 152 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
| 153 |
+
--weight-decay 0.01 \
|
| 154 |
+
--batch-size $MAX_SENTENCES \
|
| 155 |
+
--update-freq $UPDATE_FREQ --max-update $TOTAL_UPDATES \
|
| 156 |
+
--save-dir checkpoint/roberta \
|
| 157 |
+
--ddp-backend no_c10d --encoder-layerdrop 0.2 \
|
| 158 |
+
--quant-noise-pq 0.2 --quant-noise-pq-block-size 8 --untie-weights-roberta
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
To **finetune** RoBERTa + QuantNoise, we followed this setting [here](https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.glue.md).
|
| 162 |
+
The following command can be used to finetune a RoBERTa Base + QuantNoise model on the RTE dataset:
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
TOTAL_NUM_UPDATES=2036
|
| 166 |
+
WARMUP_UPDATES=122
|
| 167 |
+
LR=2e-05
|
| 168 |
+
NUM_CLASSES=2
|
| 169 |
+
MAX_SENTENCES=16
|
| 170 |
+
ROBERTA_PATH=/path/to/roberta_quantnoise/model.pt
|
| 171 |
+
|
| 172 |
+
fairseq-train /path/to/rte/data/ \
|
| 173 |
+
--restore-file $ROBERTA_PATH \
|
| 174 |
+
--max-positions 512 \
|
| 175 |
+
--batch-size $MAX_SENTENCES \
|
| 176 |
+
--max-tokens 4400 \
|
| 177 |
+
--task sentence_prediction \
|
| 178 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
| 179 |
+
--required-batch-size-multiple 1 \
|
| 180 |
+
--init-token 0 --separator-token 2 \
|
| 181 |
+
--arch roberta_large \
|
| 182 |
+
--criterion sentence_prediction \
|
| 183 |
+
--num-classes $NUM_CLASSES \
|
| 184 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
| 185 |
+
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
|
| 186 |
+
--clip-norm 0.0 \
|
| 187 |
+
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
|
| 188 |
+
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
|
| 189 |
+
--max-epoch 10 \
|
| 190 |
+
--find-unused-parameters \
|
| 191 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
| 192 |
+
--ddp-backend no_c10d \
|
| 193 |
+
--quant-noise-pq 0.2 --quant-noise-pq-block-size 8
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
To **train** Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model).
|
| 197 |
+
The following command can be used to train a Transformer + QuantNoise model on Wikitext-103:
|
| 198 |
+
|
| 199 |
+
```bash
|
| 200 |
+
fairseq-train --task language_modeling /path/to/wikitext-103/data \
|
| 201 |
+
--save-dir checkpoints/transformer_wikitext-103 \
|
| 202 |
+
--adaptive-input --adaptive-input-cutoff 20000,60000 --adaptive-input-factor 4 \
|
| 203 |
+
--adaptive-softmax-cutoff 20000,60000 --adaptive-softmax-dropout 0.2 --adaptive-softmax-factor 4.0 \
|
| 204 |
+
--tie-adaptive-proj --tie-adaptive-weights \
|
| 205 |
+
--arch transformer_lm_gbw \
|
| 206 |
+
--attention-dropout 0.1 --dropout 0.2 --relu-dropout 0.1 \
|
| 207 |
+
--clip-norm 0.1 --criterion adaptive_loss \
|
| 208 |
+
--ddp-backend no_c10d \
|
| 209 |
+
--decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 \
|
| 210 |
+
--decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \
|
| 211 |
+
--lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --max-lr 1.0 --t-mult 2.0 \
|
| 212 |
+
--max-tokens 3072 --tokens-per-sample 3072 --momentum 0.99 --optimizer nag \
|
| 213 |
+
--sample-break-mode none --update-freq 3 \
|
| 214 |
+
--warmup-init-lr 1e-07 --warmup-updates 16000 \
|
| 215 |
+
--weight-decay 0 --seed 1 --min-lr 1e-09 \
|
| 216 |
+
--quant-noise-pq 0.05 --quant-noise-pq-block-size 8
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
To **evaluate** this model, note you need to use the `eval.py` script. The following command can be used to evaluate:
|
| 220 |
+
|
| 221 |
+
```bash
|
| 222 |
+
fairseq-eval-lm /path/to/wikitext-103/data --path /path/to/model/checkpoint \
|
| 223 |
+
--sample-break-mode complete \
|
| 224 |
+
--max-tokens 3072 \
|
| 225 |
+
--context-window 2560 \
|
| 226 |
+
--softmax-batch 1024 \
|
| 227 |
+
--gen-subset valid
|
| 228 |
+
```
|
| 229 |
+
and change the `--gen-subset` to `test` if you would like to evaluate on the test set instead.
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
### Iterative Product Quantization
|
| 233 |
+
|
| 234 |
+
To quantize the finetuned RoBERTa model, we use this command on 1 GPU. This should run in a day.
|
| 235 |
+
```bash
|
| 236 |
+
TOTAL_NUM_UPDATES=6108 # 2036 updates for each iteration
|
| 237 |
+
WARMUP_UPDATES=122
|
| 238 |
+
LR=2e-05
|
| 239 |
+
NUM_CLASSES=2
|
| 240 |
+
MAX_SENTENCES=16
|
| 241 |
+
fairseq-train --task sentence_prediction /path/to/data/ \
|
| 242 |
+
--restore-file $ROBERTA_PATH \
|
| 243 |
+
--save-dir checkpoints/roberta_finetuned \
|
| 244 |
+
--max-positions 512 \
|
| 245 |
+
--batch-size $MAX_SENTENCES \
|
| 246 |
+
--max-tokens 4400 \
|
| 247 |
+
--init-token 0 --separator-token 2 \
|
| 248 |
+
--arch roberta_large \
|
| 249 |
+
--criterion sentence_prediction \
|
| 250 |
+
--num-classes $NUM_CLASSES \
|
| 251 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
| 252 |
+
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
|
| 253 |
+
--clip-norm 0.0 --lr-scheduler polynomial_decay \
|
| 254 |
+
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
|
| 255 |
+
--no-progress-bar --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d \
|
| 256 |
+
--quantization-config-path /path/to/config/yaml
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
To quantize the trained Language Model, we use this command on 8 V100 23GB GPUs. This should run in a couple of hours.
|
| 260 |
+
```bash
|
| 261 |
+
fairseq-train --task language_modeling /path/to/wikitext-103/data \
|
| 262 |
+
--save-dir checkpoints/transformer_wikitext-103 \
|
| 263 |
+
--adaptive-input --adaptive-input-cutoff 20000,60000 --adaptive-input-factor 4 \
|
| 264 |
+
--adaptive-softmax-cutoff 20000,60000 --adaptive-softmax-dropout 0.2 --adaptive-softmax-factor 4.0 \
|
| 265 |
+
--arch transformer_lm_gbw \
|
| 266 |
+
--attention-dropout 0.1 --dropout 0.2 --relu-dropout 0.1 \
|
| 267 |
+
--bucket-cap-mb 25 --char-embedder-highway-layers 2 --character-embedding-dim 4 \
|
| 268 |
+
--clip-norm 0.1 --criterion adaptive_loss \
|
| 269 |
+
--ddp-backend no_c10d \
|
| 270 |
+
--decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \
|
| 271 |
+
--fp16 --keep-last-epochs -1 \
|
| 272 |
+
--lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --max-lr 0.05 --min-lr 1e-09 \
|
| 273 |
+
--max-tokens 2944 --tokens-per-sample 2944\
|
| 274 |
+
--momentum 0.99 --no-epoch-checkpoints --no-progress-bar --optimizer nag --required-batch-size-multiple 8 \
|
| 275 |
+
--sample-break-mode none --t-mult 2.0 --skip-invalid-size-inputs-valid-test \
|
| 276 |
+
--tie-adaptive-proj --tie-adaptive-weights --update-freq 3 --weight-decay 0 --seed 1 \
|
| 277 |
+
--log-interval 100 --no-progress-bar --skip-invalid-size-inputs-valid-test \
|
| 278 |
+
--restore-file path/to/trained/lm/with/quant/noise \
|
| 279 |
+
--max-update 13500 --quantization-config-path /path/to/config/yaml
|
| 280 |
+
```
|
| 281 |
+
If you have less capacity or if your distributed training freezes, try reducing `--max-tokens` and `--tokens-per-sample` (this may reduce the quantized accuracy a bit).
|
| 282 |
+
|
| 283 |
+
### Remarks
|
| 284 |
+
|
| 285 |
+
We try to keep the open-sourced code as readable and as easy-to-plug as possible. Therefore, we did not test it for the following cases:
|
| 286 |
+
- Scalar quantization with RoBERTa.
|
| 287 |
+
- Quantization with iPQ and `int8` combined.
|
| 288 |
+
|
| 289 |
+
If you have trouble adapting it, we will be more than happy to help!
|
| 290 |
+
|
| 291 |
+
## Looking to reproduce the Vision results in the paper?
|
| 292 |
+
|
| 293 |
+
We are working on open sourcing our code as part of ClassyVision. Please check back.
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
## Having an issue or have a question?
|
| 297 |
+
|
| 298 |
+
Please open an issue in this repository with the details of your question. Thanks!
|
fairseq-0.10.2/examples/quant_noise/transformer_quantization_config.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# This file defines example configuration arguments for quantizing
|
| 7 |
+
# a transformer model with product quantization
|
| 8 |
+
|
| 9 |
+
# Number of Centroids for Product Quantization, by default 256 (byte-aligned)
|
| 10 |
+
n_centroids:
|
| 11 |
+
Linear:
|
| 12 |
+
key: in_features
|
| 13 |
+
value: {"*": 256}
|
| 14 |
+
Embedding:
|
| 15 |
+
key: embedding_dim
|
| 16 |
+
value: {"*": 256}
|
| 17 |
+
|
| 18 |
+
# Block Sizes for Product Quantization
|
| 19 |
+
# We suggest: 8 for FFN, 4 for ATTN, 4 for embedding projections, 8 for embeddings
|
| 20 |
+
block_sizes:
|
| 21 |
+
Linear:
|
| 22 |
+
key: fuzzy_name
|
| 23 |
+
value: {fc: 8, attn: 4, emb: 4}
|
| 24 |
+
Embedding:
|
| 25 |
+
key: fuzzy_name
|
| 26 |
+
value: {emb: 8}
|
| 27 |
+
|
| 28 |
+
# Layers to Quantize Sequentially
|
| 29 |
+
# We suggest: first FFN, then EMB, then ATTN
|
| 30 |
+
layers_to_quantize:
|
| 31 |
+
- decoder\\.layers\\.\d+\\.fc[12]
|
| 32 |
+
- decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01]
|
| 33 |
+
- decoder\\.layers\\.\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj)
|
mosesdecoder/moses/FF/BleuScoreFeature.h
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef BLUESCOREFEATURE_H
|
| 2 |
+
#define BLUESCOREFEATURE_H
|
| 3 |
+
|
| 4 |
+
#include <utility>
|
| 5 |
+
#include <string>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
#include <boost/unordered_map.hpp>
|
| 9 |
+
|
| 10 |
+
#include "StatefulFeatureFunction.h"
|
| 11 |
+
|
| 12 |
+
#include "moses/FF/FFState.h"
|
| 13 |
+
#include "moses/Phrase.h"
|
| 14 |
+
#include "moses/ChartHypothesis.h"
|
| 15 |
+
|
| 16 |
+
namespace Moses
|
| 17 |
+
{
|
| 18 |
+
|
| 19 |
+
class BleuScoreFeature;
|
| 20 |
+
|
| 21 |
+
class BleuScoreState : public FFState
|
| 22 |
+
{
|
| 23 |
+
public:
|
| 24 |
+
friend class BleuScoreFeature;
|
| 25 |
+
static size_t bleu_order;
|
| 26 |
+
|
| 27 |
+
BleuScoreState(bool is_syntax);
|
| 28 |
+
size_t hash() const;
|
| 29 |
+
virtual bool operator==(const FFState& other) const;
|
| 30 |
+
|
| 31 |
+
void print(std::ostream& out) const;
|
| 32 |
+
|
| 33 |
+
private:
|
| 34 |
+
Phrase m_words;
|
| 35 |
+
size_t m_source_length;
|
| 36 |
+
size_t m_target_length;
|
| 37 |
+
bool m_is_syntax;
|
| 38 |
+
// scaled reference length is needed for scoring incomplete hypotheses against reference translation
|
| 39 |
+
float m_scaled_ref_length;
|
| 40 |
+
|
| 41 |
+
std::vector< size_t > m_ngram_counts;
|
| 42 |
+
std::vector< size_t > m_ngram_matches;
|
| 43 |
+
|
| 44 |
+
void AddNgramCountAndMatches(std::vector< size_t >& counts, std::vector< size_t >& matches);
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
std::ostream& operator<<(std::ostream& out, const BleuScoreState& state);
|
| 49 |
+
|
| 50 |
+
typedef boost::unordered_map< Phrase, size_t > NGrams;
|
| 51 |
+
|
| 52 |
+
class RefValue : public std::pair<std::vector<size_t>,NGrams>
|
| 53 |
+
{
|
| 54 |
+
public:
|
| 55 |
+
RefValue& operator=( const RefValue& rhs ) {
|
| 56 |
+
first = rhs.first;
|
| 57 |
+
second = rhs.second;
|
| 58 |
+
return *this;
|
| 59 |
+
}
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class BleuScoreFeature : public StatefulFeatureFunction
|
| 64 |
+
{
|
| 65 |
+
public:
|
| 66 |
+
static const std::vector<BleuScoreFeature*>& GetColl() {
|
| 67 |
+
return s_staticColl;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
typedef boost::unordered_map<size_t, RefValue > RefCounts;
|
| 71 |
+
typedef boost::unordered_map<size_t, NGrams> Matches;
|
| 72 |
+
|
| 73 |
+
BleuScoreFeature(const std::string &line);
|
| 74 |
+
|
| 75 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 76 |
+
|
| 77 |
+
std::vector<float> DefaultWeights() const;
|
| 78 |
+
|
| 79 |
+
void PrintHistory(std::ostream& out) const;
|
| 80 |
+
void LoadReferences(const std::vector< std::vector< std::string > > &);
|
| 81 |
+
void SetCurrSourceLength(size_t);
|
| 82 |
+
void SetCurrNormSourceLength(size_t);
|
| 83 |
+
void SetCurrShortestRefLength(size_t);
|
| 84 |
+
void SetCurrAvgRefLength(size_t sent_id);
|
| 85 |
+
void SetAvgInputLength (float l) {
|
| 86 |
+
m_avg_input_length = l;
|
| 87 |
+
}
|
| 88 |
+
void SetCurrReferenceNgrams(size_t sent_id);
|
| 89 |
+
size_t GetShortestRefIndex(size_t ref_id);
|
| 90 |
+
size_t GetClosestRefLength(size_t ref_id, int hypoLength);
|
| 91 |
+
void UpdateHistory(const std::vector< const Word* >&);
|
| 92 |
+
void UpdateHistory(const std::vector< std::vector< const Word* > >& hypos, std::vector<size_t>& sourceLengths, std::vector<size_t>& ref_ids, size_t rank, size_t epoch);
|
| 93 |
+
void PrintRefLength(const std::vector<size_t>& ref_ids);
|
| 94 |
+
void SetBleuParameters(bool disable, bool sentenceBleu, bool scaleByInputLength, bool scaleByAvgInputLength,
|
| 95 |
+
bool scaleByInverseLength, bool scaleByAvgInverseLength,
|
| 96 |
+
float scaleByX, float historySmoothing, size_t scheme, bool simpleHistoryBleu);
|
| 97 |
+
|
| 98 |
+
void GetNgramMatchCounts(Phrase&,
|
| 99 |
+
const NGrams&,
|
| 100 |
+
std::vector< size_t >&,
|
| 101 |
+
std::vector< size_t >&,
|
| 102 |
+
size_t skip = 0) const;
|
| 103 |
+
void GetNgramMatchCounts_prefix(Phrase&,
|
| 104 |
+
const NGrams&,
|
| 105 |
+
std::vector< size_t >&,
|
| 106 |
+
std::vector< size_t >&,
|
| 107 |
+
size_t new_start_indices,
|
| 108 |
+
size_t last_end_index) const;
|
| 109 |
+
void GetNgramMatchCounts_overlap(Phrase& phrase,
|
| 110 |
+
const NGrams& ref_ngram_counts,
|
| 111 |
+
std::vector< size_t >& ret_counts,
|
| 112 |
+
std::vector< size_t >& ret_matches,
|
| 113 |
+
size_t overlap_index) const;
|
| 114 |
+
void GetClippedNgramMatchesAndCounts(Phrase&,
|
| 115 |
+
const NGrams&,
|
| 116 |
+
std::vector< size_t >&,
|
| 117 |
+
std::vector< size_t >&,
|
| 118 |
+
size_t skip = 0) const;
|
| 119 |
+
|
| 120 |
+
FFState* EvaluateWhenApplied( const Hypothesis& cur_hypo,
|
| 121 |
+
const FFState* prev_state,
|
| 122 |
+
ScoreComponentCollection* accumulator) const;
|
| 123 |
+
FFState* EvaluateWhenApplied(const ChartHypothesis& cur_hypo,
|
| 124 |
+
int featureID,
|
| 125 |
+
ScoreComponentCollection* accumulator) const;
|
| 126 |
+
|
| 127 |
+
bool Enabled() const {
|
| 128 |
+
return m_enabled;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
bool IsUseable(const FactorMask &mask) const;
|
| 132 |
+
|
| 133 |
+
float CalculateBleu(BleuScoreState*) const;
|
| 134 |
+
float CalculateBleu(Phrase translation) const;
|
| 135 |
+
const FFState* EmptyHypothesisState(const InputType&) const;
|
| 136 |
+
|
| 137 |
+
float GetSourceLengthHistory() {
|
| 138 |
+
return m_source_length_history;
|
| 139 |
+
}
|
| 140 |
+
float GetTargetLengthHistory() {
|
| 141 |
+
return m_target_length_history;
|
| 142 |
+
}
|
| 143 |
+
float GetAverageInputLength() {
|
| 144 |
+
return m_avg_input_length;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
void Load(AllOptions::ptr const& opts);
|
| 148 |
+
|
| 149 |
+
private:
|
| 150 |
+
static std::vector<BleuScoreFeature*> s_staticColl;
|
| 151 |
+
|
| 152 |
+
bool m_enabled;
|
| 153 |
+
bool m_sentence_bleu;
|
| 154 |
+
bool m_simple_history_bleu;
|
| 155 |
+
bool m_is_syntax;
|
| 156 |
+
// counts for pseudo-document
|
| 157 |
+
std::vector< float > m_count_history;
|
| 158 |
+
std::vector< float > m_match_history;
|
| 159 |
+
float m_source_length_history;
|
| 160 |
+
float m_target_length_history;
|
| 161 |
+
float m_ref_length_history;
|
| 162 |
+
|
| 163 |
+
size_t m_cur_source_length;
|
| 164 |
+
size_t m_cur_norm_source_length; // length without <s>, </s>
|
| 165 |
+
RefCounts m_refs;
|
| 166 |
+
NGrams m_cur_ref_ngrams;
|
| 167 |
+
float m_cur_ref_length;
|
| 168 |
+
|
| 169 |
+
// scale BLEU score by history of input length
|
| 170 |
+
bool m_scale_by_input_length;
|
| 171 |
+
bool m_scale_by_avg_input_length;
|
| 172 |
+
|
| 173 |
+
// scale by the inverse of the input length * 100
|
| 174 |
+
bool m_scale_by_inverse_length;
|
| 175 |
+
bool m_scale_by_avg_inverse_length;
|
| 176 |
+
|
| 177 |
+
float m_avg_input_length;
|
| 178 |
+
|
| 179 |
+
float m_scale_by_x;
|
| 180 |
+
|
| 181 |
+
// smoothing factor for history counts
|
| 182 |
+
float m_historySmoothing;
|
| 183 |
+
|
| 184 |
+
enum SmoothingScheme { PLUS_ONE = 1, PLUS_POINT_ONE = 2, PAPINENI = 3 };
|
| 185 |
+
SmoothingScheme m_smoothing_scheme;
|
| 186 |
+
};
|
| 187 |
+
|
| 188 |
+
} // Namespace.
|
| 189 |
+
|
| 190 |
+
#endif //BLUESCOREFEATURE_H
|
| 191 |
+
|
mosesdecoder/moses/FF/ConstrainedDecoding.cpp
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ConstrainedDecoding.h"
|
| 2 |
+
#include "moses/Hypothesis.h"
|
| 3 |
+
#include "moses/Manager.h"
|
| 4 |
+
#include "moses/ChartHypothesis.h"
|
| 5 |
+
#include "moses/ChartManager.h"
|
| 6 |
+
#include "moses/StaticData.h"
|
| 7 |
+
#include "moses/InputFileStream.h"
|
| 8 |
+
#include "moses/Util.h"
|
| 9 |
+
#include "util/exception.hh"
|
| 10 |
+
|
| 11 |
+
using namespace std;
|
| 12 |
+
|
| 13 |
+
namespace Moses
|
| 14 |
+
{
|
| 15 |
+
ConstrainedDecodingState::ConstrainedDecodingState(const Hypothesis &hypo)
|
| 16 |
+
{
|
| 17 |
+
hypo.GetOutputPhrase(m_outputPhrase);
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
ConstrainedDecodingState::ConstrainedDecodingState(const ChartHypothesis &hypo)
|
| 21 |
+
{
|
| 22 |
+
hypo.GetOutputPhrase(m_outputPhrase);
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
size_t ConstrainedDecodingState::hash() const
|
| 26 |
+
{
|
| 27 |
+
size_t ret = hash_value(m_outputPhrase);
|
| 28 |
+
return ret;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
bool ConstrainedDecodingState::operator==(const FFState& other) const
|
| 32 |
+
{
|
| 33 |
+
const ConstrainedDecodingState &otherFF = static_cast<const ConstrainedDecodingState&>(other);
|
| 34 |
+
bool ret = m_outputPhrase == otherFF.m_outputPhrase;
|
| 35 |
+
return ret;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
//////////////////////////////////////////////////////////////////
|
| 39 |
+
ConstrainedDecoding::ConstrainedDecoding(const std::string &line)
|
| 40 |
+
:StatefulFeatureFunction(1, line)
|
| 41 |
+
,m_maxUnknowns(0)
|
| 42 |
+
,m_negate(false)
|
| 43 |
+
,m_soft(false)
|
| 44 |
+
{
|
| 45 |
+
m_tuneable = false;
|
| 46 |
+
ReadParameters();
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
void ConstrainedDecoding::Load(AllOptions::ptr const& opts)
|
| 50 |
+
{
|
| 51 |
+
m_options = opts;
|
| 52 |
+
const StaticData &staticData = StaticData::Instance();
|
| 53 |
+
bool addBeginEndWord
|
| 54 |
+
= ((opts->search.algo == CYKPlus) || (opts->search.algo == ChartIncremental));
|
| 55 |
+
|
| 56 |
+
for(size_t i = 0; i < m_paths.size(); ++i) {
|
| 57 |
+
InputFileStream constraintFile(m_paths[i]);
|
| 58 |
+
std::string line;
|
| 59 |
+
long sentenceID = opts->output.start_translation_id - 1 ;
|
| 60 |
+
while (getline(constraintFile, line)) {
|
| 61 |
+
vector<string> vecStr = Tokenize(line, "\t");
|
| 62 |
+
|
| 63 |
+
Phrase phrase(0);
|
| 64 |
+
if (vecStr.size() == 1) {
|
| 65 |
+
sentenceID++;
|
| 66 |
+
phrase.CreateFromString(Output, opts->output.factor_order, vecStr[0], NULL);
|
| 67 |
+
} else if (vecStr.size() == 2) {
|
| 68 |
+
sentenceID = Scan<long>(vecStr[0]);
|
| 69 |
+
phrase.CreateFromString(Output, opts->output.factor_order, vecStr[1], NULL);
|
| 70 |
+
} else {
|
| 71 |
+
UTIL_THROW(util::Exception, "Reference file not loaded");
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
if (addBeginEndWord) {
|
| 75 |
+
phrase.InitStartEndWord();
|
| 76 |
+
}
|
| 77 |
+
m_constraints[sentenceID].push_back(phrase);
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
std::vector<float> ConstrainedDecoding::DefaultWeights() const
|
| 83 |
+
{
|
| 84 |
+
UTIL_THROW_IF2(m_numScoreComponents != 1,
|
| 85 |
+
"ConstrainedDecoding must only have 1 score");
|
| 86 |
+
vector<float> ret(1, 1);
|
| 87 |
+
return ret;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
template <class H, class M>
|
| 91 |
+
const std::vector<Phrase> *GetConstraint(const std::map<long,std::vector<Phrase> > &constraints, const H &hypo)
|
| 92 |
+
{
|
| 93 |
+
const M &mgr = hypo.GetManager();
|
| 94 |
+
const InputType &input = mgr.GetSource();
|
| 95 |
+
long id = input.GetTranslationId();
|
| 96 |
+
|
| 97 |
+
map<long,std::vector<Phrase> >::const_iterator iter;
|
| 98 |
+
iter = constraints.find(id);
|
| 99 |
+
|
| 100 |
+
if (iter == constraints.end()) {
|
| 101 |
+
UTIL_THROW(util::Exception, "Couldn't find reference " << id);
|
| 102 |
+
|
| 103 |
+
return NULL;
|
| 104 |
+
} else {
|
| 105 |
+
return &iter->second;
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
FFState* ConstrainedDecoding::EvaluateWhenApplied(
|
| 110 |
+
const Hypothesis& hypo,
|
| 111 |
+
const FFState* prev_state,
|
| 112 |
+
ScoreComponentCollection* accumulator) const
|
| 113 |
+
{
|
| 114 |
+
const std::vector<Phrase> *ref = GetConstraint<Hypothesis, Manager>(m_constraints, hypo);
|
| 115 |
+
assert(ref);
|
| 116 |
+
|
| 117 |
+
ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
|
| 118 |
+
const Phrase &outputPhrase = ret->GetPhrase();
|
| 119 |
+
|
| 120 |
+
size_t searchPos = NOT_FOUND;
|
| 121 |
+
size_t i = 0;
|
| 122 |
+
size_t size = 0;
|
| 123 |
+
while(searchPos == NOT_FOUND && i < ref->size()) {
|
| 124 |
+
searchPos = (*ref)[i].Find(outputPhrase, m_maxUnknowns);
|
| 125 |
+
size = (*ref)[i].GetSize();
|
| 126 |
+
i++;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
float score;
|
| 130 |
+
if (hypo.IsSourceCompleted()) {
|
| 131 |
+
// translated entire sentence.
|
| 132 |
+
bool match = (searchPos == 0) && (size == outputPhrase.GetSize());
|
| 133 |
+
if (!m_negate) {
|
| 134 |
+
score = match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
|
| 135 |
+
} else {
|
| 136 |
+
score = !match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
|
| 137 |
+
}
|
| 138 |
+
} else if (m_negate) {
|
| 139 |
+
// keep all derivations
|
| 140 |
+
score = 0;
|
| 141 |
+
} else {
|
| 142 |
+
score = (searchPos != NOT_FOUND) ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
accumulator->PlusEquals(this, score);
|
| 146 |
+
|
| 147 |
+
return ret;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
FFState* ConstrainedDecoding::EvaluateWhenApplied(
|
| 151 |
+
const ChartHypothesis &hypo,
|
| 152 |
+
int /* featureID - used to index the state in the previous hypotheses */,
|
| 153 |
+
ScoreComponentCollection* accumulator) const
|
| 154 |
+
{
|
| 155 |
+
const std::vector<Phrase> *ref = GetConstraint<ChartHypothesis, ChartManager>(m_constraints, hypo);
|
| 156 |
+
assert(ref);
|
| 157 |
+
|
| 158 |
+
const ChartManager &mgr = hypo.GetManager();
|
| 159 |
+
const Sentence &source = static_cast<const Sentence&>(mgr.GetSource());
|
| 160 |
+
|
| 161 |
+
ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
|
| 162 |
+
const Phrase &outputPhrase = ret->GetPhrase();
|
| 163 |
+
|
| 164 |
+
size_t searchPos = NOT_FOUND;
|
| 165 |
+
size_t i = 0;
|
| 166 |
+
size_t size = 0;
|
| 167 |
+
while(searchPos == NOT_FOUND && i < ref->size()) {
|
| 168 |
+
searchPos = (*ref)[i].Find(outputPhrase, m_maxUnknowns);
|
| 169 |
+
size = (*ref)[i].GetSize();
|
| 170 |
+
i++;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
float score;
|
| 174 |
+
if (hypo.GetCurrSourceRange().GetStartPos() == 0 &&
|
| 175 |
+
hypo.GetCurrSourceRange().GetEndPos() == source.GetSize() - 1) {
|
| 176 |
+
// translated entire sentence.
|
| 177 |
+
bool match = (searchPos == 0) && (size == outputPhrase.GetSize());
|
| 178 |
+
|
| 179 |
+
if (!m_negate) {
|
| 180 |
+
score = match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
|
| 181 |
+
} else {
|
| 182 |
+
score = !match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
|
| 183 |
+
}
|
| 184 |
+
} else if (m_negate) {
|
| 185 |
+
// keep all derivations
|
| 186 |
+
score = 0;
|
| 187 |
+
} else {
|
| 188 |
+
score = (searchPos != NOT_FOUND) ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
accumulator->PlusEquals(this, score);
|
| 192 |
+
|
| 193 |
+
return ret;
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
void ConstrainedDecoding::SetParameter(const std::string& key, const std::string& value)
|
| 197 |
+
{
|
| 198 |
+
if (key == "path") {
|
| 199 |
+
m_paths = Tokenize(value, ",");
|
| 200 |
+
} else if (key == "max-unknowns") {
|
| 201 |
+
m_maxUnknowns = Scan<int>(value);
|
| 202 |
+
} else if (key == "negate") {
|
| 203 |
+
m_negate = Scan<bool>(value);
|
| 204 |
+
} else if (key == "soft") {
|
| 205 |
+
m_soft = Scan<bool>(value);
|
| 206 |
+
} else {
|
| 207 |
+
StatefulFeatureFunction::SetParameter(key, value);
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
}
|
| 212 |
+
|
mosesdecoder/moses/FF/ControlRecombination.cpp
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ControlRecombination.h"
|
| 2 |
+
#include "moses/Hypothesis.h"
|
| 3 |
+
#include "moses/Manager.h"
|
| 4 |
+
#include "moses/ChartHypothesis.h"
|
| 5 |
+
#include "moses/ChartManager.h"
|
| 6 |
+
#include "moses/StaticData.h"
|
| 7 |
+
#include "moses/InputFileStream.h"
|
| 8 |
+
#include "moses/Util.h"
|
| 9 |
+
#include "util/exception.hh"
|
| 10 |
+
|
| 11 |
+
using namespace std;
|
| 12 |
+
|
| 13 |
+
namespace Moses
|
| 14 |
+
{
|
| 15 |
+
ControlRecombinationState::ControlRecombinationState(const Hypothesis &hypo, const ControlRecombination &ff)
|
| 16 |
+
:m_ff(ff)
|
| 17 |
+
{
|
| 18 |
+
if (ff.GetType() == SameOutput) {
|
| 19 |
+
//UTIL_THROW(util::Exception, "Implemented not yet completed for phrase-based model. Need to take into account the coverage");
|
| 20 |
+
hypo.GetOutputPhrase(m_outputPhrase);
|
| 21 |
+
} else {
|
| 22 |
+
m_hypo = &hypo;
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
ControlRecombinationState::ControlRecombinationState(const ChartHypothesis &hypo, const ControlRecombination &ff)
|
| 27 |
+
:m_ff(ff)
|
| 28 |
+
{
|
| 29 |
+
if (ff.GetType() == SameOutput) {
|
| 30 |
+
hypo.GetOutputPhrase(m_outputPhrase);
|
| 31 |
+
} else {
|
| 32 |
+
m_hypo = &hypo;
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
size_t ControlRecombinationState::hash() const
|
| 37 |
+
{
|
| 38 |
+
size_t ret;
|
| 39 |
+
if (m_ff.GetType() == SameOutput) {
|
| 40 |
+
ret = hash_value(m_outputPhrase);
|
| 41 |
+
} else {
|
| 42 |
+
// compare hypo address. Won't be equal unless they're actually the same hypo
|
| 43 |
+
ret = (size_t) m_hypo;
|
| 44 |
+
}
|
| 45 |
+
return ret;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
bool ControlRecombinationState::operator==(const FFState& other) const
|
| 49 |
+
{
|
| 50 |
+
const ControlRecombinationState &otherFF = static_cast<const ControlRecombinationState&>(other);
|
| 51 |
+
|
| 52 |
+
if (m_ff.GetType() == SameOutput) {
|
| 53 |
+
return m_outputPhrase == otherFF.m_outputPhrase;
|
| 54 |
+
} else {
|
| 55 |
+
// compare hypo address. Won't be equal unless they're actually the same hypo
|
| 56 |
+
if (m_hypo == otherFF.m_hypo)
|
| 57 |
+
return true;
|
| 58 |
+
return (m_hypo == otherFF.m_hypo);
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
std::vector<float> ControlRecombination::DefaultWeights() const
|
| 63 |
+
{
|
| 64 |
+
UTIL_THROW_IF2(m_numScoreComponents,
|
| 65 |
+
"ControlRecombination should not have any scores");
|
| 66 |
+
vector<float> ret(0);
|
| 67 |
+
return ret;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
FFState* ControlRecombination::EvaluateWhenApplied(
|
| 71 |
+
const Hypothesis& hypo,
|
| 72 |
+
const FFState* prev_state,
|
| 73 |
+
ScoreComponentCollection* accumulator) const
|
| 74 |
+
{
|
| 75 |
+
return new ControlRecombinationState(hypo, *this);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
FFState* ControlRecombination::EvaluateWhenApplied(
|
| 79 |
+
const ChartHypothesis &hypo,
|
| 80 |
+
int /* featureID - used to index the state in the previous hypotheses */,
|
| 81 |
+
ScoreComponentCollection* accumulator) const
|
| 82 |
+
{
|
| 83 |
+
return new ControlRecombinationState(hypo, *this);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
void ControlRecombination::SetParameter(const std::string& key, const std::string& value)
|
| 87 |
+
{
|
| 88 |
+
if (key == "type") {
|
| 89 |
+
m_type = (ControlRecombinationType) Scan<int>(value);
|
| 90 |
+
} else {
|
| 91 |
+
StatefulFeatureFunction::SetParameter(key, value);
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
}
|
| 96 |
+
|
mosesdecoder/moses/FF/ControlRecombination.h
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <string>
|
| 4 |
+
#include <map>
|
| 5 |
+
#include "StatefulFeatureFunction.h"
|
| 6 |
+
#include "FFState.h"
|
| 7 |
+
#include "moses/Phrase.h"
|
| 8 |
+
|
| 9 |
+
namespace Moses
|
| 10 |
+
{
|
| 11 |
+
enum ControlRecombinationType {
|
| 12 |
+
// when to recombine
|
| 13 |
+
SameOutput = 1,
|
| 14 |
+
Never = 2
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
class ControlRecombination;
|
| 18 |
+
|
| 19 |
+
class ControlRecombinationState : public FFState
|
| 20 |
+
{
|
| 21 |
+
public:
|
| 22 |
+
ControlRecombinationState(const ControlRecombination &ff)
|
| 23 |
+
:m_ff(ff) {
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
ControlRecombinationState(const Hypothesis &hypo, const ControlRecombination &ff);
|
| 27 |
+
ControlRecombinationState(const ChartHypothesis &hypo, const ControlRecombination &ff);
|
| 28 |
+
|
| 29 |
+
virtual size_t hash() const;
|
| 30 |
+
virtual bool operator==(const FFState& other) const;
|
| 31 |
+
|
| 32 |
+
const Phrase &GetPhrase() const {
|
| 33 |
+
return m_outputPhrase;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
protected:
|
| 37 |
+
Phrase m_outputPhrase;
|
| 38 |
+
const ControlRecombination &m_ff;
|
| 39 |
+
const void *m_hypo;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
//////////////////////////////////////////////////////////////////
|
| 43 |
+
|
| 44 |
+
// only allow recombination for the same output
|
| 45 |
+
class ControlRecombination : public StatefulFeatureFunction
|
| 46 |
+
{
|
| 47 |
+
public:
|
| 48 |
+
ControlRecombination(const std::string &line)
|
| 49 |
+
:StatefulFeatureFunction(0, line)
|
| 50 |
+
,m_type(SameOutput)
|
| 51 |
+
|
| 52 |
+
{
|
| 53 |
+
m_tuneable = false;
|
| 54 |
+
ReadParameters();
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
bool IsUseable(const FactorMask &mask) const {
|
| 58 |
+
return true;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
FFState* EvaluateWhenApplied(
|
| 62 |
+
const Hypothesis& cur_hypo,
|
| 63 |
+
const FFState* prev_state,
|
| 64 |
+
ScoreComponentCollection* accumulator) const;
|
| 65 |
+
|
| 66 |
+
FFState* EvaluateWhenApplied(
|
| 67 |
+
const ChartHypothesis& /* cur_hypo */,
|
| 68 |
+
int /* featureID - used to index the state in the previous hypotheses */,
|
| 69 |
+
ScoreComponentCollection* accumulator) const;
|
| 70 |
+
|
| 71 |
+
virtual const FFState* EmptyHypothesisState(const InputType &input) const {
|
| 72 |
+
return new ControlRecombinationState(*this);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
std::vector<float> DefaultWeights() const;
|
| 76 |
+
|
| 77 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 78 |
+
|
| 79 |
+
ControlRecombinationType GetType() const {
|
| 80 |
+
return m_type;
|
| 81 |
+
}
|
| 82 |
+
protected:
|
| 83 |
+
ControlRecombinationType m_type;
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
}
|
| 88 |
+
|
mosesdecoder/moses/FF/CorrectionPattern.cpp
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <sstream>
|
| 2 |
+
#include "CorrectionPattern.h"
|
| 3 |
+
#include "moses/Phrase.h"
|
| 4 |
+
#include "moses/TargetPhrase.h"
|
| 5 |
+
#include "moses/InputPath.h"
|
| 6 |
+
#include "moses/Hypothesis.h"
|
| 7 |
+
#include "moses/ChartHypothesis.h"
|
| 8 |
+
#include "moses/ScoreComponentCollection.h"
|
| 9 |
+
#include "moses/TranslationOption.h"
|
| 10 |
+
#include "util/string_piece_hash.hh"
|
| 11 |
+
#include "util/exception.hh"
|
| 12 |
+
|
| 13 |
+
#include <functional>
|
| 14 |
+
#include <algorithm>
|
| 15 |
+
|
| 16 |
+
#include <boost/foreach.hpp>
|
| 17 |
+
#include <boost/algorithm/string.hpp>
|
| 18 |
+
|
| 19 |
+
#include "Diffs.h"
|
| 20 |
+
|
| 21 |
+
namespace Moses
|
| 22 |
+
{
|
| 23 |
+
|
| 24 |
+
using namespace std;
|
| 25 |
+
|
| 26 |
+
std::string MakePair(const std::string &s1, const std::string &s2, bool general)
|
| 27 |
+
{
|
| 28 |
+
std::vector<std::string> sourceList;
|
| 29 |
+
std::vector<std::string> targetList;
|
| 30 |
+
|
| 31 |
+
if(general) {
|
| 32 |
+
Diffs diffs = CreateDiff(s1, s2);
|
| 33 |
+
|
| 34 |
+
size_t i = 0, j = 0;
|
| 35 |
+
char lastType = 'm';
|
| 36 |
+
|
| 37 |
+
std::string source, target;
|
| 38 |
+
std::string match;
|
| 39 |
+
|
| 40 |
+
int count = 1;
|
| 41 |
+
|
| 42 |
+
BOOST_FOREACH(Diff type, diffs) {
|
| 43 |
+
if(type == 'm') {
|
| 44 |
+
if(lastType != 'm') {
|
| 45 |
+
sourceList.push_back(source);
|
| 46 |
+
targetList.push_back(target);
|
| 47 |
+
}
|
| 48 |
+
source.clear();
|
| 49 |
+
target.clear();
|
| 50 |
+
|
| 51 |
+
if(s1[i] == '+') {
|
| 52 |
+
if(match.size() >= 3) {
|
| 53 |
+
sourceList.push_back("(\\w{3,})·");
|
| 54 |
+
std::string temp = "1";
|
| 55 |
+
sprintf((char*)temp.c_str(), "%d", count);
|
| 56 |
+
targetList.push_back("\\" + temp + "·");
|
| 57 |
+
count++;
|
| 58 |
+
} else {
|
| 59 |
+
sourceList.push_back(match + "·");
|
| 60 |
+
targetList.push_back(match + "·");
|
| 61 |
+
}
|
| 62 |
+
match.clear();
|
| 63 |
+
} else
|
| 64 |
+
match.push_back(s1[i]);
|
| 65 |
+
|
| 66 |
+
i++;
|
| 67 |
+
j++;
|
| 68 |
+
} else if(type == 'd') {
|
| 69 |
+
if(s1[i] == '+')
|
| 70 |
+
source += "·";
|
| 71 |
+
else
|
| 72 |
+
source.push_back(s1[i]);
|
| 73 |
+
i++;
|
| 74 |
+
} else if(type == 'i') {
|
| 75 |
+
if(s2[j] == '+')
|
| 76 |
+
target += "·";
|
| 77 |
+
else
|
| 78 |
+
target.push_back(s2[j]);
|
| 79 |
+
j++;
|
| 80 |
+
}
|
| 81 |
+
if(type != 'm' && !match.empty()) {
|
| 82 |
+
if(match.size() >= 3) {
|
| 83 |
+
sourceList.push_back("(\\w{3,})");
|
| 84 |
+
std::string temp = "1";
|
| 85 |
+
sprintf((char*)temp.c_str(), "%d", count);
|
| 86 |
+
targetList.push_back("\\" + temp);
|
| 87 |
+
count++;
|
| 88 |
+
} else {
|
| 89 |
+
sourceList.push_back(match);
|
| 90 |
+
targetList.push_back(match);
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
match.clear();
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
lastType = type;
|
| 97 |
+
}
|
| 98 |
+
if(lastType != 'm') {
|
| 99 |
+
sourceList.push_back(source);
|
| 100 |
+
targetList.push_back(target);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
if(!match.empty()) {
|
| 104 |
+
if(match.size() >= 3) {
|
| 105 |
+
sourceList.push_back("(\\w{3,})");
|
| 106 |
+
std::string temp = "1";
|
| 107 |
+
sprintf((char*)temp.c_str(), "%d", count);
|
| 108 |
+
targetList.push_back("\\"+ temp);
|
| 109 |
+
count++;
|
| 110 |
+
} else {
|
| 111 |
+
sourceList.push_back(match);
|
| 112 |
+
targetList.push_back(match);
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
match.clear();
|
| 116 |
+
} else {
|
| 117 |
+
std::string cs1 = s1;
|
| 118 |
+
std::string cs2 = s2;
|
| 119 |
+
boost::replace_all(cs1, "+", "·");
|
| 120 |
+
boost::replace_all(cs2, "+", "·");
|
| 121 |
+
|
| 122 |
+
sourceList.push_back(cs1);
|
| 123 |
+
targetList.push_back(cs2);
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
std::stringstream out;
|
| 127 |
+
out << "sub(«";
|
| 128 |
+
out << boost::join(sourceList, "");
|
| 129 |
+
out << "»,«";
|
| 130 |
+
out << boost::join(targetList, "");
|
| 131 |
+
out << "»)";
|
| 132 |
+
|
| 133 |
+
return out.str();
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
std::string CorrectionPattern::CreateSinglePattern(const Tokens &s1, const Tokens &s2) const
|
| 137 |
+
{
|
| 138 |
+
std::stringstream out;
|
| 139 |
+
if(s1.empty()) {
|
| 140 |
+
out << "ins(«" << boost::join(s2, "·") << "»)";
|
| 141 |
+
return out.str();
|
| 142 |
+
} else if(s2.empty()) {
|
| 143 |
+
out << "del(«" << boost::join(s1, "·") << "»)";
|
| 144 |
+
return out.str();
|
| 145 |
+
} else {
|
| 146 |
+
Tokens::value_type v1 = boost::join(s1, "+");
|
| 147 |
+
Tokens::value_type v2 = boost::join(s2, "+");
|
| 148 |
+
out << MakePair(v1, v2, m_general);
|
| 149 |
+
return out.str();
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
std::vector<std::string> GetContext(size_t pos,
|
| 154 |
+
size_t len,
|
| 155 |
+
size_t window,
|
| 156 |
+
const InputType &input,
|
| 157 |
+
const InputPath &inputPath,
|
| 158 |
+
const std::vector<FactorType>& factorTypes,
|
| 159 |
+
bool isRight)
|
| 160 |
+
{
|
| 161 |
+
|
| 162 |
+
const Sentence& sentence = static_cast<const Sentence&>(input);
|
| 163 |
+
const Range& range = inputPath.GetWordsRange();
|
| 164 |
+
|
| 165 |
+
int leftPos = range.GetStartPos() + pos - len - 1;
|
| 166 |
+
int rightPos = range.GetStartPos() + pos;
|
| 167 |
+
|
| 168 |
+
std::vector<std::string> contexts;
|
| 169 |
+
|
| 170 |
+
for(int length = 1; length <= (int)window; ++length) {
|
| 171 |
+
std::vector<std::string> current;
|
| 172 |
+
if(!isRight) {
|
| 173 |
+
for(int i = 0; i < length; i++) {
|
| 174 |
+
if(leftPos - i >= 0) {
|
| 175 |
+
current.push_back(sentence.GetWord(leftPos - i).GetString(factorTypes, false));
|
| 176 |
+
} else {
|
| 177 |
+
current.push_back("<s>");
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
if(current.back() == "<s>" && current.size() >= 2 && current[current.size()-2] == "<s>")
|
| 182 |
+
continue;
|
| 183 |
+
|
| 184 |
+
std::reverse(current.begin(), current.end());
|
| 185 |
+
contexts.push_back("left(«" + boost::join(current, "·") + "»)_");
|
| 186 |
+
}
|
| 187 |
+
if(isRight) {
|
| 188 |
+
for(int i = 0; i < length; i++) {
|
| 189 |
+
if(rightPos + i < (int)sentence.GetSize()) {
|
| 190 |
+
current.push_back(sentence.GetWord(rightPos + i).GetString(factorTypes, false));
|
| 191 |
+
} else {
|
| 192 |
+
current.push_back("</s>");
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
if(current.back() == "</s>" && current.size() >= 2 && current[current.size()-2] == "</s>")
|
| 197 |
+
continue;
|
| 198 |
+
|
| 199 |
+
contexts.push_back("_right(«" + boost::join(current, "·") + "»)");
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
return contexts;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
std::vector<std::string>
|
| 206 |
+
CorrectionPattern::CreatePattern(const Tokens &s1,
|
| 207 |
+
const Tokens &s2,
|
| 208 |
+
const InputType &input,
|
| 209 |
+
const InputPath &inputPath) const
|
| 210 |
+
{
|
| 211 |
+
|
| 212 |
+
Diffs diffs = CreateDiff(s1, s2);
|
| 213 |
+
size_t i = 0, j = 0;
|
| 214 |
+
char lastType = 'm';
|
| 215 |
+
std::vector<std::string> patternList;
|
| 216 |
+
Tokens source, target;
|
| 217 |
+
BOOST_FOREACH(Diff type, diffs) {
|
| 218 |
+
if(type == 'm') {
|
| 219 |
+
if(lastType != 'm') {
|
| 220 |
+
std::string pattern = CreateSinglePattern(source, target);
|
| 221 |
+
patternList.push_back(pattern);
|
| 222 |
+
|
| 223 |
+
if(m_context > 0) {
|
| 224 |
+
std::vector<std::string> leftContexts = GetContext(i, source.size(), m_context, input, inputPath, m_contextFactors, false);
|
| 225 |
+
std::vector<std::string> rightContexts = GetContext(i, source.size(), m_context, input, inputPath, m_contextFactors, true);
|
| 226 |
+
|
| 227 |
+
BOOST_FOREACH(std::string left, leftContexts)
|
| 228 |
+
patternList.push_back(left + pattern);
|
| 229 |
+
|
| 230 |
+
BOOST_FOREACH(std::string right, rightContexts)
|
| 231 |
+
patternList.push_back(pattern + right);
|
| 232 |
+
|
| 233 |
+
BOOST_FOREACH(std::string left, leftContexts)
|
| 234 |
+
BOOST_FOREACH(std::string right, rightContexts)
|
| 235 |
+
patternList.push_back(left + pattern + right);
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
source.clear();
|
| 239 |
+
target.clear();
|
| 240 |
+
if(s1[i] != s2[j]) {
|
| 241 |
+
source.push_back(s1[i]);
|
| 242 |
+
target.push_back(s2[j]);
|
| 243 |
+
}
|
| 244 |
+
i++;
|
| 245 |
+
j++;
|
| 246 |
+
} else if(type == 'd') {
|
| 247 |
+
source.push_back(s1[i]);
|
| 248 |
+
i++;
|
| 249 |
+
} else if(type == 'i') {
|
| 250 |
+
target.push_back(s2[j]);
|
| 251 |
+
j++;
|
| 252 |
+
}
|
| 253 |
+
lastType = type;
|
| 254 |
+
}
|
| 255 |
+
if(lastType != 'm') {
|
| 256 |
+
std::string pattern = CreateSinglePattern(source, target);
|
| 257 |
+
patternList.push_back(pattern);
|
| 258 |
+
|
| 259 |
+
if(m_context > 0) {
|
| 260 |
+
std::vector<std::string> leftContexts = GetContext(i, source.size(), m_context, input, inputPath, m_contextFactors, false);
|
| 261 |
+
std::vector<std::string> rightContexts = GetContext(i, source.size(), m_context, input, inputPath, m_contextFactors, true);
|
| 262 |
+
|
| 263 |
+
BOOST_FOREACH(std::string left, leftContexts)
|
| 264 |
+
patternList.push_back(left + pattern);
|
| 265 |
+
|
| 266 |
+
BOOST_FOREACH(std::string right, rightContexts)
|
| 267 |
+
patternList.push_back(pattern + right);
|
| 268 |
+
|
| 269 |
+
BOOST_FOREACH(std::string left, leftContexts)
|
| 270 |
+
BOOST_FOREACH(std::string right, rightContexts)
|
| 271 |
+
patternList.push_back(left + pattern + right);
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
return patternList;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
CorrectionPattern::CorrectionPattern(const std::string &line)
|
| 279 |
+
: StatelessFeatureFunction(0, line), m_factors(1, 0), m_general(false),
|
| 280 |
+
m_context(0), m_contextFactors(1, 0)
|
| 281 |
+
{
|
| 282 |
+
std::cerr << "Initializing correction pattern feature.." << std::endl;
|
| 283 |
+
ReadParameters();
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
void CorrectionPattern::SetParameter(const std::string& key, const std::string& value)
|
| 287 |
+
{
|
| 288 |
+
if (key == "factor") {
|
| 289 |
+
m_factors = std::vector<FactorType>(1, Scan<FactorType>(value));
|
| 290 |
+
} else if (key == "context-factor") {
|
| 291 |
+
m_contextFactors = std::vector<FactorType>(1, Scan<FactorType>(value));
|
| 292 |
+
} else if (key == "general") {
|
| 293 |
+
m_general = Scan<bool>(value);
|
| 294 |
+
} else if (key == "context") {
|
| 295 |
+
m_context = Scan<size_t>(value);
|
| 296 |
+
} else {
|
| 297 |
+
StatelessFeatureFunction::SetParameter(key, value);
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
void CorrectionPattern::EvaluateWithSourceContext(const InputType &input
|
| 302 |
+
, const InputPath &inputPath
|
| 303 |
+
, const TargetPhrase &targetPhrase
|
| 304 |
+
, const StackVec *stackVec
|
| 305 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 306 |
+
, ScoreComponentCollection *estimatedFutureScore) const
|
| 307 |
+
{
|
| 308 |
+
ComputeFeatures(input, inputPath, targetPhrase, &scoreBreakdown);
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
void CorrectionPattern::ComputeFeatures(
|
| 312 |
+
const InputType &input,
|
| 313 |
+
const InputPath &inputPath,
|
| 314 |
+
const TargetPhrase& target,
|
| 315 |
+
ScoreComponentCollection* accumulator) const
|
| 316 |
+
{
|
| 317 |
+
const Phrase &source = inputPath.GetPhrase();
|
| 318 |
+
|
| 319 |
+
std::vector<std::string> sourceTokens;
|
| 320 |
+
for(size_t i = 0; i < source.GetSize(); ++i)
|
| 321 |
+
sourceTokens.push_back(source.GetWord(i).GetString(m_factors, false));
|
| 322 |
+
|
| 323 |
+
std::vector<std::string> targetTokens;
|
| 324 |
+
for(size_t i = 0; i < target.GetSize(); ++i)
|
| 325 |
+
targetTokens.push_back(target.GetWord(i).GetString(m_factors, false));
|
| 326 |
+
|
| 327 |
+
std::vector<std::string> patternList = CreatePattern(sourceTokens, targetTokens, input, inputPath);
|
| 328 |
+
for(size_t i = 0; i < patternList.size(); ++i)
|
| 329 |
+
accumulator->PlusEquals(this, patternList[i], 1);
|
| 330 |
+
|
| 331 |
+
/*
|
| 332 |
+
BOOST_FOREACH(std::string w, sourceTokens)
|
| 333 |
+
std::cerr << w << " ";
|
| 334 |
+
std::cerr << std::endl;
|
| 335 |
+
BOOST_FOREACH(std::string w, targetTokens)
|
| 336 |
+
std::cerr << w << " ";
|
| 337 |
+
std::cerr << std::endl;
|
| 338 |
+
BOOST_FOREACH(std::string w, patternList)
|
| 339 |
+
std::cerr << w << " ";
|
| 340 |
+
std::cerr << std::endl << std::endl;
|
| 341 |
+
*/
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
bool CorrectionPattern::IsUseable(const FactorMask &mask) const
|
| 345 |
+
{
|
| 346 |
+
bool ret = true;
|
| 347 |
+
for(size_t i = 0; i < m_factors.size(); ++i)
|
| 348 |
+
ret = ret && mask[m_factors[i]];
|
| 349 |
+
for(size_t i = 0; i < m_contextFactors.size(); ++i)
|
| 350 |
+
ret = ret && mask[m_contextFactors[i]];
|
| 351 |
+
return ret;
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
}
|
mosesdecoder/moses/FF/CorrectionPattern.h
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef moses_CorrectionPattern_h
|
| 2 |
+
#define moses_CorrectionPattern_h
|
| 3 |
+
|
| 4 |
+
#include <string>
|
| 5 |
+
#include <boost/unordered_set.hpp>
|
| 6 |
+
|
| 7 |
+
#include "StatelessFeatureFunction.h"
|
| 8 |
+
#include "moses/FactorCollection.h"
|
| 9 |
+
#include "moses/AlignmentInfo.h"
|
| 10 |
+
|
| 11 |
+
namespace Moses
|
| 12 |
+
{
|
| 13 |
+
|
| 14 |
+
typedef std::vector<std::string> Tokens;
|
| 15 |
+
|
| 16 |
+
/** Sets the features for length of source phrase, target phrase, both.
|
| 17 |
+
*/
|
| 18 |
+
class CorrectionPattern : public StatelessFeatureFunction
|
| 19 |
+
{
|
| 20 |
+
private:
|
| 21 |
+
std::vector<FactorType> m_factors;
|
| 22 |
+
bool m_general;
|
| 23 |
+
size_t m_context;
|
| 24 |
+
std::vector<FactorType> m_contextFactors;
|
| 25 |
+
|
| 26 |
+
public:
|
| 27 |
+
CorrectionPattern(const std::string &line);
|
| 28 |
+
|
| 29 |
+
bool IsUseable(const FactorMask &mask) const;
|
| 30 |
+
|
| 31 |
+
void EvaluateInIsolation(const Phrase &source
|
| 32 |
+
, const TargetPhrase &targetPhrase
|
| 33 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 34 |
+
, ScoreComponentCollection &estimatedFutureScore) const
|
| 35 |
+
{}
|
| 36 |
+
|
| 37 |
+
virtual void EvaluateWithSourceContext(const InputType &input
|
| 38 |
+
, const InputPath &inputPath
|
| 39 |
+
, const TargetPhrase &targetPhrase
|
| 40 |
+
, const StackVec *stackVec
|
| 41 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 42 |
+
, ScoreComponentCollection *estimatedFutureScore = NULL) const;
|
| 43 |
+
|
| 44 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 45 |
+
, const TranslationOptionList &translationOptionList) const
|
| 46 |
+
{}
|
| 47 |
+
|
| 48 |
+
void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 49 |
+
ScoreComponentCollection* accumulator) const
|
| 50 |
+
{}
|
| 51 |
+
void EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 52 |
+
ScoreComponentCollection* accumulator) const
|
| 53 |
+
{}
|
| 54 |
+
|
| 55 |
+
void ComputeFeatures(const InputType &input,
|
| 56 |
+
const InputPath &inputPath,
|
| 57 |
+
const TargetPhrase& targetPhrase,
|
| 58 |
+
ScoreComponentCollection* accumulator) const;
|
| 59 |
+
|
| 60 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 61 |
+
|
| 62 |
+
std::vector<std::string> CreatePattern(const Tokens &s1,
|
| 63 |
+
const Tokens &s2,
|
| 64 |
+
const InputType &input,
|
| 65 |
+
const InputPath &inputPath) const;
|
| 66 |
+
|
| 67 |
+
std::string CreateSinglePattern(const Tokens &s1, const Tokens &s2) const;
|
| 68 |
+
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
#endif // moses_CorrectionPattern_h
|
mosesdecoder/moses/FF/CoveredReferenceFeature.cpp
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <boost/functional/hash.hpp>
|
| 2 |
+
#include <vector>
|
| 3 |
+
#include <algorithm>
|
| 4 |
+
#include <iterator>
|
| 5 |
+
#include <boost/foreach.hpp>
|
| 6 |
+
#include "CoveredReferenceFeature.h"
|
| 7 |
+
#include "moses/ScoreComponentCollection.h"
|
| 8 |
+
#include "moses/Hypothesis.h"
|
| 9 |
+
#include "moses/Manager.h"
|
| 10 |
+
#include "moses/ChartHypothesis.h"
|
| 11 |
+
#include "moses/ChartManager.h"
|
| 12 |
+
#include "moses/StaticData.h"
|
| 13 |
+
#include "moses/InputFileStream.h"
|
| 14 |
+
#include "moses/Util.h"
|
| 15 |
+
#include "util/exception.hh"
|
| 16 |
+
|
| 17 |
+
using namespace std;
|
| 18 |
+
|
| 19 |
+
namespace Moses
|
| 20 |
+
{
|
| 21 |
+
|
| 22 |
+
size_t CoveredReferenceState::hash() const
|
| 23 |
+
{
|
| 24 |
+
UTIL_THROW2("TODO:Haven't figure this out yet");
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
bool CoveredReferenceState::operator==(const FFState& other) const
|
| 28 |
+
{
|
| 29 |
+
UTIL_THROW2("TODO:Haven't figure this out yet");
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 33 |
+
|
| 34 |
+
void CoveredReferenceFeature::EvaluateWithSourceContext(const InputType &input
|
| 35 |
+
, const InputPath &inputPath
|
| 36 |
+
, const TargetPhrase &targetPhrase
|
| 37 |
+
, const StackVec *stackVec
|
| 38 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 39 |
+
, ScoreComponentCollection *estimatedScores) const
|
| 40 |
+
{
|
| 41 |
+
long id = input.GetTranslationId();
|
| 42 |
+
boost::unordered_map<long, std::multiset<string> >::const_iterator refIt = m_refs.find(id);
|
| 43 |
+
multiset<string> wordsInPhrase = GetWordsInPhrase(targetPhrase);
|
| 44 |
+
multiset<string> covered;
|
| 45 |
+
set_intersection(wordsInPhrase.begin(), wordsInPhrase.end(),
|
| 46 |
+
refIt->second.begin(), refIt->second.end(),
|
| 47 |
+
inserter(covered, covered.begin()));
|
| 48 |
+
vector<float> scores;
|
| 49 |
+
scores.push_back(covered.size());
|
| 50 |
+
|
| 51 |
+
scoreBreakdown.Assign(this, scores);
|
| 52 |
+
estimatedScores->Assign(this, scores);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
void CoveredReferenceFeature::Load(AllOptions::ptr const& opts)
|
| 56 |
+
{
|
| 57 |
+
m_options = opts;
|
| 58 |
+
InputFileStream refFile(m_path);
|
| 59 |
+
std::string line;
|
| 60 |
+
const StaticData &staticData = StaticData::Instance();
|
| 61 |
+
long sentenceID = opts->output.start_translation_id;
|
| 62 |
+
while (getline(refFile, line)) {
|
| 63 |
+
vector<string> words = Tokenize(line, " ");
|
| 64 |
+
multiset<string> wordSet;
|
| 65 |
+
// TODO make Tokenize work with other containers than vector
|
| 66 |
+
copy(words.begin(), words.end(), inserter(wordSet, wordSet.begin()));
|
| 67 |
+
m_refs.insert(make_pair(sentenceID++, wordSet));
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
void CoveredReferenceFeature::SetParameter(const std::string& key, const std::string& value)
|
| 72 |
+
{
|
| 73 |
+
if (key == "path") {
|
| 74 |
+
m_path = value;
|
| 75 |
+
} else {
|
| 76 |
+
StatefulFeatureFunction::SetParameter(key, value);
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
FFState* CoveredReferenceFeature::EvaluateWhenApplied(
|
| 81 |
+
const Hypothesis& cur_hypo,
|
| 82 |
+
const FFState* prev_state,
|
| 83 |
+
ScoreComponentCollection* accumulator) const
|
| 84 |
+
{
|
| 85 |
+
const CoveredReferenceState &prev = static_cast<const CoveredReferenceState&>(*prev_state);
|
| 86 |
+
CoveredReferenceState *ret = new CoveredReferenceState(prev);
|
| 87 |
+
|
| 88 |
+
const Manager &mgr = cur_hypo.GetManager();
|
| 89 |
+
const InputType &input = mgr.GetSource();
|
| 90 |
+
long id = input.GetTranslationId();
|
| 91 |
+
|
| 92 |
+
// which words from the reference remain uncovered
|
| 93 |
+
multiset<string> remaining;
|
| 94 |
+
boost::unordered_map<long, std::multiset<string> >::const_iterator refIt = m_refs.find(id);
|
| 95 |
+
if (refIt == m_refs.end()) UTIL_THROW(util::Exception, "Sentence id out of range: " + SPrint<long>(id));
|
| 96 |
+
set_difference(refIt->second.begin(), refIt->second.end(),
|
| 97 |
+
ret->m_coveredRef.begin(), ret->m_coveredRef.end(),
|
| 98 |
+
inserter(remaining, remaining.begin()));
|
| 99 |
+
|
| 100 |
+
// which of the remaining words are present in the current phrase
|
| 101 |
+
multiset<string> wordsInPhrase = GetWordsInPhrase(cur_hypo.GetCurrTargetPhrase());
|
| 102 |
+
multiset<string> newCovered;
|
| 103 |
+
set_intersection(wordsInPhrase.begin(), wordsInPhrase.end(),
|
| 104 |
+
remaining.begin(), remaining.end(),
|
| 105 |
+
inserter(newCovered, newCovered.begin()));
|
| 106 |
+
|
| 107 |
+
vector<float> estimateScore =
|
| 108 |
+
cur_hypo.GetCurrTargetPhrase().GetScoreBreakdown().GetScoresForProducer(this);
|
| 109 |
+
vector<float> scores;
|
| 110 |
+
scores.push_back(newCovered.size() - estimateScore[0]);
|
| 111 |
+
accumulator->PlusEquals(this, scores);
|
| 112 |
+
|
| 113 |
+
// update feature state
|
| 114 |
+
multiset<string>::const_iterator newCoveredIt;
|
| 115 |
+
for (newCoveredIt = newCovered.begin(); newCoveredIt != newCovered.end(); newCoveredIt++) {
|
| 116 |
+
ret->m_coveredRef.insert(*newCoveredIt);
|
| 117 |
+
}
|
| 118 |
+
return ret;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
FFState* CoveredReferenceFeature::EvaluateWhenApplied(
|
| 122 |
+
const ChartHypothesis& /* cur_hypo */,
|
| 123 |
+
int /* featureID - used to index the state in the previous hypotheses */,
|
| 124 |
+
ScoreComponentCollection* accumulator) const
|
| 125 |
+
{
|
| 126 |
+
UTIL_THROW(util::Exception, "Not implemented");
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
}
|
mosesdecoder/moses/FF/DecodeFeature.cpp
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// $Id: PhraseDictionaryMemory.cpp 2477 2009-08-07 16:47:54Z bhaddow $
|
| 2 |
+
// vim:tabstop=2
|
| 3 |
+
|
| 4 |
+
/***********************************************************************
|
| 5 |
+
Moses - factored phrase-based language decoder
|
| 6 |
+
Copyright (C) 2010 University of Edinburgh
|
| 7 |
+
|
| 8 |
+
This library is free software; you can redistribute it and/or
|
| 9 |
+
modify it under the terms of the GNU Lesser General Public
|
| 10 |
+
License as published by the Free Software Foundation; either
|
| 11 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 12 |
+
|
| 13 |
+
This library is distributed in the hope that it will be useful,
|
| 14 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 15 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 16 |
+
Lesser General Public License for more details.
|
| 17 |
+
|
| 18 |
+
You should have received a copy of the GNU Lesser General Public
|
| 19 |
+
License along with this library; if not, write to the Free Software
|
| 20 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 21 |
+
***********************************************************************/
|
| 22 |
+
|
| 23 |
+
#include <iostream>
|
| 24 |
+
|
| 25 |
+
#include "DecodeFeature.h"
|
| 26 |
+
#include "moses/DecodeStep.h"
|
| 27 |
+
#include "moses/StaticData.h"
|
| 28 |
+
|
| 29 |
+
using namespace std;
|
| 30 |
+
|
| 31 |
+
namespace Moses
|
| 32 |
+
{
|
| 33 |
+
|
| 34 |
+
DecodeFeature::DecodeFeature(const std::string &line, bool registerNow)
|
| 35 |
+
: StatelessFeatureFunction(line, registerNow)
|
| 36 |
+
, m_container(NULL)
|
| 37 |
+
{
|
| 38 |
+
VERBOSE(2,"DecodeFeature:" << std::endl);
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
DecodeFeature::DecodeFeature(size_t numScoreComponents
|
| 42 |
+
, const std::string &line)
|
| 43 |
+
: StatelessFeatureFunction(numScoreComponents, line)
|
| 44 |
+
, m_container(NULL)
|
| 45 |
+
{
|
| 46 |
+
VERBOSE(2,"DecodeFeature: no factors yet" << std::endl);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
DecodeFeature::DecodeFeature(size_t numScoreComponents
|
| 50 |
+
, const std::vector<FactorType> &input
|
| 51 |
+
, const std::vector<FactorType> &output
|
| 52 |
+
, const std::string &line)
|
| 53 |
+
: StatelessFeatureFunction(numScoreComponents, line)
|
| 54 |
+
, m_input(input), m_output(output)
|
| 55 |
+
, m_container(NULL)
|
| 56 |
+
{
|
| 57 |
+
m_inputFactors = FactorMask(input);
|
| 58 |
+
m_outputFactors = FactorMask(output);
|
| 59 |
+
VERBOSE(2,"DecodeFeature: input=" << m_inputFactors << " output=" << m_outputFactors << std::endl);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
void DecodeFeature::SetParameter(const std::string& key, const std::string& value)
|
| 63 |
+
{
|
| 64 |
+
if (key == "input-factor") {
|
| 65 |
+
m_input =Tokenize<FactorType>(value, ",");
|
| 66 |
+
m_inputFactors = FactorMask(m_input);
|
| 67 |
+
} else if (key == "output-factor") {
|
| 68 |
+
m_output =Tokenize<FactorType>(value, ",");
|
| 69 |
+
m_outputFactors = FactorMask(m_output);
|
| 70 |
+
} else {
|
| 71 |
+
StatelessFeatureFunction::SetParameter(key, value);
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
const FactorMask& DecodeFeature::GetOutputFactorMask() const
|
| 77 |
+
{
|
| 78 |
+
return m_outputFactors;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
const FactorMask& DecodeFeature::GetInputFactorMask() const
|
| 83 |
+
{
|
| 84 |
+
return m_inputFactors;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
const std::vector<FactorType>& DecodeFeature::GetInput() const
|
| 88 |
+
{
|
| 89 |
+
return m_input;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
const std::vector<FactorType>& DecodeFeature::GetOutput() const
|
| 93 |
+
{
|
| 94 |
+
return m_output;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
bool DecodeFeature::IsUseable(const FactorMask &mask) const
|
| 98 |
+
{
|
| 99 |
+
for (size_t i = 0; i < m_output.size(); ++i) {
|
| 100 |
+
const FactorType &factor = m_output[i];
|
| 101 |
+
if (!mask[factor]) {
|
| 102 |
+
return false;
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
return true;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
const DecodeGraph &DecodeFeature::GetDecodeGraph() const
|
| 109 |
+
{
|
| 110 |
+
assert(m_container);
|
| 111 |
+
const DecodeGraph *graph = m_container->GetContainer();
|
| 112 |
+
assert(graph);
|
| 113 |
+
return *graph;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
}
|
| 117 |
+
|
mosesdecoder/moses/FF/DistortionScoreProducer.h
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <string>
|
| 4 |
+
#include "StatefulFeatureFunction.h"
|
| 5 |
+
#include "moses/Range.h"
|
| 6 |
+
|
| 7 |
+
namespace Moses
|
| 8 |
+
{
|
| 9 |
+
|
| 10 |
+
/** Calculates Distortion scores
|
| 11 |
+
*/
|
| 12 |
+
class DistortionScoreProducer : public StatefulFeatureFunction
|
| 13 |
+
{
|
| 14 |
+
protected:
|
| 15 |
+
static std::vector<const DistortionScoreProducer*> s_staticColl;
|
| 16 |
+
|
| 17 |
+
FactorType m_sparseFactorTypeSource;
|
| 18 |
+
FactorType m_sparseFactorTypeTarget;
|
| 19 |
+
bool m_useSparse;
|
| 20 |
+
bool m_sparseDistance;
|
| 21 |
+
bool m_sparseSubordinate;
|
| 22 |
+
FactorType m_sparseFactorTypeTargetSubordinate;
|
| 23 |
+
const Factor* m_subordinateConjunctionTagFactor;
|
| 24 |
+
|
| 25 |
+
public:
|
| 26 |
+
static const std::vector<const DistortionScoreProducer*>& GetDistortionFeatureFunctions() {
|
| 27 |
+
return s_staticColl;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
DistortionScoreProducer(const std::string &line);
|
| 31 |
+
|
| 32 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 33 |
+
|
| 34 |
+
bool IsUseable(const FactorMask &mask) const {
|
| 35 |
+
return true;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
static float CalculateDistortionScore(const Hypothesis& hypo,
|
| 39 |
+
const Range &prev, const Range &curr, const int FirstGapPosition);
|
| 40 |
+
|
| 41 |
+
virtual const FFState* EmptyHypothesisState(const InputType &input) const;
|
| 42 |
+
|
| 43 |
+
virtual FFState* EvaluateWhenApplied(
|
| 44 |
+
const Hypothesis& cur_hypo,
|
| 45 |
+
const FFState* prev_state,
|
| 46 |
+
ScoreComponentCollection* accumulator) const;
|
| 47 |
+
|
| 48 |
+
virtual FFState* EvaluateWhenApplied(
|
| 49 |
+
const ChartHypothesis& /* cur_hypo */,
|
| 50 |
+
int /* featureID - used to index the state in the previous hypotheses */,
|
| 51 |
+
ScoreComponentCollection*) const {
|
| 52 |
+
UTIL_THROW(util::Exception, "DIstortion not implemented in chart decoder");
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
};
|
| 56 |
+
}
|
| 57 |
+
|
mosesdecoder/moses/FF/DynamicCacheBasedLanguageModel.cpp
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <utility>
|
| 2 |
+
#include "moses/StaticData.h"
|
| 3 |
+
#include "moses/InputFileStream.h"
|
| 4 |
+
#include "DynamicCacheBasedLanguageModel.h"
|
| 5 |
+
|
| 6 |
+
namespace Moses
|
| 7 |
+
{
|
| 8 |
+
|
| 9 |
+
std::map< const std::string, DynamicCacheBasedLanguageModel * > DynamicCacheBasedLanguageModel::s_instance_map;
|
| 10 |
+
DynamicCacheBasedLanguageModel *DynamicCacheBasedLanguageModel::s_instance = NULL;
|
| 11 |
+
|
| 12 |
+
DynamicCacheBasedLanguageModel::DynamicCacheBasedLanguageModel(const std::string &line)
|
| 13 |
+
: StatelessFeatureFunction(1, line)
|
| 14 |
+
{
|
| 15 |
+
VERBOSE(2,"Initializing DynamicCacheBasedLanguageModel feature..." << std::endl);
|
| 16 |
+
|
| 17 |
+
m_query_type = CBLM_QUERY_TYPE_ALLSUBSTRINGS;
|
| 18 |
+
m_score_type = CBLM_SCORE_TYPE_HYPERBOLA;
|
| 19 |
+
m_maxAge = 1000;
|
| 20 |
+
m_name = "default";
|
| 21 |
+
m_constant = false;
|
| 22 |
+
|
| 23 |
+
ReadParameters();
|
| 24 |
+
UTIL_THROW_IF2(s_instance_map.find(m_name) != s_instance_map.end(), "Only 1 DynamicCacheBasedLanguageModel feature named " + m_name + " is allowed");
|
| 25 |
+
s_instance_map[m_name] = this;
|
| 26 |
+
s_instance = this; //for back compatibility
|
| 27 |
+
|
| 28 |
+
SetPreComputedScores();
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
DynamicCacheBasedLanguageModel::~DynamicCacheBasedLanguageModel() {};
|
| 32 |
+
|
| 33 |
+
void DynamicCacheBasedLanguageModel::SetPreComputedScores()
|
| 34 |
+
{
|
| 35 |
+
#ifdef WITH_THREADS
|
| 36 |
+
boost::shared_lock<boost::shared_mutex> lock(m_cacheLock);
|
| 37 |
+
#endif
|
| 38 |
+
precomputedScores.clear();
|
| 39 |
+
for (unsigned int i=0; i<m_maxAge; i++) {
|
| 40 |
+
precomputedScores.push_back(decaying_score(i));
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
if ( m_score_type == CBLM_SCORE_TYPE_HYPERBOLA
|
| 44 |
+
|| m_score_type == CBLM_SCORE_TYPE_POWER
|
| 45 |
+
|| m_score_type == CBLM_SCORE_TYPE_EXPONENTIAL
|
| 46 |
+
|| m_score_type == CBLM_SCORE_TYPE_COSINE ) {
|
| 47 |
+
precomputedScores.push_back(decaying_score(m_maxAge));
|
| 48 |
+
} else { // m_score_type = CBLM_SCORE_TYPE_XXXXXXXXX_REWARD
|
| 49 |
+
precomputedScores.push_back(0.0);
|
| 50 |
+
}
|
| 51 |
+
m_lower_score = precomputedScores[m_maxAge];
|
| 52 |
+
VERBOSE(3, "SetPreComputedScores(): lower_age:|" << m_maxAge << "| lower_score:|" << m_lower_score << "|" << std::endl);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
float DynamicCacheBasedLanguageModel::GetPreComputedScores(const unsigned int age)
|
| 56 |
+
{
|
| 57 |
+
VERBOSE(2, "float DynamicCacheBasedLanguageModel::GetPreComputedScores" << std::endl);
|
| 58 |
+
VERBOSE(2, "age:|"<< age << "|" << std::endl);
|
| 59 |
+
|
| 60 |
+
if (age < m_maxAge) {
|
| 61 |
+
return precomputedScores.at(age);
|
| 62 |
+
} else {
|
| 63 |
+
VERBOSE(2, "is to big reduced to m)_maxAge:|"<< m_maxAge << "|" << std::endl);
|
| 64 |
+
return precomputedScores.at(m_maxAge);
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
void DynamicCacheBasedLanguageModel::SetParameter(const std::string& key, const std::string& value)
|
| 69 |
+
{
|
| 70 |
+
VERBOSE(2, "DynamicCacheBasedLanguageModel::SetParameter key:|" << key << "| value:|" << value << "|" << std::endl);
|
| 71 |
+
if (key == "cblm-query-type") {
|
| 72 |
+
SetQueryType(Scan<size_t>(value));
|
| 73 |
+
} else if (key == "cblm-score-type") {
|
| 74 |
+
SetScoreType(Scan<size_t>(value));
|
| 75 |
+
} else if (key == "cblm-max-age") {
|
| 76 |
+
SetMaxAge(Scan<unsigned int>(value));
|
| 77 |
+
} else if (key == "cblm-file") {
|
| 78 |
+
m_initfiles = Scan<std::string>(value);
|
| 79 |
+
} else if (key == "cblm-name") {
|
| 80 |
+
m_name = Scan<std::string>(value);
|
| 81 |
+
} else if (key == "cblm-constant") {
|
| 82 |
+
m_constant = Scan<bool>(value);
|
| 83 |
+
} else {
|
| 84 |
+
StatelessFeatureFunction::SetParameter(key, value);
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
void DynamicCacheBasedLanguageModel::EvaluateInIsolation(const Phrase &sp
|
| 89 |
+
, const TargetPhrase &tp
|
| 90 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 91 |
+
, ScoreComponentCollection &estimatedScores) const
|
| 92 |
+
{
|
| 93 |
+
float score = m_lower_score;
|
| 94 |
+
switch(m_query_type) {
|
| 95 |
+
case CBLM_QUERY_TYPE_WHOLESTRING:
|
| 96 |
+
score = Evaluate_Whole_String(tp);
|
| 97 |
+
break;
|
| 98 |
+
case CBLM_QUERY_TYPE_ALLSUBSTRINGS:
|
| 99 |
+
score = Evaluate_All_Substrings(tp);
|
| 100 |
+
break;
|
| 101 |
+
default:
|
| 102 |
+
UTIL_THROW_IF2(false, "This score type (" << m_query_type << ") is unknown.");
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
scoreBreakdown.Assign(this, score);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
float DynamicCacheBasedLanguageModel::Evaluate_Whole_String(const TargetPhrase& tp) const
|
| 109 |
+
{
|
| 110 |
+
//consider all words in the TargetPhrase as one n-gram
|
| 111 |
+
// and compute the decaying_score for the whole n-gram
|
| 112 |
+
// and return this value
|
| 113 |
+
|
| 114 |
+
decaying_cache_t::const_iterator it;
|
| 115 |
+
float score = m_lower_score;
|
| 116 |
+
|
| 117 |
+
std::string w = "";
|
| 118 |
+
size_t endpos = tp.GetSize();
|
| 119 |
+
for (size_t pos = 0 ; pos < endpos ; ++pos) {
|
| 120 |
+
w += tp.GetWord(pos).GetFactor(0)->GetString().as_string();
|
| 121 |
+
if (pos < endpos - 1) {
|
| 122 |
+
w += " ";
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
it = m_cache.find(w);
|
| 126 |
+
|
| 127 |
+
VERBOSE(4,"cblm::Evaluate_Whole_String: searching w:|" << w << "|" << std::endl);
|
| 128 |
+
if (it != m_cache.end()) { //found!
|
| 129 |
+
score = ((*it).second).second;
|
| 130 |
+
VERBOSE(4,"cblm::Evaluate_Whole_String: found w:|" << w << "|" << std::endl);
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
VERBOSE(4,"cblm::Evaluate_Whole_String: returning score:|" << score << "|" << std::endl);
|
| 134 |
+
return score;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
float DynamicCacheBasedLanguageModel::Evaluate_All_Substrings(const TargetPhrase& tp) const
|
| 138 |
+
{
|
| 139 |
+
//loop over all n-grams in the TargetPhrase (no matter of n)
|
| 140 |
+
//and compute the decaying_score for all words
|
| 141 |
+
//and return their sum
|
| 142 |
+
|
| 143 |
+
decaying_cache_t::const_iterator it;
|
| 144 |
+
float score = 0.0;
|
| 145 |
+
|
| 146 |
+
for (size_t startpos = 0 ; startpos < tp.GetSize() ; ++startpos) {
|
| 147 |
+
std::string w = "";
|
| 148 |
+
for (size_t endpos = startpos; endpos < tp.GetSize() ; ++endpos) {
|
| 149 |
+
w += tp.GetWord(endpos).GetFactor(0)->GetString().as_string();
|
| 150 |
+
it = m_cache.find(w);
|
| 151 |
+
|
| 152 |
+
if (it != m_cache.end()) { //found!
|
| 153 |
+
score += ((*it).second).second;
|
| 154 |
+
VERBOSE(3,"cblm::Evaluate_All_Substrings: found w:|" << w << "| actual score:|" << ((*it).second).second << "| score:|" << score << "|" << std::endl);
|
| 155 |
+
} else {
|
| 156 |
+
score += m_lower_score;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
if (endpos == startpos) {
|
| 160 |
+
w += " ";
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
VERBOSE(3,"cblm::Evaluate_All_Substrings: returning score:|" << score << "|" << std::endl);
|
| 166 |
+
return score;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
void DynamicCacheBasedLanguageModel::Print() const
|
| 170 |
+
{
|
| 171 |
+
#ifdef WITH_THREADS
|
| 172 |
+
boost::shared_lock<boost::shared_mutex> read_lock(m_cacheLock);
|
| 173 |
+
#endif
|
| 174 |
+
decaying_cache_t::const_iterator it;
|
| 175 |
+
std::cout << "Content of the cache of Cache-Based Language Model" << std::endl;
|
| 176 |
+
std::cout << "Size of the cache of Cache-Based Language Model:|" << m_cache.size() << "|" << std::endl;
|
| 177 |
+
for ( it=m_cache.begin() ; it != m_cache.end(); it++ ) {
|
| 178 |
+
std::cout << "word:|" << (*it).first << "| age:|" << ((*it).second).first << "| score:|" << ((*it).second).second << "|" << std::endl;
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
void DynamicCacheBasedLanguageModel::Decay()
|
| 183 |
+
{
|
| 184 |
+
#ifdef WITH_THREADS
|
| 185 |
+
boost::shared_lock<boost::shared_mutex> lock(m_cacheLock);
|
| 186 |
+
#endif
|
| 187 |
+
decaying_cache_t::iterator it;
|
| 188 |
+
|
| 189 |
+
unsigned int age;
|
| 190 |
+
float score;
|
| 191 |
+
for ( it=m_cache.begin() ; it != m_cache.end(); it++ ) {
|
| 192 |
+
age=((*it).second).first + 1;
|
| 193 |
+
if (age > m_maxAge) {
|
| 194 |
+
m_cache.erase(it);
|
| 195 |
+
it--;
|
| 196 |
+
} else {
|
| 197 |
+
score = GetPreComputedScores(age);
|
| 198 |
+
// score = decaying_score(age);
|
| 199 |
+
decaying_cache_value_t p (age, score);
|
| 200 |
+
(*it).second = p;
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
void DynamicCacheBasedLanguageModel::Update(std::vector<std::string> words, int age)
|
| 206 |
+
{
|
| 207 |
+
#ifdef WITH_THREADS
|
| 208 |
+
boost::shared_lock<boost::shared_mutex> lock(m_cacheLock);
|
| 209 |
+
#endif
|
| 210 |
+
VERBOSE(3,"words.size():|" << words.size() << "|" << std::endl);
|
| 211 |
+
for (size_t j=0; j<words.size(); j++) {
|
| 212 |
+
words[j] = Trim(words[j]);
|
| 213 |
+
// VERBOSE(3,"CacheBasedLanguageModel::Update word[" << j << "]:"<< words[j] << " age:" << age << " decaying_score(age):" << decaying_score(age) << std::endl);
|
| 214 |
+
// decaying_cache_value_t p (age,decaying_score(age));
|
| 215 |
+
VERBOSE(3,"CacheBasedLanguageModel::Update word[" << j << "]:"<< words[j] << " age:" << age << " GetPreComputedScores(age):" << GetPreComputedScores(age) << std::endl);
|
| 216 |
+
decaying_cache_value_t p (age,GetPreComputedScores(age));
|
| 217 |
+
std::pair<std::string, decaying_cache_value_t> e (words[j],p);
|
| 218 |
+
m_cache.erase(words[j]); //always erase the element (do nothing if the entry does not exist)
|
| 219 |
+
m_cache.insert(e); //insert the entry
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
void DynamicCacheBasedLanguageModel::ClearEntries(std::string &entries)
|
| 224 |
+
{
|
| 225 |
+
if (entries != "") {
|
| 226 |
+
VERBOSE(3,"entries:|" << entries << "|" << std::endl);
|
| 227 |
+
std::vector<std::string> elements = TokenizeMultiCharSeparator(entries, "||");
|
| 228 |
+
VERBOSE(3,"elements.size() after:|" << elements.size() << "|" << std::endl);
|
| 229 |
+
ClearEntries(elements);
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
void DynamicCacheBasedLanguageModel::ClearEntries(std::vector<std::string> words)
|
| 234 |
+
{
|
| 235 |
+
#ifdef WITH_THREADS
|
| 236 |
+
boost::shared_lock<boost::shared_mutex> lock(m_cacheLock);
|
| 237 |
+
#endif
|
| 238 |
+
VERBOSE(3,"words.size():|" << words.size() << "|" << std::endl);
|
| 239 |
+
for (size_t j=0; j<words.size(); j++) {
|
| 240 |
+
words[j] = Trim(words[j]);
|
| 241 |
+
VERBOSE(3,"CacheBasedLanguageModel::ClearEntries word[" << j << "]:"<< words[j] << std::endl);
|
| 242 |
+
m_cache.erase(words[j]); //always erase the element (do nothing if the entry does not exist)
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
void DynamicCacheBasedLanguageModel::Insert(std::string &entries)
|
| 247 |
+
{
|
| 248 |
+
if (entries != "") {
|
| 249 |
+
VERBOSE(3,"entries:|" << entries << "|" << std::endl);
|
| 250 |
+
std::vector<std::string> elements = TokenizeMultiCharSeparator(entries, "||");
|
| 251 |
+
VERBOSE(3,"elements.size() after:|" << elements.size() << "|" << std::endl);
|
| 252 |
+
Insert(elements);
|
| 253 |
+
}
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
void DynamicCacheBasedLanguageModel::Insert(std::vector<std::string> ngrams)
|
| 257 |
+
{
|
| 258 |
+
VERBOSE(3,"DynamicCacheBasedLanguageModel Insert ngrams.size():|" << ngrams.size() << "|" << std::endl);
|
| 259 |
+
if (m_constant == false) {
|
| 260 |
+
Decay();
|
| 261 |
+
}
|
| 262 |
+
Update(ngrams,1);
|
| 263 |
+
IFVERBOSE(3) Print();
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
void DynamicCacheBasedLanguageModel::ExecuteDlt(std::map<std::string, std::string> dlt_meta)
|
| 267 |
+
{
|
| 268 |
+
if (dlt_meta.find("cblm") != dlt_meta.end()) {
|
| 269 |
+
Insert(dlt_meta["cblm"]);
|
| 270 |
+
}
|
| 271 |
+
if (dlt_meta.find("cblm-command") != dlt_meta.end()) {
|
| 272 |
+
Execute(dlt_meta["cblm-command"]);
|
| 273 |
+
}
|
| 274 |
+
if (dlt_meta.find("cblm-file") != dlt_meta.end()) {
|
| 275 |
+
Load(dlt_meta["cblm-file"]);
|
| 276 |
+
}
|
| 277 |
+
if (dlt_meta.find("cblm-clear-entries") != dlt_meta.end()) {
|
| 278 |
+
ClearEntries(dlt_meta["cblm-clear-entries"]);
|
| 279 |
+
}
|
| 280 |
+
if (dlt_meta.find("cblm-clear-all") != dlt_meta.end()) {
|
| 281 |
+
Clear();
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
void DynamicCacheBasedLanguageModel::Execute(std::string command)
|
| 287 |
+
{
|
| 288 |
+
VERBOSE(2,"DynamicCacheBasedLanguageModel::Execute(std::string command:|" << command << "|" << std::endl);
|
| 289 |
+
std::vector<std::string> commands = Tokenize(command, "||");
|
| 290 |
+
Execute(commands);
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
void DynamicCacheBasedLanguageModel::Execute(std::vector<std::string> commands)
|
| 294 |
+
{
|
| 295 |
+
for (size_t j=0; j<commands.size(); j++) {
|
| 296 |
+
Execute_Single_Command(commands[j]);
|
| 297 |
+
}
|
| 298 |
+
IFVERBOSE(2) Print();
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
void DynamicCacheBasedLanguageModel::Execute_Single_Command(std::string command)
|
| 302 |
+
{
|
| 303 |
+
VERBOSE(2,"CacheBasedLanguageModel::Execute_Single_Command(std::string command:|" << command << "|" << std::endl);
|
| 304 |
+
if (command == "clear") {
|
| 305 |
+
VERBOSE(2,"CacheBasedLanguageModel Execute command:|"<< command << "|. Cache cleared." << std::endl);
|
| 306 |
+
Clear();
|
| 307 |
+
} else if (command == "settype_wholestring") {
|
| 308 |
+
VERBOSE(2,"CacheBasedLanguageModel Execute command:|"<< command << "|. Query type set to " << CBLM_QUERY_TYPE_WHOLESTRING << " (CBLM_QUERY_TYPE_WHOLESTRING)." << std::endl);
|
| 309 |
+
SetQueryType(CBLM_QUERY_TYPE_WHOLESTRING);
|
| 310 |
+
} else if (command == "settype_allsubstrings") {
|
| 311 |
+
VERBOSE(2,"CacheBasedLanguageModel Execute command:|"<< command << "|. Query type set to " << CBLM_QUERY_TYPE_ALLSUBSTRINGS << " (CBLM_QUERY_TYPE_ALLSUBSTRINGS)." << std::endl);
|
| 312 |
+
SetQueryType(CBLM_QUERY_TYPE_ALLSUBSTRINGS);
|
| 313 |
+
} else {
|
| 314 |
+
VERBOSE(2,"CacheBasedLanguageModel Execute command:|"<< command << "| is unknown. Skipped." << std::endl);
|
| 315 |
+
}
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
void DynamicCacheBasedLanguageModel::Clear()
|
| 319 |
+
{
|
| 320 |
+
#ifdef WITH_THREADS
|
| 321 |
+
boost::shared_lock<boost::shared_mutex> lock(m_cacheLock);
|
| 322 |
+
#endif
|
| 323 |
+
m_cache.clear();
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
void DynamicCacheBasedLanguageModel::Load(AllOptions::ptr const& opts)
|
| 327 |
+
{
|
| 328 |
+
m_options = opts;
|
| 329 |
+
// SetPreComputedScores();
|
| 330 |
+
VERBOSE(2,"DynamicCacheBasedLanguageModel::Load()" << std::endl);
|
| 331 |
+
Load(m_initfiles);
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
void DynamicCacheBasedLanguageModel::Load(const std::string filestr)
|
| 335 |
+
{
|
| 336 |
+
VERBOSE(2,"DynamicCacheBasedLanguageModel::Load(const std::string filestr)" << std::endl);
|
| 337 |
+
// std::vector<std::string> files = Tokenize(m_initfiles, "||");
|
| 338 |
+
std::vector<std::string> files = Tokenize(filestr, "||");
|
| 339 |
+
Load_Multiple_Files(files);
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
void DynamicCacheBasedLanguageModel::Load_Multiple_Files(std::vector<std::string> files)
|
| 344 |
+
{
|
| 345 |
+
VERBOSE(2,"DynamicCacheBasedLanguageModel::Load_Multiple_Files(std::vector<std::string> files)" << std::endl);
|
| 346 |
+
for(size_t j = 0; j < files.size(); ++j) {
|
| 347 |
+
Load_Single_File(files[j]);
|
| 348 |
+
}
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
void DynamicCacheBasedLanguageModel::Load_Single_File(const std::string file)
|
| 352 |
+
{
|
| 353 |
+
VERBOSE(2,"DynamicCacheBasedLanguageModel::Load_Single_File(const std::string file)" << std::endl);
|
| 354 |
+
//file format
|
| 355 |
+
//age || n-gram
|
| 356 |
+
//age || n-gram || n-gram || n-gram || ...
|
| 357 |
+
//....
|
| 358 |
+
//each n-gram is a sequence of n words (no matter of n)
|
| 359 |
+
//
|
| 360 |
+
//there is no limit on the size of n
|
| 361 |
+
//
|
| 362 |
+
//entries can be repeated, but the last entry overwrites the previous
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
VERBOSE(2,"Loading data from the cache file " << file << std::endl);
|
| 366 |
+
InputFileStream cacheFile(file);
|
| 367 |
+
|
| 368 |
+
std::string line;
|
| 369 |
+
int age;
|
| 370 |
+
std::vector<std::string> words;
|
| 371 |
+
|
| 372 |
+
while (getline(cacheFile, line)) {
|
| 373 |
+
std::vector<std::string> vecStr = TokenizeMultiCharSeparator( line , "||" );
|
| 374 |
+
if (vecStr.size() >= 2) {
|
| 375 |
+
age = Scan<int>(vecStr[0]);
|
| 376 |
+
vecStr.erase(vecStr.begin());
|
| 377 |
+
Update(vecStr,age);
|
| 378 |
+
} else {
|
| 379 |
+
UTIL_THROW_IF2(false, "The format of the loaded file is wrong: " << line);
|
| 380 |
+
}
|
| 381 |
+
}
|
| 382 |
+
IFVERBOSE(2) Print();
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
void DynamicCacheBasedLanguageModel::SetQueryType(size_t type)
|
| 386 |
+
{
|
| 387 |
+
#ifdef WITH_THREADS
|
| 388 |
+
boost::shared_lock<boost::shared_mutex> read_lock(m_cacheLock);
|
| 389 |
+
#endif
|
| 390 |
+
|
| 391 |
+
m_query_type = type;
|
| 392 |
+
if ( m_query_type != CBLM_QUERY_TYPE_WHOLESTRING
|
| 393 |
+
&& m_query_type != CBLM_QUERY_TYPE_ALLSUBSTRINGS ) {
|
| 394 |
+
VERBOSE(2, "This query type " << m_query_type << " is unknown. Instead used " << CBLM_QUERY_TYPE_ALLSUBSTRINGS << "." << std::endl);
|
| 395 |
+
m_query_type = CBLM_QUERY_TYPE_ALLSUBSTRINGS;
|
| 396 |
+
}
|
| 397 |
+
VERBOSE(2, "CacheBasedLanguageModel QueryType: " << m_query_type << std::endl);
|
| 398 |
+
|
| 399 |
+
};
|
| 400 |
+
|
| 401 |
+
void DynamicCacheBasedLanguageModel::SetScoreType(size_t type)
|
| 402 |
+
{
|
| 403 |
+
#ifdef WITH_THREADS
|
| 404 |
+
boost::shared_lock<boost::shared_mutex> read_lock(m_cacheLock);
|
| 405 |
+
#endif
|
| 406 |
+
m_score_type = type;
|
| 407 |
+
if ( m_score_type != CBLM_SCORE_TYPE_HYPERBOLA
|
| 408 |
+
&& m_score_type != CBLM_SCORE_TYPE_POWER
|
| 409 |
+
&& m_score_type != CBLM_SCORE_TYPE_EXPONENTIAL
|
| 410 |
+
&& m_score_type != CBLM_SCORE_TYPE_COSINE
|
| 411 |
+
&& m_score_type != CBLM_SCORE_TYPE_HYPERBOLA_REWARD
|
| 412 |
+
&& m_score_type != CBLM_SCORE_TYPE_POWER_REWARD
|
| 413 |
+
&& m_score_type != CBLM_SCORE_TYPE_EXPONENTIAL_REWARD ) {
|
| 414 |
+
VERBOSE(2, "This score type " << m_score_type << " is unknown. Instead used " << CBLM_SCORE_TYPE_HYPERBOLA << "." << std::endl);
|
| 415 |
+
m_score_type = CBLM_SCORE_TYPE_HYPERBOLA;
|
| 416 |
+
}
|
| 417 |
+
VERBOSE(2, "CacheBasedLanguageModel ScoreType: " << m_score_type << std::endl);
|
| 418 |
+
};
|
| 419 |
+
|
| 420 |
+
void DynamicCacheBasedLanguageModel::SetMaxAge(unsigned int age)
|
| 421 |
+
{
|
| 422 |
+
#ifdef WITH_THREADS
|
| 423 |
+
boost::shared_lock<boost::shared_mutex> read_lock(m_cacheLock);
|
| 424 |
+
#endif
|
| 425 |
+
m_maxAge = age;
|
| 426 |
+
VERBOSE(2, "CacheBasedLanguageModel MaxAge: " << m_maxAge << std::endl);
|
| 427 |
+
};
|
| 428 |
+
|
| 429 |
+
float DynamicCacheBasedLanguageModel::decaying_score(const unsigned int age)
|
| 430 |
+
{
|
| 431 |
+
float sc;
|
| 432 |
+
switch(m_score_type) {
|
| 433 |
+
case CBLM_SCORE_TYPE_HYPERBOLA:
|
| 434 |
+
sc = (float) 1.0/age - 1.0;
|
| 435 |
+
break;
|
| 436 |
+
case CBLM_SCORE_TYPE_POWER:
|
| 437 |
+
sc = (float) pow(age, -0.25) - 1.0;
|
| 438 |
+
break;
|
| 439 |
+
case CBLM_SCORE_TYPE_EXPONENTIAL:
|
| 440 |
+
sc = (age == 1) ? 0.0 : (float) exp( 1.0/age ) / exp(1.0) - 1.0;
|
| 441 |
+
break;
|
| 442 |
+
case CBLM_SCORE_TYPE_COSINE:
|
| 443 |
+
sc = (float) cos( (age-1) * (PI/2) / m_maxAge ) - 1.0;
|
| 444 |
+
break;
|
| 445 |
+
case CBLM_SCORE_TYPE_HYPERBOLA_REWARD:
|
| 446 |
+
sc = (float) 1.0/age;
|
| 447 |
+
break;
|
| 448 |
+
case CBLM_SCORE_TYPE_POWER_REWARD:
|
| 449 |
+
sc = (float) pow(age, -0.25);
|
| 450 |
+
break;
|
| 451 |
+
case CBLM_SCORE_TYPE_EXPONENTIAL_REWARD:
|
| 452 |
+
sc = (age == 1) ? 1.0 : (float) exp( 1.0/age ) / exp(1.0);
|
| 453 |
+
break;
|
| 454 |
+
default:
|
| 455 |
+
sc = -1.0;
|
| 456 |
+
}
|
| 457 |
+
return sc;
|
| 458 |
+
}
|
| 459 |
+
}
|
mosesdecoder/moses/FF/DynamicCacheBasedLanguageModel.h
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// $Id$
|
| 2 |
+
|
| 3 |
+
#ifndef moses_DynamicCacheBasedLanguageModel_h
|
| 4 |
+
#define moses_DynamicCacheBasedLanguageModel_h
|
| 5 |
+
|
| 6 |
+
#include "moses/Util.h"
|
| 7 |
+
#include "FeatureFunction.h"
|
| 8 |
+
|
| 9 |
+
#ifdef WITH_THREADS
|
| 10 |
+
#include <boost/thread/shared_mutex.hpp>
|
| 11 |
+
#include <boost/thread/locks.hpp>
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
typedef std::pair<int, float> decaying_cache_value_t;
|
| 15 |
+
typedef std::map<std::string, decaying_cache_value_t > decaying_cache_t;
|
| 16 |
+
|
| 17 |
+
#define CBLM_QUERY_TYPE_UNDEFINED (-1)
|
| 18 |
+
#define CBLM_QUERY_TYPE_ALLSUBSTRINGS 0
|
| 19 |
+
#define CBLM_QUERY_TYPE_WHOLESTRING 1
|
| 20 |
+
|
| 21 |
+
#define CBLM_SCORE_TYPE_UNDEFINED (-1)
|
| 22 |
+
#define CBLM_SCORE_TYPE_HYPERBOLA 0
|
| 23 |
+
#define CBLM_SCORE_TYPE_POWER 1
|
| 24 |
+
#define CBLM_SCORE_TYPE_EXPONENTIAL 2
|
| 25 |
+
#define CBLM_SCORE_TYPE_COSINE 3
|
| 26 |
+
#define CBLM_SCORE_TYPE_HYPERBOLA_REWARD 10
|
| 27 |
+
#define CBLM_SCORE_TYPE_POWER_REWARD 11
|
| 28 |
+
#define CBLM_SCORE_TYPE_EXPONENTIAL_REWARD 12
|
| 29 |
+
#define PI 3.14159265
|
| 30 |
+
|
| 31 |
+
namespace Moses
|
| 32 |
+
{
|
| 33 |
+
|
| 34 |
+
class Range;
|
| 35 |
+
|
| 36 |
+
/** Calculates score for the Dynamic Cache-Based pseudo LM
|
| 37 |
+
*/
|
| 38 |
+
class DynamicCacheBasedLanguageModel : public StatelessFeatureFunction
|
| 39 |
+
{
|
| 40 |
+
// data structure for the cache;
|
| 41 |
+
// the key is the word and the value is the decaying score
|
| 42 |
+
decaying_cache_t m_cache;
|
| 43 |
+
size_t m_query_type; //way of querying the cache
|
| 44 |
+
size_t m_score_type; //way of scoring entries of the cache
|
| 45 |
+
std::string m_initfiles; // vector of files loaded in the initialization phase
|
| 46 |
+
std::string m_name; // internal name to identify this instance of the Cache-based pseudo LM
|
| 47 |
+
float m_lower_score; //lower_bound_score for no match
|
| 48 |
+
bool m_constant; //flag for setting a non-decaying cache
|
| 49 |
+
std::vector<float> precomputedScores;
|
| 50 |
+
unsigned int m_maxAge;
|
| 51 |
+
|
| 52 |
+
#ifdef WITH_THREADS
|
| 53 |
+
//multiple readers - single writer lock
|
| 54 |
+
mutable boost::shared_mutex m_cacheLock;
|
| 55 |
+
#endif
|
| 56 |
+
|
| 57 |
+
float decaying_score(unsigned int age);
|
| 58 |
+
void SetPreComputedScores();
|
| 59 |
+
float GetPreComputedScores(const unsigned int age);
|
| 60 |
+
|
| 61 |
+
float Evaluate_Whole_String( const TargetPhrase&) const;
|
| 62 |
+
float Evaluate_All_Substrings( const TargetPhrase&) const;
|
| 63 |
+
|
| 64 |
+
void Decay();
|
| 65 |
+
void Update(std::vector<std::string> words, int age);
|
| 66 |
+
|
| 67 |
+
void ClearEntries(std::vector<std::string> entries);
|
| 68 |
+
|
| 69 |
+
void Execute(std::vector<std::string> commands);
|
| 70 |
+
void Execute_Single_Command(std::string command);
|
| 71 |
+
|
| 72 |
+
void Load_Multiple_Files(std::vector<std::string> files);
|
| 73 |
+
void Load_Single_File(const std::string file);
|
| 74 |
+
|
| 75 |
+
void Insert(std::vector<std::string> ngrams);
|
| 76 |
+
|
| 77 |
+
// void EvaluateInIsolation(const Phrase&, const TargetPhrase&, ScoreComponentCollection&, ScoreComponentCollection& ) const;
|
| 78 |
+
void Print() const;
|
| 79 |
+
|
| 80 |
+
protected:
|
| 81 |
+
static DynamicCacheBasedLanguageModel* s_instance;
|
| 82 |
+
static std::map< const std::string, DynamicCacheBasedLanguageModel* > s_instance_map;
|
| 83 |
+
|
| 84 |
+
public:
|
| 85 |
+
DynamicCacheBasedLanguageModel(const std::string &line);
|
| 86 |
+
~DynamicCacheBasedLanguageModel();
|
| 87 |
+
|
| 88 |
+
inline const std::string GetName() {
|
| 89 |
+
return m_name;
|
| 90 |
+
};
|
| 91 |
+
inline void SetName(const std::string name) {
|
| 92 |
+
m_name = name;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
static const DynamicCacheBasedLanguageModel* Instance(const std::string& name) {
|
| 96 |
+
if (s_instance_map.find(name) == s_instance_map.end()) {
|
| 97 |
+
return NULL;
|
| 98 |
+
}
|
| 99 |
+
return s_instance_map[name];
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
static DynamicCacheBasedLanguageModel* InstanceNonConst(const std::string& name) {
|
| 103 |
+
if (s_instance_map.find(name) == s_instance_map.end()) {
|
| 104 |
+
return NULL;
|
| 105 |
+
}
|
| 106 |
+
return s_instance_map[name];
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
static const DynamicCacheBasedLanguageModel& Instance() {
|
| 112 |
+
return *s_instance;
|
| 113 |
+
}
|
| 114 |
+
static DynamicCacheBasedLanguageModel& InstanceNonConst() {
|
| 115 |
+
return *s_instance;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
bool IsUseable(const FactorMask &mask) const {
|
| 119 |
+
return true;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
void Load(AllOptions::ptr const& opts);
|
| 123 |
+
void Load(const std::string filestr);
|
| 124 |
+
void Execute(std::string command);
|
| 125 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 126 |
+
void ExecuteDlt(std::map<std::string, std::string> dlt_meta);
|
| 127 |
+
|
| 128 |
+
void ClearEntries(std::string &entries);
|
| 129 |
+
void Insert(std::string &entries);
|
| 130 |
+
void Clear();
|
| 131 |
+
|
| 132 |
+
virtual void EvaluateInIsolation(const Phrase &source
|
| 133 |
+
, const TargetPhrase &targetPhrase
|
| 134 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 135 |
+
, ScoreComponentCollection &estimatedScores) const;
|
| 136 |
+
|
| 137 |
+
void EvaluateWithSourceContext(const InputType &input
|
| 138 |
+
, const InputPath &inputPath
|
| 139 |
+
, const TargetPhrase &targetPhrase
|
| 140 |
+
, const StackVec *stackVec
|
| 141 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 142 |
+
, ScoreComponentCollection *estimatedScores = NULL) const {
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 146 |
+
, const TranslationOptionList &translationOptionList) const {
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 150 |
+
ScoreComponentCollection* accumulator) const {
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
void EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 154 |
+
ScoreComponentCollection* accumulator) const {
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
void SetQueryType(size_t type);
|
| 158 |
+
void SetScoreType(size_t type);
|
| 159 |
+
void SetMaxAge(unsigned int age);
|
| 160 |
+
};
|
| 161 |
+
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
#endif
|
mosesdecoder/moses/FF/EditOps.h
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef moses_EditOps_h
|
| 2 |
+
#define moses_EditOps_h
|
| 3 |
+
|
| 4 |
+
#include <string>
|
| 5 |
+
#include <boost/unordered_set.hpp>
|
| 6 |
+
|
| 7 |
+
#include "StatelessFeatureFunction.h"
|
| 8 |
+
#include "moses/FactorCollection.h"
|
| 9 |
+
#include "moses/AlignmentInfo.h"
|
| 10 |
+
|
| 11 |
+
namespace Moses
|
| 12 |
+
{
|
| 13 |
+
|
| 14 |
+
typedef std::vector<std::string> Tokens;
|
| 15 |
+
|
| 16 |
+
/** Calculates string edit operations that transform source phrase into target
|
| 17 |
+
* phrase using the LCS algorithm. Potentially usefule for monolingual tasks
|
| 18 |
+
* like paraphrasing, summarization, correction.
|
| 19 |
+
*/
|
| 20 |
+
class EditOps : public StatelessFeatureFunction
|
| 21 |
+
{
|
| 22 |
+
private:
|
| 23 |
+
FactorType m_factorType;
|
| 24 |
+
bool m_chars;
|
| 25 |
+
std::string m_scores;
|
| 26 |
+
|
| 27 |
+
public:
|
| 28 |
+
EditOps(const std::string &line);
|
| 29 |
+
|
| 30 |
+
bool IsUseable(const FactorMask &mask) const;
|
| 31 |
+
|
| 32 |
+
void Load();
|
| 33 |
+
|
| 34 |
+
virtual void EvaluateInIsolation(const Phrase &source
|
| 35 |
+
, const TargetPhrase &targetPhrase
|
| 36 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 37 |
+
, ScoreComponentCollection &estimatedFutureScore) const;
|
| 38 |
+
|
| 39 |
+
void EvaluateWithSourceContext(const InputType &input
|
| 40 |
+
, const InputPath &inputPath
|
| 41 |
+
, const TargetPhrase &targetPhrase
|
| 42 |
+
, const StackVec *stackVec
|
| 43 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 44 |
+
, ScoreComponentCollection *estimatedFutureScore = NULL) const
|
| 45 |
+
{}
|
| 46 |
+
void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 47 |
+
ScoreComponentCollection* accumulator) const
|
| 48 |
+
{}
|
| 49 |
+
void EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 50 |
+
ScoreComponentCollection* accumulator) const
|
| 51 |
+
{}
|
| 52 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 53 |
+
, const TranslationOptionList &translationOptionList) const
|
| 54 |
+
{}
|
| 55 |
+
|
| 56 |
+
void ComputeFeatures(const Phrase &source,
|
| 57 |
+
const TargetPhrase& targetPhrase,
|
| 58 |
+
ScoreComponentCollection* accumulator) const;
|
| 59 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
#endif // moses_CorrectionPattern_h
|
mosesdecoder/moses/FF/ExampleStatelessFF.cpp
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <vector>
|
| 2 |
+
#include "ExampleStatelessFF.h"
|
| 3 |
+
#include "moses/ScoreComponentCollection.h"
|
| 4 |
+
#include "moses/TargetPhrase.h"
|
| 5 |
+
|
| 6 |
+
using namespace std;
|
| 7 |
+
|
| 8 |
+
namespace Moses
|
| 9 |
+
{
|
| 10 |
+
ExampleStatelessFF::ExampleStatelessFF(const std::string &line)
|
| 11 |
+
:StatelessFeatureFunction(2, line)
|
| 12 |
+
{
|
| 13 |
+
ReadParameters();
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
void ExampleStatelessFF::EvaluateInIsolation(const Phrase &source
|
| 17 |
+
, const TargetPhrase &targetPhrase
|
| 18 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 19 |
+
, ScoreComponentCollection &estimatedScores) const
|
| 20 |
+
{
|
| 21 |
+
// dense scores
|
| 22 |
+
vector<float> newScores(m_numScoreComponents);
|
| 23 |
+
newScores[0] = 1.5;
|
| 24 |
+
newScores[1] = 0.3;
|
| 25 |
+
scoreBreakdown.PlusEquals(this, newScores);
|
| 26 |
+
|
| 27 |
+
// sparse scores
|
| 28 |
+
scoreBreakdown.PlusEquals(this, "sparse-name", 2.4);
|
| 29 |
+
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
void ExampleStatelessFF::EvaluateWithSourceContext(const InputType &input
|
| 33 |
+
, const InputPath &inputPath
|
| 34 |
+
, const TargetPhrase &targetPhrase
|
| 35 |
+
, const StackVec *stackVec
|
| 36 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 37 |
+
, ScoreComponentCollection *estimatedScores) const
|
| 38 |
+
{
|
| 39 |
+
if (targetPhrase.GetNumNonTerminals()) {
|
| 40 |
+
vector<float> newScores(m_numScoreComponents);
|
| 41 |
+
newScores[0] = - std::numeric_limits<float>::infinity();
|
| 42 |
+
scoreBreakdown.PlusEquals(this, newScores);
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
void ExampleStatelessFF::EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 47 |
+
|
| 48 |
+
, const TranslationOptionList &translationOptionList) const
|
| 49 |
+
{}
|
| 50 |
+
|
| 51 |
+
void ExampleStatelessFF::EvaluateWhenApplied(const Hypothesis& hypo,
|
| 52 |
+
ScoreComponentCollection* accumulator) const
|
| 53 |
+
{}
|
| 54 |
+
|
| 55 |
+
void ExampleStatelessFF::EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 56 |
+
ScoreComponentCollection* accumulator) const
|
| 57 |
+
{}
|
| 58 |
+
|
| 59 |
+
void ExampleStatelessFF::SetParameter(const std::string& key, const std::string& value)
|
| 60 |
+
{
|
| 61 |
+
if (key == "arg") {
|
| 62 |
+
// set value here
|
| 63 |
+
} else {
|
| 64 |
+
StatelessFeatureFunction::SetParameter(key, value);
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
}
|
| 69 |
+
|
mosesdecoder/moses/FF/ExampleTranslationOptionListFeature.h
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <string>
|
| 4 |
+
#include "StatelessFeatureFunction.h"
|
| 5 |
+
|
| 6 |
+
namespace Moses
|
| 7 |
+
{
|
| 8 |
+
|
| 9 |
+
class ExampleTranslationOptionListFeature : public StatelessFeatureFunction
|
| 10 |
+
{
|
| 11 |
+
public:
|
| 12 |
+
ExampleTranslationOptionListFeature(const std::string &line)
|
| 13 |
+
:StatelessFeatureFunction(1, line) {
|
| 14 |
+
ReadParameters();
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
bool IsUseable(const FactorMask &mask) const {
|
| 18 |
+
return true;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
void EvaluateInIsolation(const Phrase &source
|
| 22 |
+
, const TargetPhrase &targetPhrase
|
| 23 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 24 |
+
, ScoreComponentCollection &estimatedFutureScore) const {
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
void EvaluateWithSourceContext(const InputType &input
|
| 28 |
+
, const InputPath &inputPath
|
| 29 |
+
, const TargetPhrase &targetPhrase
|
| 30 |
+
, const StackVec *stackVec
|
| 31 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 32 |
+
, ScoreComponentCollection *estimatedFutureScore = NULL) const {
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 36 |
+
, const TranslationOptionList &translationOptionList) const {
|
| 37 |
+
std::vector<float> newScores(m_numScoreComponents);
|
| 38 |
+
newScores[0] = translationOptionList.size();
|
| 39 |
+
|
| 40 |
+
TranslationOptionList::const_iterator iterTransOpt;
|
| 41 |
+
for(iterTransOpt = translationOptionList.begin() ;
|
| 42 |
+
iterTransOpt != translationOptionList.end() ; ++iterTransOpt) {
|
| 43 |
+
TranslationOption &transOpt = **iterTransOpt;
|
| 44 |
+
|
| 45 |
+
ScoreComponentCollection &scoreBreakDown = transOpt.GetScoreBreakdown();
|
| 46 |
+
scoreBreakDown.PlusEquals(this, newScores);
|
| 47 |
+
|
| 48 |
+
transOpt.UpdateScore();
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 53 |
+
ScoreComponentCollection* accumulator) const {
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
void EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 57 |
+
ScoreComponentCollection* accumulator) const {
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
void SetParameter(const std::string& key, const std::string& value) {
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
}
|
| 67 |
+
|
mosesdecoder/moses/FF/FFState.cpp
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "moses/FF/FFState.h"
|
| 2 |
+
|
| 3 |
+
namespace Moses
|
| 4 |
+
{
|
| 5 |
+
|
| 6 |
+
FFState::~FFState() {}
|
| 7 |
+
|
| 8 |
+
}
|
| 9 |
+
|
mosesdecoder/moses/FF/FFState.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef moses_FFState_h
|
| 2 |
+
#define moses_FFState_h
|
| 3 |
+
|
| 4 |
+
#include <vector>
|
| 5 |
+
#include <stddef.h>
|
| 6 |
+
#include "util/exception.hh"
|
| 7 |
+
|
| 8 |
+
namespace Moses
|
| 9 |
+
{
|
| 10 |
+
|
| 11 |
+
class FFState
|
| 12 |
+
{
|
| 13 |
+
public:
|
| 14 |
+
virtual ~FFState();
|
| 15 |
+
virtual size_t hash() const = 0;
|
| 16 |
+
virtual bool operator==(const FFState& other) const = 0;
|
| 17 |
+
|
| 18 |
+
virtual bool operator!=(const FFState& other) const {
|
| 19 |
+
return !(*this == other);
|
| 20 |
+
}
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
class DummyState : public FFState
|
| 24 |
+
{
|
| 25 |
+
public:
|
| 26 |
+
DummyState() {}
|
| 27 |
+
|
| 28 |
+
virtual size_t hash() const {
|
| 29 |
+
return 0;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
virtual bool operator==(const FFState& other) const {
|
| 33 |
+
return true;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
}
|
| 39 |
+
#endif
|
mosesdecoder/moses/FF/FeatureFunction.h
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// -*- c++ -*-
|
| 2 |
+
#ifndef moses_FeatureFunction_h
|
| 3 |
+
#define moses_FeatureFunction_h
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <set>
|
| 7 |
+
#include <string>
|
| 8 |
+
#include "moses/FeatureVector.h"
|
| 9 |
+
#include "moses/TypeDef.h"
|
| 10 |
+
#include "moses/parameters/AllOptions.h"
|
| 11 |
+
#include <boost/shared_ptr.hpp>
|
| 12 |
+
|
| 13 |
+
namespace Moses
|
| 14 |
+
{
|
| 15 |
+
|
| 16 |
+
class AllOptions;
|
| 17 |
+
class Phrase;
|
| 18 |
+
class TargetPhrase;
|
| 19 |
+
class TranslationOptionList;
|
| 20 |
+
class TranslationOption;
|
| 21 |
+
class Hypothesis;
|
| 22 |
+
class ChartHypothesis;
|
| 23 |
+
class InputType;
|
| 24 |
+
class ScoreComponentCollection;
|
| 25 |
+
class Bitmap;
|
| 26 |
+
class Range;
|
| 27 |
+
class FactorMask;
|
| 28 |
+
class InputPath;
|
| 29 |
+
class StackVec;
|
| 30 |
+
class DistortionScoreProducer;
|
| 31 |
+
class TranslationTask;
|
| 32 |
+
|
| 33 |
+
/** base class for all feature functions.
|
| 34 |
+
*/
|
| 35 |
+
class FeatureFunction
|
| 36 |
+
{
|
| 37 |
+
protected:
|
| 38 |
+
/**< all the score producers in this run */
|
| 39 |
+
static std::vector<FeatureFunction*> s_staticColl;
|
| 40 |
+
|
| 41 |
+
std::string m_description, m_argLine;
|
| 42 |
+
std::vector<std::vector<std::string> > m_args;
|
| 43 |
+
bool m_tuneable;
|
| 44 |
+
bool m_requireSortingAfterSourceContext;
|
| 45 |
+
size_t m_verbosity;
|
| 46 |
+
size_t m_numScoreComponents;
|
| 47 |
+
size_t m_index; // index into vector covering ALL feature function values
|
| 48 |
+
std::vector<bool> m_tuneableComponents;
|
| 49 |
+
size_t m_numTuneableComponents;
|
| 50 |
+
AllOptions::ptr m_options;
|
| 51 |
+
//In case there's multiple producers with the same description
|
| 52 |
+
static std::multiset<std::string> description_counts;
|
| 53 |
+
|
| 54 |
+
public:
|
| 55 |
+
static void Register(FeatureFunction* ff);
|
| 56 |
+
private:
|
| 57 |
+
// void Initialize(const std::string &line);
|
| 58 |
+
void ParseLine(const std::string &line);
|
| 59 |
+
|
| 60 |
+
public:
|
| 61 |
+
static const std::vector<FeatureFunction*>& GetFeatureFunctions() {
|
| 62 |
+
return s_staticColl;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
static FeatureFunction &FindFeatureFunction(const std::string& name);
|
| 66 |
+
static void Destroy();
|
| 67 |
+
|
| 68 |
+
FeatureFunction(const std::string &line, bool registerNow);
|
| 69 |
+
FeatureFunction(size_t numScoreComponents, const std::string &line, bool registerNow = true);
|
| 70 |
+
virtual bool IsStateless() const = 0;
|
| 71 |
+
virtual ~FeatureFunction();
|
| 72 |
+
|
| 73 |
+
//! override to load model files
|
| 74 |
+
virtual void Load(AllOptions::ptr const& opts) {
|
| 75 |
+
m_options = opts;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
AllOptions::ptr const&
|
| 79 |
+
options() const {
|
| 80 |
+
return m_options;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
static void ResetDescriptionCounts() {
|
| 84 |
+
description_counts.clear();
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
//! returns the number of scores that a subclass produces.
|
| 88 |
+
//! For example, a language model conventionally produces 1, a translation table some arbitrary number, etc
|
| 89 |
+
size_t GetNumScoreComponents() const {
|
| 90 |
+
return m_numScoreComponents;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
//! returns a string description of this producer
|
| 94 |
+
const std::string& GetScoreProducerDescription() const {
|
| 95 |
+
return m_description;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
FName GetFeatureName(const std::string& name) const {
|
| 99 |
+
return FName(GetScoreProducerDescription(), name);
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
//! if false, then this feature is not displayed in the n-best list.
|
| 104 |
+
// use with care
|
| 105 |
+
virtual bool IsTuneable() const {
|
| 106 |
+
return m_tuneable;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
virtual bool HasTuneableComponents() const {
|
| 110 |
+
return m_numTuneableComponents;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
virtual bool IsTuneableComponent(size_t i) const {
|
| 114 |
+
if (m_numTuneableComponents == m_numScoreComponents) {
|
| 115 |
+
return true;
|
| 116 |
+
}
|
| 117 |
+
return m_tuneableComponents[i];
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
virtual bool RequireSortingAfterSourceContext() const {
|
| 121 |
+
return m_requireSortingAfterSourceContext;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
virtual std::vector<float> DefaultWeights() const;
|
| 125 |
+
|
| 126 |
+
size_t GetIndex() const;
|
| 127 |
+
size_t SetIndex(size_t const idx);
|
| 128 |
+
|
| 129 |
+
protected:
|
| 130 |
+
virtual void
|
| 131 |
+
CleanUpAfterSentenceProcessing(InputType const& source) { }
|
| 132 |
+
|
| 133 |
+
public:
|
| 134 |
+
//! Called before search and collecting of translation options
|
| 135 |
+
virtual void
|
| 136 |
+
InitializeForInput(ttasksptr const& ttask) { };
|
| 137 |
+
|
| 138 |
+
// clean up temporary memory, called after processing each sentence
|
| 139 |
+
virtual void
|
| 140 |
+
CleanUpAfterSentenceProcessing(ttasksptr const& ttask);
|
| 141 |
+
|
| 142 |
+
const std::string &
|
| 143 |
+
GetArgLine() const {
|
| 144 |
+
return m_argLine;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// given a target phrase containing only factors specified in mask
|
| 148 |
+
// return true if the feature function can be evaluated
|
| 149 |
+
virtual bool IsUseable(const FactorMask &mask) const = 0;
|
| 150 |
+
|
| 151 |
+
// used by stateless ff and stateful ff. Calculate initial score
|
| 152 |
+
// estimate during loading of phrase table
|
| 153 |
+
//
|
| 154 |
+
// source phrase is the substring that the phrase table uses to look
|
| 155 |
+
// up the target phrase,
|
| 156 |
+
//
|
| 157 |
+
// may have more factors than actually need, but not guaranteed.
|
| 158 |
+
// For SCFG decoding, the source contains non-terminals, NOT the raw
|
| 159 |
+
// source from the input sentence
|
| 160 |
+
virtual void
|
| 161 |
+
EvaluateInIsolation(const Phrase &source, const TargetPhrase &targetPhrase,
|
| 162 |
+
ScoreComponentCollection& scoreBreakdown,
|
| 163 |
+
ScoreComponentCollection& estimatedScores) const = 0;
|
| 164 |
+
|
| 165 |
+
// for context-dependent processing
|
| 166 |
+
static void SetupAll(TranslationTask const& task);
|
| 167 |
+
virtual void Setup(TranslationTask const& task) const { };
|
| 168 |
+
|
| 169 |
+
// This method is called once all the translation options are retrieved from the phrase table, and
|
| 170 |
+
// just before search.
|
| 171 |
+
// 'inputPath' is guaranteed to be the raw substring from the input. No factors were added or taken away
|
| 172 |
+
// 'stackVec' is a vector of chart cells that the RHS non-terms cover.
|
| 173 |
+
// It is guaranteed to be in the same order as the non-terms in the source phrase.
|
| 174 |
+
// For pb models, stackvec is NULL.
|
| 175 |
+
// No FF should set estimatedScores in both overloads!
|
| 176 |
+
virtual void EvaluateWithSourceContext(const InputType &input
|
| 177 |
+
, const InputPath &inputPath
|
| 178 |
+
, const TargetPhrase &targetPhrase
|
| 179 |
+
, const StackVec *stackVec
|
| 180 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 181 |
+
, ScoreComponentCollection *estimatedScores = NULL) const = 0;
|
| 182 |
+
|
| 183 |
+
// This method is called once all the translation options are retrieved from the phrase table, and
|
| 184 |
+
// just before search.
|
| 185 |
+
// 'inputPath' is guaranteed to be the raw substring from the input. No factors were added or taken away
|
| 186 |
+
// 'stackVec' is a vector of chart cells that the RHS non-terms cover.
|
| 187 |
+
// It is guaranteed to be in the same order as the non-terms in the source phrase.
|
| 188 |
+
// For pb models, stackvec is NULL.
|
| 189 |
+
// No FF should set estimatedScores in both overloads!
|
| 190 |
+
virtual void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 191 |
+
, const TranslationOptionList &translationOptionList) const = 0;
|
| 192 |
+
|
| 193 |
+
virtual void SetParameter(const std::string& key, const std::string& value);
|
| 194 |
+
virtual void ReadParameters();
|
| 195 |
+
virtual void SetTuneableComponents(const std::string& value);
|
| 196 |
+
};
|
| 197 |
+
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
#endif
|
mosesdecoder/moses/FF/GlobalLexicalModel.h
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef moses_GlobalLexicalModel_h
|
| 2 |
+
#define moses_GlobalLexicalModel_h
|
| 3 |
+
|
| 4 |
+
#include <stdexcept>
|
| 5 |
+
#include <string>
|
| 6 |
+
#include <vector>
|
| 7 |
+
#include <memory>
|
| 8 |
+
#include "StatelessFeatureFunction.h"
|
| 9 |
+
#include "moses/Factor.h"
|
| 10 |
+
#include "moses/Phrase.h"
|
| 11 |
+
#include "moses/TypeDef.h"
|
| 12 |
+
#include "moses/Util.h"
|
| 13 |
+
#include "moses/Range.h"
|
| 14 |
+
#include "moses/FactorTypeSet.h"
|
| 15 |
+
#include "moses/Sentence.h"
|
| 16 |
+
|
| 17 |
+
#ifdef WITH_THREADS
|
| 18 |
+
#include <boost/thread/tss.hpp>
|
| 19 |
+
#endif
|
| 20 |
+
|
| 21 |
+
namespace Moses
|
| 22 |
+
{
|
| 23 |
+
|
| 24 |
+
class Factor;
|
| 25 |
+
class Phrase;
|
| 26 |
+
class Hypothesis;
|
| 27 |
+
class InputType;
|
| 28 |
+
|
| 29 |
+
/** Discriminatively trained global lexicon model
|
| 30 |
+
* This is a implementation of Mauser et al., 2009's model that predicts
|
| 31 |
+
* each output word from _all_ the input words. The intuition behind this
|
| 32 |
+
* feature is that it uses context words for disambiguation
|
| 33 |
+
*/
|
| 34 |
+
class GlobalLexicalModel : public StatelessFeatureFunction
|
| 35 |
+
{
|
| 36 |
+
typedef boost::unordered_map< const Word*,
|
| 37 |
+
boost::unordered_map< const Word*, float, UnorderedComparer<Word> , UnorderedComparer<Word> >,
|
| 38 |
+
UnorderedComparer<Word>, UnorderedComparer<Word> > DoubleHash;
|
| 39 |
+
typedef boost::unordered_map< const Word*, float, UnorderedComparer<Word>, UnorderedComparer<Word> > SingleHash;
|
| 40 |
+
typedef std::map< const TargetPhrase*, float > LexiconCache;
|
| 41 |
+
|
| 42 |
+
struct ThreadLocalStorage {
|
| 43 |
+
LexiconCache cache;
|
| 44 |
+
const Sentence *input;
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
private:
|
| 48 |
+
DoubleHash m_hash;
|
| 49 |
+
#ifdef WITH_THREADS
|
| 50 |
+
boost::thread_specific_ptr<ThreadLocalStorage> m_local;
|
| 51 |
+
#else
|
| 52 |
+
std::auto_ptr<ThreadLocalStorage> m_local;
|
| 53 |
+
#endif
|
| 54 |
+
Word *m_bias;
|
| 55 |
+
|
| 56 |
+
FactorMask m_inputFactors, m_outputFactors;
|
| 57 |
+
std::vector<FactorType> m_inputFactorsVec, m_outputFactorsVec;
|
| 58 |
+
std::string m_filePath;
|
| 59 |
+
|
| 60 |
+
void Load(AllOptions::ptr const& opts);
|
| 61 |
+
|
| 62 |
+
float ScorePhrase( const TargetPhrase& targetPhrase ) const;
|
| 63 |
+
float GetFromCacheOrScorePhrase( const TargetPhrase& targetPhrase ) const;
|
| 64 |
+
|
| 65 |
+
public:
|
| 66 |
+
GlobalLexicalModel(const std::string &line);
|
| 67 |
+
virtual ~GlobalLexicalModel();
|
| 68 |
+
|
| 69 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 70 |
+
|
| 71 |
+
void InitializeForInput(ttasksptr const& ttask);
|
| 72 |
+
|
| 73 |
+
bool IsUseable(const FactorMask &mask) const;
|
| 74 |
+
|
| 75 |
+
void EvaluateInIsolation(const Phrase &source
|
| 76 |
+
, const TargetPhrase &targetPhrase
|
| 77 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 78 |
+
, ScoreComponentCollection &estimatedScores) const {
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 82 |
+
ScoreComponentCollection* accumulator) const {
|
| 83 |
+
}
|
| 84 |
+
void EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 85 |
+
ScoreComponentCollection* accumulator) const {
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
void EvaluateWithSourceContext(const InputType &input
|
| 89 |
+
, const InputPath &inputPath
|
| 90 |
+
, const TargetPhrase &targetPhrase
|
| 91 |
+
, const StackVec *stackVec
|
| 92 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 93 |
+
, ScoreComponentCollection *estimatedScores = NULL) const;
|
| 94 |
+
|
| 95 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 96 |
+
, const TranslationOptionList &translationOptionList) const {
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
};
|
| 100 |
+
|
| 101 |
+
}
|
| 102 |
+
#endif
|
mosesdecoder/moses/FF/GlobalLexicalModelUnlimited.cpp
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "GlobalLexicalModelUnlimited.h"
|
| 2 |
+
#include <fstream>
|
| 3 |
+
#include "moses/StaticData.h"
|
| 4 |
+
#include "moses/InputFileStream.h"
|
| 5 |
+
#include "moses/Hypothesis.h"
|
| 6 |
+
#include "moses/TranslationTask.h"
|
| 7 |
+
#include "util/string_piece_hash.hh"
|
| 8 |
+
#include "util/string_stream.hh"
|
| 9 |
+
|
| 10 |
+
using namespace std;
|
| 11 |
+
|
| 12 |
+
namespace Moses
|
| 13 |
+
{
|
| 14 |
+
GlobalLexicalModelUnlimited::GlobalLexicalModelUnlimited(const std::string &line)
|
| 15 |
+
:StatelessFeatureFunction(0, line)
|
| 16 |
+
{
|
| 17 |
+
UTIL_THROW(util::Exception,
|
| 18 |
+
"GlobalLexicalModelUnlimited hasn't been refactored for new feature function framework yet"); // TODO need to update arguments to key=value
|
| 19 |
+
|
| 20 |
+
const vector<string> modelSpec = Tokenize(line);
|
| 21 |
+
|
| 22 |
+
for (size_t i = 0; i < modelSpec.size(); i++ ) {
|
| 23 |
+
bool ignorePunctuation = true, biasFeature = false, restricted = false;
|
| 24 |
+
size_t context = 0;
|
| 25 |
+
string filenameSource, filenameTarget;
|
| 26 |
+
vector< string > factors;
|
| 27 |
+
vector< string > spec = Tokenize(modelSpec[i]," ");
|
| 28 |
+
|
| 29 |
+
// read optional punctuation and bias specifications
|
| 30 |
+
if (spec.size() > 0) {
|
| 31 |
+
if (spec.size() != 2 && spec.size() != 3 && spec.size() != 4 && spec.size() != 6) {
|
| 32 |
+
std::cerr << "Format of glm feature is <factor-src>-<factor-tgt> [ignore-punct] [use-bias] "
|
| 33 |
+
<< "[context-type] [filename-src filename-tgt]";
|
| 34 |
+
//return false;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
factors = Tokenize(spec[0],"-");
|
| 38 |
+
if (spec.size() >= 2)
|
| 39 |
+
ignorePunctuation = Scan<size_t>(spec[1]);
|
| 40 |
+
if (spec.size() >= 3)
|
| 41 |
+
biasFeature = Scan<size_t>(spec[2]);
|
| 42 |
+
if (spec.size() >= 4)
|
| 43 |
+
context = Scan<size_t>(spec[3]);
|
| 44 |
+
if (spec.size() == 6) {
|
| 45 |
+
filenameSource = spec[4];
|
| 46 |
+
filenameTarget = spec[5];
|
| 47 |
+
restricted = true;
|
| 48 |
+
}
|
| 49 |
+
} else
|
| 50 |
+
factors = Tokenize(modelSpec[i],"-");
|
| 51 |
+
|
| 52 |
+
if ( factors.size() != 2 ) {
|
| 53 |
+
std::cerr << "Wrong factor definition for global lexical model unlimited: " << modelSpec[i];
|
| 54 |
+
//return false;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
const vector<FactorType> inputFactors = Tokenize<FactorType>(factors[0],",");
|
| 58 |
+
const vector<FactorType> outputFactors = Tokenize<FactorType>(factors[1],",");
|
| 59 |
+
throw runtime_error("GlobalLexicalModelUnlimited should be reimplemented as a stateful feature");
|
| 60 |
+
GlobalLexicalModelUnlimited* glmu = NULL; // new GlobalLexicalModelUnlimited(inputFactors, outputFactors, biasFeature, ignorePunctuation, context);
|
| 61 |
+
|
| 62 |
+
if (restricted) {
|
| 63 |
+
cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl;
|
| 64 |
+
if (!glmu->Load(filenameSource, filenameTarget)) {
|
| 65 |
+
std::cerr << "Unable to load word lists for word translation feature from files "
|
| 66 |
+
<< filenameSource
|
| 67 |
+
<< " and "
|
| 68 |
+
<< filenameTarget;
|
| 69 |
+
//return false;
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
bool GlobalLexicalModelUnlimited::Load(const std::string &filePathSource,
|
| 76 |
+
const std::string &filePathTarget)
|
| 77 |
+
{
|
| 78 |
+
// restricted source word vocabulary
|
| 79 |
+
ifstream inFileSource(filePathSource.c_str());
|
| 80 |
+
if (!inFileSource) {
|
| 81 |
+
cerr << "could not open file " << filePathSource << endl;
|
| 82 |
+
return false;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
std::string line;
|
| 86 |
+
while (getline(inFileSource, line)) {
|
| 87 |
+
m_vocabSource.insert(line);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
inFileSource.close();
|
| 91 |
+
|
| 92 |
+
// restricted target word vocabulary
|
| 93 |
+
ifstream inFileTarget(filePathTarget.c_str());
|
| 94 |
+
if (!inFileTarget) {
|
| 95 |
+
cerr << "could not open file " << filePathTarget << endl;
|
| 96 |
+
return false;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
while (getline(inFileTarget, line)) {
|
| 100 |
+
m_vocabTarget.insert(line);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
inFileTarget.close();
|
| 104 |
+
|
| 105 |
+
m_unrestricted = false;
|
| 106 |
+
return true;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
void GlobalLexicalModelUnlimited::InitializeForInput(ttasksptr const& ttask)
|
| 110 |
+
{
|
| 111 |
+
UTIL_THROW_IF2(ttask->GetSource()->GetType() != SentenceInput,
|
| 112 |
+
"GlobalLexicalModel works only with sentence input.");
|
| 113 |
+
Sentence const* s = reinterpret_cast<Sentence const*>(ttask->GetSource().get());
|
| 114 |
+
m_local.reset(new ThreadLocalStorage);
|
| 115 |
+
m_local->input = s;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
void GlobalLexicalModelUnlimited::EvaluateWhenApplied(const Hypothesis& cur_hypo, ScoreComponentCollection* accumulator) const
|
| 119 |
+
{
|
| 120 |
+
const Sentence& input = *(m_local->input);
|
| 121 |
+
const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase();
|
| 122 |
+
|
| 123 |
+
for(size_t targetIndex = 0; targetIndex < targetPhrase.GetSize(); targetIndex++ ) {
|
| 124 |
+
StringPiece targetString = targetPhrase.GetWord(targetIndex).GetString(0); // TODO: change for other factors
|
| 125 |
+
|
| 126 |
+
if (m_ignorePunctuation) {
|
| 127 |
+
// check if first char is punctuation
|
| 128 |
+
char firstChar = targetString[0];
|
| 129 |
+
CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
|
| 130 |
+
if(charIterator != m_punctuationHash.end())
|
| 131 |
+
continue;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
if (m_biasFeature) {
|
| 135 |
+
util::StringStream feature;
|
| 136 |
+
feature << "glm_";
|
| 137 |
+
feature << targetString;
|
| 138 |
+
feature << "~";
|
| 139 |
+
feature << "**BIAS**";
|
| 140 |
+
accumulator->SparsePlusEquals(feature.str(), 1);
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
boost::unordered_set<uint64_t> alreadyScored;
|
| 144 |
+
for(size_t sourceIndex = 0; sourceIndex < input.GetSize(); sourceIndex++ ) {
|
| 145 |
+
const StringPiece sourceString = input.GetWord(sourceIndex).GetString(0);
|
| 146 |
+
// TODO: change for other factors
|
| 147 |
+
|
| 148 |
+
if (m_ignorePunctuation) {
|
| 149 |
+
// check if first char is punctuation
|
| 150 |
+
char firstChar = sourceString[0];
|
| 151 |
+
CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
|
| 152 |
+
if(charIterator != m_punctuationHash.end())
|
| 153 |
+
continue;
|
| 154 |
+
}
|
| 155 |
+
const uint64_t sourceHash = util::MurmurHashNative(sourceString.data(), sourceString.size());
|
| 156 |
+
|
| 157 |
+
if ( alreadyScored.find(sourceHash) == alreadyScored.end()) {
|
| 158 |
+
bool sourceExists, targetExists;
|
| 159 |
+
if (!m_unrestricted) {
|
| 160 |
+
sourceExists = FindStringPiece(m_vocabSource, sourceString ) != m_vocabSource.end();
|
| 161 |
+
targetExists = FindStringPiece(m_vocabTarget, targetString) != m_vocabTarget.end();
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
// no feature if vocab is in use and both words are not in restricted vocabularies
|
| 165 |
+
if (m_unrestricted || (sourceExists && targetExists)) {
|
| 166 |
+
if (m_sourceContext) {
|
| 167 |
+
if (sourceIndex == 0) {
|
| 168 |
+
// add <s> trigger feature for source
|
| 169 |
+
util::StringStream feature;
|
| 170 |
+
feature << "glm_";
|
| 171 |
+
feature << targetString;
|
| 172 |
+
feature << "~";
|
| 173 |
+
feature << "<s>,";
|
| 174 |
+
feature << sourceString;
|
| 175 |
+
accumulator->SparsePlusEquals(feature.str(), 1);
|
| 176 |
+
alreadyScored.insert(sourceHash);
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
// add source words to the right of current source word as context
|
| 180 |
+
for(int contextIndex = sourceIndex+1; contextIndex < input.GetSize(); contextIndex++ ) {
|
| 181 |
+
StringPiece contextString = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors
|
| 182 |
+
bool contextExists;
|
| 183 |
+
if (!m_unrestricted)
|
| 184 |
+
contextExists = FindStringPiece(m_vocabSource, contextString ) != m_vocabSource.end();
|
| 185 |
+
|
| 186 |
+
if (m_unrestricted || contextExists) {
|
| 187 |
+
util::StringStream feature;
|
| 188 |
+
feature << "glm_";
|
| 189 |
+
feature << targetString;
|
| 190 |
+
feature << "~";
|
| 191 |
+
feature << sourceString;
|
| 192 |
+
feature << ",";
|
| 193 |
+
feature << contextString;
|
| 194 |
+
accumulator->SparsePlusEquals(feature.str(), 1);
|
| 195 |
+
alreadyScored.insert(sourceHash);
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
} else if (m_biphrase) {
|
| 199 |
+
// --> look backwards for constructing context
|
| 200 |
+
int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex;
|
| 201 |
+
|
| 202 |
+
// 1) source-target pair, trigger source word (can be discont.) and adjacent target word (bigram)
|
| 203 |
+
StringPiece targetContext;
|
| 204 |
+
if (globalTargetIndex > 0)
|
| 205 |
+
targetContext = cur_hypo.GetWord(globalTargetIndex-1).GetString(0); // TODO: change for other factors
|
| 206 |
+
else
|
| 207 |
+
targetContext = "<s>";
|
| 208 |
+
|
| 209 |
+
if (sourceIndex == 0) {
|
| 210 |
+
StringPiece sourceTrigger = "<s>";
|
| 211 |
+
AddFeature(accumulator, sourceTrigger, sourceString,
|
| 212 |
+
targetContext, targetString);
|
| 213 |
+
} else
|
| 214 |
+
for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) {
|
| 215 |
+
StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors
|
| 216 |
+
bool sourceTriggerExists = false;
|
| 217 |
+
if (!m_unrestricted)
|
| 218 |
+
sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();
|
| 219 |
+
|
| 220 |
+
if (m_unrestricted || sourceTriggerExists)
|
| 221 |
+
AddFeature(accumulator, sourceTrigger, sourceString,
|
| 222 |
+
targetContext, targetString);
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
// 2) source-target pair, adjacent source word (bigram) and trigger target word (can be discont.)
|
| 226 |
+
StringPiece sourceContext;
|
| 227 |
+
if (sourceIndex-1 >= 0)
|
| 228 |
+
sourceContext = input.GetWord(sourceIndex-1).GetString(0); // TODO: change for other factors
|
| 229 |
+
else
|
| 230 |
+
sourceContext = "<s>";
|
| 231 |
+
|
| 232 |
+
if (globalTargetIndex == 0) {
|
| 233 |
+
string targetTrigger = "<s>";
|
| 234 |
+
AddFeature(accumulator, sourceContext, sourceString,
|
| 235 |
+
targetTrigger, targetString);
|
| 236 |
+
} else
|
| 237 |
+
for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
|
| 238 |
+
StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors
|
| 239 |
+
bool targetTriggerExists = false;
|
| 240 |
+
if (!m_unrestricted)
|
| 241 |
+
targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end();
|
| 242 |
+
|
| 243 |
+
if (m_unrestricted || targetTriggerExists)
|
| 244 |
+
AddFeature(accumulator, sourceContext, sourceString,
|
| 245 |
+
targetTrigger, targetString);
|
| 246 |
+
}
|
| 247 |
+
} else if (m_bitrigger) {
|
| 248 |
+
// allow additional discont. triggers on both sides
|
| 249 |
+
int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex;
|
| 250 |
+
|
| 251 |
+
if (sourceIndex == 0) {
|
| 252 |
+
StringPiece sourceTrigger = "<s>";
|
| 253 |
+
bool sourceTriggerExists = true;
|
| 254 |
+
|
| 255 |
+
if (globalTargetIndex == 0) {
|
| 256 |
+
string targetTrigger = "<s>";
|
| 257 |
+
bool targetTriggerExists = true;
|
| 258 |
+
|
| 259 |
+
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
|
| 260 |
+
AddFeature(accumulator, sourceTrigger, sourceString,
|
| 261 |
+
targetTrigger, targetString);
|
| 262 |
+
} else {
|
| 263 |
+
// iterate backwards over target
|
| 264 |
+
for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
|
| 265 |
+
StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors
|
| 266 |
+
bool targetTriggerExists = false;
|
| 267 |
+
if (!m_unrestricted)
|
| 268 |
+
targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end();
|
| 269 |
+
|
| 270 |
+
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
|
| 271 |
+
AddFeature(accumulator, sourceTrigger, sourceString,
|
| 272 |
+
targetTrigger, targetString);
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
// iterate over both source and target
|
| 277 |
+
else {
|
| 278 |
+
// iterate backwards over source
|
| 279 |
+
for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) {
|
| 280 |
+
StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors
|
| 281 |
+
bool sourceTriggerExists = false;
|
| 282 |
+
if (!m_unrestricted)
|
| 283 |
+
sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();
|
| 284 |
+
|
| 285 |
+
if (globalTargetIndex == 0) {
|
| 286 |
+
string targetTrigger = "<s>";
|
| 287 |
+
bool targetTriggerExists = true;
|
| 288 |
+
|
| 289 |
+
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
|
| 290 |
+
AddFeature(accumulator, sourceTrigger, sourceString,
|
| 291 |
+
targetTrigger, targetString);
|
| 292 |
+
} else {
|
| 293 |
+
// iterate backwards over target
|
| 294 |
+
for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
|
| 295 |
+
StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors
|
| 296 |
+
bool targetTriggerExists = false;
|
| 297 |
+
if (!m_unrestricted)
|
| 298 |
+
targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end();
|
| 299 |
+
|
| 300 |
+
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
|
| 301 |
+
AddFeature(accumulator, sourceTrigger, sourceString,
|
| 302 |
+
targetTrigger, targetString);
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
} else {
|
| 308 |
+
util::StringStream feature;
|
| 309 |
+
feature << "glm_";
|
| 310 |
+
feature << targetString;
|
| 311 |
+
feature << "~";
|
| 312 |
+
feature << sourceString;
|
| 313 |
+
accumulator->SparsePlusEquals(feature.str(), 1);
|
| 314 |
+
alreadyScored.insert(sourceHash);
|
| 315 |
+
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
void GlobalLexicalModelUnlimited::AddFeature(ScoreComponentCollection* accumulator,
|
| 324 |
+
StringPiece sourceTrigger, StringPiece sourceWord,
|
| 325 |
+
StringPiece targetTrigger, StringPiece targetWord) const
|
| 326 |
+
{
|
| 327 |
+
util::StringStream feature;
|
| 328 |
+
feature << "glm_";
|
| 329 |
+
feature << targetTrigger;
|
| 330 |
+
feature << ",";
|
| 331 |
+
feature << targetWord;
|
| 332 |
+
feature << "~";
|
| 333 |
+
feature << sourceTrigger;
|
| 334 |
+
feature << ",";
|
| 335 |
+
feature << sourceWord;
|
| 336 |
+
accumulator->SparsePlusEquals(feature.str(), 1);
|
| 337 |
+
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
}
|
mosesdecoder/moses/FF/HyperParameterAsWeight.cpp
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "HyperParameterAsWeight.h"
|
| 2 |
+
#include "moses/StaticData.h"
|
| 3 |
+
|
| 4 |
+
using namespace std;
|
| 5 |
+
|
| 6 |
+
namespace Moses
|
| 7 |
+
{
|
| 8 |
+
|
| 9 |
+
HyperParameterAsWeight::HyperParameterAsWeight(const std::string &line)
|
| 10 |
+
:StatelessFeatureFunction(2, line)
|
| 11 |
+
{
|
| 12 |
+
ReadParameters();
|
| 13 |
+
|
| 14 |
+
// hack into StaticData and change anything you want
|
| 15 |
+
// as an example, we have 2 weights and change
|
| 16 |
+
// 1. stack size
|
| 17 |
+
// 2. beam width
|
| 18 |
+
StaticData &staticData = StaticData::InstanceNonConst();
|
| 19 |
+
|
| 20 |
+
vector<float> weights = staticData.GetWeights(this);
|
| 21 |
+
|
| 22 |
+
staticData.m_options->search.stack_size = weights[0] * 1000;
|
| 23 |
+
staticData.m_options->search.beam_width = weights[1] * 10;
|
| 24 |
+
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
}
|
| 29 |
+
|
mosesdecoder/moses/FF/InputFeature.h
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "InputFeature.h"
|
| 4 |
+
#include "StatelessFeatureFunction.h"
|
| 5 |
+
|
| 6 |
+
namespace Moses
|
| 7 |
+
{
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class InputFeature : public StatelessFeatureFunction
|
| 11 |
+
{
|
| 12 |
+
protected:
|
| 13 |
+
static InputFeature *s_instance;
|
| 14 |
+
|
| 15 |
+
size_t m_numInputScores;
|
| 16 |
+
size_t m_numRealWordCount;
|
| 17 |
+
bool m_legacy;
|
| 18 |
+
|
| 19 |
+
public:
|
| 20 |
+
static const InputFeature *InstancePtr() {
|
| 21 |
+
return s_instance;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
InputFeature(const std::string &line);
|
| 25 |
+
|
| 26 |
+
void Load(AllOptions::ptr const& opts);
|
| 27 |
+
|
| 28 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 29 |
+
|
| 30 |
+
bool IsUseable(const FactorMask &mask) const {
|
| 31 |
+
return true;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
size_t GetNumInputScores() const {
|
| 35 |
+
return m_numInputScores;
|
| 36 |
+
}
|
| 37 |
+
size_t GetNumRealWordsInInput() const {
|
| 38 |
+
return m_numRealWordCount;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
void EvaluateInIsolation(const Phrase &source
|
| 42 |
+
, const TargetPhrase &targetPhrase
|
| 43 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 44 |
+
, ScoreComponentCollection &estimatedScores) const {
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
void EvaluateWithSourceContext(const InputType &input
|
| 48 |
+
, const InputPath &inputPath
|
| 49 |
+
, const TargetPhrase &targetPhrase
|
| 50 |
+
, const StackVec *stackVec
|
| 51 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 52 |
+
, ScoreComponentCollection *estimatedScores = NULL) const;
|
| 53 |
+
|
| 54 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 55 |
+
, const TranslationOptionList &translationOptionList) const {
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 59 |
+
ScoreComponentCollection* accumulator) const {
|
| 60 |
+
}
|
| 61 |
+
void EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 62 |
+
ScoreComponentCollection* accumulator) const {
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
}
|
| 70 |
+
|
mosesdecoder/moses/FF/Model1Feature.cpp
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "Model1Feature.h"
|
| 2 |
+
#include "moses/StaticData.h"
|
| 3 |
+
#include "moses/InputFileStream.h"
|
| 4 |
+
#include "moses/ScoreComponentCollection.h"
|
| 5 |
+
#include "moses/FactorCollection.h"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
using namespace std;
|
| 9 |
+
|
| 10 |
+
namespace Moses
|
| 11 |
+
{
|
| 12 |
+
|
| 13 |
+
const std::string Model1Vocabulary::GIZANULL = "GIZANULL";
|
| 14 |
+
|
| 15 |
+
Model1Vocabulary::Model1Vocabulary()
|
| 16 |
+
{
|
| 17 |
+
FactorCollection &factorCollection = FactorCollection::Instance();
|
| 18 |
+
m_NULL = factorCollection.AddFactor(GIZANULL,false);
|
| 19 |
+
Store(m_NULL,0);
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
bool Model1Vocabulary::Store(const Factor* word, const unsigned id)
|
| 23 |
+
{
|
| 24 |
+
boost::unordered_map<const Factor*, unsigned>::iterator iter = m_lookup.find( word );
|
| 25 |
+
if ( iter != m_lookup.end() ) {
|
| 26 |
+
return false;
|
| 27 |
+
}
|
| 28 |
+
m_lookup[ word ] = id;
|
| 29 |
+
if ( m_vocab.size() <= id ) {
|
| 30 |
+
m_vocab.resize(id+1);
|
| 31 |
+
}
|
| 32 |
+
m_vocab[id] = word;
|
| 33 |
+
return true;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
unsigned Model1Vocabulary::StoreIfNew(const Factor* word)
|
| 37 |
+
{
|
| 38 |
+
boost::unordered_map<const Factor*, unsigned>::iterator iter = m_lookup.find( word );
|
| 39 |
+
|
| 40 |
+
if ( iter != m_lookup.end() ) {
|
| 41 |
+
return iter->second;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
unsigned id = m_vocab.size();
|
| 45 |
+
m_vocab.push_back( word );
|
| 46 |
+
m_lookup[ word ] = id;
|
| 47 |
+
return id;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
unsigned Model1Vocabulary::GetWordID(const Factor* word) const
|
| 51 |
+
{
|
| 52 |
+
boost::unordered_map<const Factor*, unsigned>::const_iterator iter = m_lookup.find( word );
|
| 53 |
+
if ( iter == m_lookup.end() ) {
|
| 54 |
+
return INVALID_ID;
|
| 55 |
+
}
|
| 56 |
+
return iter->second;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
const Factor* Model1Vocabulary::GetWord(unsigned id) const
|
| 60 |
+
{
|
| 61 |
+
if (id >= m_vocab.size()) {
|
| 62 |
+
return NULL;
|
| 63 |
+
}
|
| 64 |
+
return m_vocab[ id ];
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
void Model1Vocabulary::Load(const std::string& fileName)
|
| 68 |
+
{
|
| 69 |
+
InputFileStream inFile(fileName);
|
| 70 |
+
FactorCollection &factorCollection = FactorCollection::Instance();
|
| 71 |
+
std::string line;
|
| 72 |
+
|
| 73 |
+
unsigned i = 0;
|
| 74 |
+
if ( getline(inFile, line) ) { // first line of MGIZA vocabulary files seems to be special : "1 UNK 0" -- skip if it's this
|
| 75 |
+
++i;
|
| 76 |
+
std::vector<std::string> tokens = Tokenize(line);
|
| 77 |
+
UTIL_THROW_IF2(tokens.size()!=3, "Line " << i << " in " << fileName << " has wrong number of tokens.");
|
| 78 |
+
unsigned id = atoll( tokens[0].c_str() );
|
| 79 |
+
if (! ( (id == 1) && (tokens[1] == "UNK") )) {
|
| 80 |
+
const Factor* factor = factorCollection.AddFactor(tokens[1],false); // TODO: can we assume that the vocabulary is know and filter the model on loading?
|
| 81 |
+
bool stored = Store(factor, id);
|
| 82 |
+
UTIL_THROW_IF2(!stored, "Line " << i << " in " << fileName << " overwrites existing vocabulary entry.");
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
while ( getline(inFile, line) ) {
|
| 86 |
+
++i;
|
| 87 |
+
std::vector<std::string> tokens = Tokenize(line);
|
| 88 |
+
UTIL_THROW_IF2(tokens.size()!=3, "Line " << i << " in " << fileName << " has wrong number of tokens.");
|
| 89 |
+
unsigned id = atoll( tokens[0].c_str() );
|
| 90 |
+
const Factor* factor = factorCollection.AddFactor(tokens[1],false); // TODO: can we assume that the vocabulary is know and filter the model on loading?
|
| 91 |
+
bool stored = Store(factor, id);
|
| 92 |
+
UTIL_THROW_IF2(!stored, "Line " << i << " in " << fileName << " overwrites existing vocabulary entry.");
|
| 93 |
+
}
|
| 94 |
+
inFile.Close();
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
void Model1LexicalTable::Load(const std::string &fileName, const Model1Vocabulary& vcbS, const Model1Vocabulary& vcbT)
|
| 99 |
+
{
|
| 100 |
+
InputFileStream inFile(fileName);
|
| 101 |
+
std::string line;
|
| 102 |
+
|
| 103 |
+
unsigned i = 0;
|
| 104 |
+
while ( getline(inFile, line) ) {
|
| 105 |
+
++i;
|
| 106 |
+
std::vector<std::string> tokens = Tokenize(line);
|
| 107 |
+
UTIL_THROW_IF2(tokens.size()!=3, "Line " << i << " in " << fileName << " has wrong number of tokens.");
|
| 108 |
+
unsigned idS = atoll( tokens[0].c_str() );
|
| 109 |
+
unsigned idT = atoll( tokens[1].c_str() );
|
| 110 |
+
const Factor* wordS = vcbS.GetWord(idS);
|
| 111 |
+
const Factor* wordT = vcbT.GetWord(idT);
|
| 112 |
+
float prob = std::atof( tokens[2].c_str() );
|
| 113 |
+
if ( (wordS != NULL) && (wordT != NULL) ) {
|
| 114 |
+
m_ltable[ wordS ][ wordT ] = prob;
|
| 115 |
+
}
|
| 116 |
+
UTIL_THROW_IF2((wordS == NULL) || (wordT == NULL), "Line " << i << " in " << fileName << " has unknown vocabulary."); // TODO: can we assume that the vocabulary is know and filter the model on loading? Then remove this line.
|
| 117 |
+
}
|
| 118 |
+
inFile.Close();
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
// p( wordT | wordS )
|
| 122 |
+
float Model1LexicalTable::GetProbability(const Factor* wordS, const Factor* wordT) const
|
| 123 |
+
{
|
| 124 |
+
float prob = m_floor;
|
| 125 |
+
|
| 126 |
+
boost::unordered_map< const Factor*, boost::unordered_map< const Factor*, float > >::const_iterator iter1 = m_ltable.find( wordS );
|
| 127 |
+
|
| 128 |
+
if ( iter1 != m_ltable.end() ) {
|
| 129 |
+
boost::unordered_map< const Factor*, float >::const_iterator iter2 = iter1->second.find( wordT );
|
| 130 |
+
if ( iter2 != iter1->second.end() ) {
|
| 131 |
+
prob = iter2->second;
|
| 132 |
+
if ( prob < m_floor ) {
|
| 133 |
+
prob = m_floor;
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
return prob;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
Model1Feature::Model1Feature(const std::string &line)
|
| 142 |
+
: StatelessFeatureFunction(1, line)
|
| 143 |
+
, m_skipTargetPunctuation(false)
|
| 144 |
+
, m_is_syntax(false)
|
| 145 |
+
{
|
| 146 |
+
VERBOSE(1, "Initializing feature " << GetScoreProducerDescription() << " ...");
|
| 147 |
+
ReadParameters();
|
| 148 |
+
VERBOSE(1, " Done.");
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
void Model1Feature::SetParameter(const std::string& key, const std::string& value)
|
| 152 |
+
{
|
| 153 |
+
if (key == "path") {
|
| 154 |
+
m_fileNameModel1 = value;
|
| 155 |
+
} else if (key == "source-vocabulary") {
|
| 156 |
+
m_fileNameVcbS = value;
|
| 157 |
+
} else if (key == "target-vocabulary") {
|
| 158 |
+
m_fileNameVcbT = value;
|
| 159 |
+
} else if (key == "skip-target-punctuation") {
|
| 160 |
+
m_skipTargetPunctuation = Scan<bool>(value);
|
| 161 |
+
} else {
|
| 162 |
+
StatelessFeatureFunction::SetParameter(key, value);
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
void Model1Feature::Load(AllOptions::ptr const& opts)
|
| 167 |
+
{
|
| 168 |
+
m_options = opts;
|
| 169 |
+
m_is_syntax = is_syntax(opts->search.algo);
|
| 170 |
+
|
| 171 |
+
FEATUREVERBOSE(2, GetScoreProducerDescription() << ": Loading source vocabulary from file " << m_fileNameVcbS << " ...");
|
| 172 |
+
Model1Vocabulary vcbS;
|
| 173 |
+
vcbS.Load(m_fileNameVcbS);
|
| 174 |
+
FEATUREVERBOSE2(2, " Done." << std::endl);
|
| 175 |
+
FEATUREVERBOSE(2, GetScoreProducerDescription() << ": Loading target vocabulary from file " << m_fileNameVcbT << " ...");
|
| 176 |
+
Model1Vocabulary vcbT;
|
| 177 |
+
vcbT.Load(m_fileNameVcbT);
|
| 178 |
+
FEATUREVERBOSE2(2, " Done." << std::endl);
|
| 179 |
+
FEATUREVERBOSE(2, GetScoreProducerDescription() << ": Loading model 1 lexical translation table from file " << m_fileNameModel1 << " ...");
|
| 180 |
+
m_model1.Load(m_fileNameModel1,vcbS,vcbT);
|
| 181 |
+
FEATUREVERBOSE2(2, " Done." << std::endl);
|
| 182 |
+
FactorCollection &factorCollection = FactorCollection::Instance();
|
| 183 |
+
m_emptyWord = factorCollection.GetFactor(Model1Vocabulary::GIZANULL,false);
|
| 184 |
+
UTIL_THROW_IF2(m_emptyWord==NULL, GetScoreProducerDescription()
|
| 185 |
+
<< ": Factor for GIZA empty word does not exist.");
|
| 186 |
+
|
| 187 |
+
if (m_skipTargetPunctuation) {
|
| 188 |
+
const std::string punctuation = ",;.:!?";
|
| 189 |
+
for (size_t i=0; i<punctuation.size(); ++i) {
|
| 190 |
+
const std::string punct = punctuation.substr(i,1);
|
| 191 |
+
FactorCollection &factorCollection = FactorCollection::Instance();
|
| 192 |
+
const Factor* punctFactor = factorCollection.AddFactor(punct,false);
|
| 193 |
+
std::pair<std::set<const Factor*>::iterator,bool> inserted = m_punctuation.insert(punctFactor);
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
void Model1Feature::EvaluateWithSourceContext(const InputType &input
|
| 199 |
+
, const InputPath &inputPath
|
| 200 |
+
, const TargetPhrase &targetPhrase
|
| 201 |
+
, const StackVec *stackVec
|
| 202 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 203 |
+
, ScoreComponentCollection *estimatedScores) const
|
| 204 |
+
{
|
| 205 |
+
const Sentence& sentence = static_cast<const Sentence&>(input);
|
| 206 |
+
float score = 0.0;
|
| 207 |
+
float norm = TransformScore(1+sentence.GetSize());
|
| 208 |
+
|
| 209 |
+
for (size_t posT=0; posT<targetPhrase.GetSize(); ++posT) {
|
| 210 |
+
const Word &wordT = targetPhrase.GetWord(posT);
|
| 211 |
+
if (m_skipTargetPunctuation) {
|
| 212 |
+
std::set<const Factor*>::const_iterator foundPunctuation = m_punctuation.find(wordT[0]);
|
| 213 |
+
if (foundPunctuation != m_punctuation.end()) {
|
| 214 |
+
continue;
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
if ( !wordT.IsNonTerminal() ) {
|
| 218 |
+
float thisWordProb = m_model1.GetProbability(m_emptyWord,wordT[0]); // probability conditioned on empty word
|
| 219 |
+
|
| 220 |
+
// cache lookup
|
| 221 |
+
bool foundInCache = false;
|
| 222 |
+
{
|
| 223 |
+
#ifdef WITH_THREADS
|
| 224 |
+
boost::shared_lock<boost::shared_mutex> read_lock(m_accessLock);
|
| 225 |
+
#endif
|
| 226 |
+
boost::unordered_map<const InputType*, boost::unordered_map<const Factor*, float> >::const_iterator sentenceCache = m_cache.find(&input);
|
| 227 |
+
if (sentenceCache != m_cache.end()) {
|
| 228 |
+
boost::unordered_map<const Factor*, float>::const_iterator cacheHit = sentenceCache->second.find(wordT[0]);
|
| 229 |
+
if (cacheHit != sentenceCache->second.end()) {
|
| 230 |
+
foundInCache = true;
|
| 231 |
+
score += cacheHit->second;
|
| 232 |
+
FEATUREVERBOSE(3, "Cached score( " << wordT << " ) = " << cacheHit->second << std::endl);
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
if (!foundInCache) {
|
| 238 |
+
for (size_t posS=(m_is_syntax?1:0); posS<(m_is_syntax?sentence.GetSize()-1:sentence.GetSize()); ++posS) { // ignore <s> and </s>
|
| 239 |
+
const Word &wordS = sentence.GetWord(posS);
|
| 240 |
+
float modelProb = m_model1.GetProbability(wordS[0],wordT[0]);
|
| 241 |
+
FEATUREVERBOSE(4, "p( " << wordT << " | " << wordS << " ) = " << modelProb << std::endl);
|
| 242 |
+
thisWordProb += modelProb;
|
| 243 |
+
}
|
| 244 |
+
float thisWordScore = TransformScore(thisWordProb) - norm;
|
| 245 |
+
FEATUREVERBOSE(3, "score( " << wordT << " ) = " << thisWordScore << std::endl);
|
| 246 |
+
{
|
| 247 |
+
#ifdef WITH_THREADS
|
| 248 |
+
// need to update cache; write lock
|
| 249 |
+
boost::unique_lock<boost::shared_mutex> lock(m_accessLock);
|
| 250 |
+
#endif
|
| 251 |
+
m_cache[&input][wordT[0]] = thisWordScore;
|
| 252 |
+
}
|
| 253 |
+
score += thisWordScore;
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
scoreBreakdown.PlusEquals(this, score);
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
void Model1Feature::CleanUpAfterSentenceProcessing(const InputType& source)
|
| 262 |
+
{
|
| 263 |
+
#ifdef WITH_THREADS
|
| 264 |
+
// need to update cache; write lock
|
| 265 |
+
boost::unique_lock<boost::shared_mutex> lock(m_accessLock);
|
| 266 |
+
#endif
|
| 267 |
+
// clear cache
|
| 268 |
+
boost::unordered_map<const InputType*, boost::unordered_map<const Factor*, float> >::iterator sentenceCache = m_cache.find(&source);
|
| 269 |
+
if (sentenceCache != m_cache.end()) {
|
| 270 |
+
sentenceCache->second.clear();
|
| 271 |
+
m_cache.erase(sentenceCache);
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
}
|
| 276 |
+
|
mosesdecoder/moses/FF/Model1Feature.h
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <string>
|
| 4 |
+
#include <limits>
|
| 5 |
+
#include <set>
|
| 6 |
+
#include <boost/unordered_map.hpp>
|
| 7 |
+
#include "StatelessFeatureFunction.h"
|
| 8 |
+
#include "moses/Factor.h"
|
| 9 |
+
|
| 10 |
+
#ifdef WITH_THREADS
|
| 11 |
+
#include <boost/thread/shared_mutex.hpp>
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
namespace Moses
|
| 15 |
+
{
|
| 16 |
+
|
| 17 |
+
class Model1Vocabulary
|
| 18 |
+
{
|
| 19 |
+
public:
|
| 20 |
+
|
| 21 |
+
#define INVALID_ID std::numeric_limits<unsigned>::max() // UINT_MAX
|
| 22 |
+
static const std::string GIZANULL;
|
| 23 |
+
|
| 24 |
+
Model1Vocabulary();
|
| 25 |
+
bool Store(const Factor* word, const unsigned id);
|
| 26 |
+
unsigned StoreIfNew(const Factor* word);
|
| 27 |
+
unsigned GetWordID(const Factor* word) const;
|
| 28 |
+
const Factor* GetWord(unsigned id) const;
|
| 29 |
+
void Load(const std::string& fileName);
|
| 30 |
+
|
| 31 |
+
protected:
|
| 32 |
+
boost::unordered_map<const Factor*, unsigned> m_lookup;
|
| 33 |
+
std::vector< const Factor* > m_vocab;
|
| 34 |
+
const Factor* m_NULL;
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Model1LexicalTable
|
| 39 |
+
{
|
| 40 |
+
public:
|
| 41 |
+
Model1LexicalTable(float floor=1e-7) : m_floor(floor)
|
| 42 |
+
{}
|
| 43 |
+
|
| 44 |
+
void Load(const std::string& fileName, const Model1Vocabulary& vcbS, const Model1Vocabulary& vcbT);
|
| 45 |
+
|
| 46 |
+
// p( wordT | wordS )
|
| 47 |
+
float GetProbability(const Factor* wordS, const Factor* wordT) const;
|
| 48 |
+
|
| 49 |
+
protected:
|
| 50 |
+
boost::unordered_map< const Factor*, boost::unordered_map< const Factor*, float > > m_ltable;
|
| 51 |
+
const float m_floor;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Model1Feature : public StatelessFeatureFunction
|
| 57 |
+
{
|
| 58 |
+
public:
|
| 59 |
+
Model1Feature(const std::string &line);
|
| 60 |
+
|
| 61 |
+
bool IsUseable(const FactorMask &mask) const {
|
| 62 |
+
return true;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 66 |
+
|
| 67 |
+
void EvaluateInIsolation(const Phrase &source
|
| 68 |
+
, const TargetPhrase &targetPhrase
|
| 69 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 70 |
+
, ScoreComponentCollection &estimatedScores) const
|
| 71 |
+
{};
|
| 72 |
+
|
| 73 |
+
void EvaluateWithSourceContext(const InputType &input
|
| 74 |
+
, const InputPath &inputPath
|
| 75 |
+
, const TargetPhrase &targetPhrase
|
| 76 |
+
, const StackVec *stackVec
|
| 77 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 78 |
+
, ScoreComponentCollection *estimatedScores = NULL) const;
|
| 79 |
+
|
| 80 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 81 |
+
, const TranslationOptionList &translationOptionList) const
|
| 82 |
+
{}
|
| 83 |
+
|
| 84 |
+
void EvaluateWhenApplied(
|
| 85 |
+
const Hypothesis& cur_hypo,
|
| 86 |
+
ScoreComponentCollection* accumulator) const
|
| 87 |
+
{}
|
| 88 |
+
|
| 89 |
+
void EvaluateWhenApplied(
|
| 90 |
+
const ChartHypothesis& cur_hypo,
|
| 91 |
+
ScoreComponentCollection* accumulator) const
|
| 92 |
+
{}
|
| 93 |
+
|
| 94 |
+
void CleanUpAfterSentenceProcessing(const InputType& source);
|
| 95 |
+
|
| 96 |
+
private:
|
| 97 |
+
std::string m_fileNameVcbS;
|
| 98 |
+
std::string m_fileNameVcbT;
|
| 99 |
+
std::string m_fileNameModel1;
|
| 100 |
+
Model1LexicalTable m_model1;
|
| 101 |
+
const Factor* m_emptyWord;
|
| 102 |
+
bool m_skipTargetPunctuation;
|
| 103 |
+
std::set<const Factor*> m_punctuation;
|
| 104 |
+
bool m_is_syntax;
|
| 105 |
+
|
| 106 |
+
void Load(AllOptions::ptr const& opts);
|
| 107 |
+
|
| 108 |
+
// cache
|
| 109 |
+
mutable boost::unordered_map<const InputType*, boost::unordered_map<const Factor*, float> > m_cache;
|
| 110 |
+
#ifdef WITH_THREADS
|
| 111 |
+
// reader-writer lock
|
| 112 |
+
mutable boost::shared_mutex m_accessLock;
|
| 113 |
+
#endif
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
}
|
| 118 |
+
|
mosesdecoder/moses/FF/PhraseBoundaryFeature.h
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef moses_PhraseBoundaryFeature_h
|
| 2 |
+
#define moses_PhraseBoundaryFeature_h
|
| 3 |
+
|
| 4 |
+
#include <stdexcept>
|
| 5 |
+
#include <sstream>
|
| 6 |
+
#include <string>
|
| 7 |
+
|
| 8 |
+
#include "StatefulFeatureFunction.h"
|
| 9 |
+
#include "moses/FF/FFState.h"
|
| 10 |
+
#include "moses/Word.h"
|
| 11 |
+
|
| 12 |
+
namespace Moses
|
| 13 |
+
{
|
| 14 |
+
|
| 15 |
+
class PhraseBoundaryState : public FFState
|
| 16 |
+
{
|
| 17 |
+
public:
|
| 18 |
+
PhraseBoundaryState(const Word* sourceWord, const Word* targetWord) :
|
| 19 |
+
m_sourceWord(sourceWord), m_targetWord(targetWord) {}
|
| 20 |
+
const Word* GetSourceWord() const {
|
| 21 |
+
return m_sourceWord;
|
| 22 |
+
}
|
| 23 |
+
const Word* GetTargetWord() const {
|
| 24 |
+
return m_targetWord;
|
| 25 |
+
}
|
| 26 |
+
virtual size_t hash() const;
|
| 27 |
+
virtual bool operator==(const FFState& other) const;
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
private:
|
| 31 |
+
const Word* m_sourceWord;
|
| 32 |
+
const Word* m_targetWord;
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
/**
|
| 37 |
+
* Concatenations of factors on boundaries of phrases.
|
| 38 |
+
**/
|
| 39 |
+
class PhraseBoundaryFeature : public StatefulFeatureFunction
|
| 40 |
+
{
|
| 41 |
+
public:
|
| 42 |
+
PhraseBoundaryFeature(const std::string &line);
|
| 43 |
+
|
| 44 |
+
bool IsUseable(const FactorMask &mask) const;
|
| 45 |
+
|
| 46 |
+
virtual const FFState* EmptyHypothesisState(const InputType &) const;
|
| 47 |
+
|
| 48 |
+
virtual FFState* EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state,
|
| 49 |
+
ScoreComponentCollection* accumulator) const;
|
| 50 |
+
|
| 51 |
+
virtual FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */,
|
| 52 |
+
int /* featureID */,
|
| 53 |
+
ScoreComponentCollection* ) const {
|
| 54 |
+
throw std::logic_error("PhraseBoundaryState not supported in chart decoder, yet");
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 58 |
+
|
| 59 |
+
private:
|
| 60 |
+
void AddFeatures(
|
| 61 |
+
const Word* leftWord, const Word* rightWord, const FactorList& factors,
|
| 62 |
+
const std::string& side, ScoreComponentCollection* scores) const ;
|
| 63 |
+
FactorList m_sourceFactors;
|
| 64 |
+
FactorList m_targetFactors;
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
#endif
|
mosesdecoder/moses/FF/PhraseDistanceFeature.cpp
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "PhraseDistanceFeature.h"
|
| 2 |
+
|
| 3 |
+
#include <vector>
|
| 4 |
+
#include <boost/foreach.hpp>
|
| 5 |
+
#include "moses/InputType.h"
|
| 6 |
+
#include "moses/ScoreComponentCollection.h"
|
| 7 |
+
#include "moses/StaticData.h"
|
| 8 |
+
#include "util/exception.hh"
|
| 9 |
+
|
| 10 |
+
using namespace std;
|
| 11 |
+
|
| 12 |
+
namespace Moses
|
| 13 |
+
{
|
| 14 |
+
PhraseDistanceFeature::PhraseDistanceFeature(const string &line)
|
| 15 |
+
: StatelessFeatureFunction(2, line)
|
| 16 |
+
, m_space("")
|
| 17 |
+
, m_spaceID(0)
|
| 18 |
+
, m_measure(EuclideanDistance)
|
| 19 |
+
{
|
| 20 |
+
ReadParameters();
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
void PhraseDistanceFeature::EvaluateWithSourceContext(const InputType &input
|
| 24 |
+
, const InputPath &inputPath
|
| 25 |
+
, const TargetPhrase &targetPhrase
|
| 26 |
+
, const StackVec *stackVec
|
| 27 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 28 |
+
, ScoreComponentCollection *estimatedScores) const
|
| 29 |
+
{
|
| 30 |
+
vector<float> scores(m_numScoreComponents, 0);
|
| 31 |
+
bool broken = false;
|
| 32 |
+
// Input coord
|
| 33 |
+
map<size_t const, vector<float> >::const_iterator ii;
|
| 34 |
+
if (input.m_coordMap) {
|
| 35 |
+
ii = input.m_coordMap->find(m_spaceID);
|
| 36 |
+
} else {
|
| 37 |
+
TRACE_ERR("No coordinates for space " << m_space << " on input (specify with coord XML tag)" << endl);
|
| 38 |
+
TRACE_ERR("Scores for " << m_description << " will be incorrect and probably all zeros" << endl);
|
| 39 |
+
broken = true;
|
| 40 |
+
}
|
| 41 |
+
if (ii == input.m_coordMap->end()) {
|
| 42 |
+
TRACE_ERR("No coordinates for space " << m_space << " on input (specify with coord XML tag)" << endl);
|
| 43 |
+
TRACE_ERR("Scores for " << m_description << " will be incorrect and probably all zeros" << endl);
|
| 44 |
+
broken = true;
|
| 45 |
+
}
|
| 46 |
+
// Target phrase coord
|
| 47 |
+
vector<SPTR<vector<float> > > const* tpp = targetPhrase.GetCoordList(m_spaceID);
|
| 48 |
+
if (tpp == NULL) {
|
| 49 |
+
TRACE_ERR("No coordinates for space " << m_space << " on target phrase (PhraseDictionary implementation needs to set)" << endl);
|
| 50 |
+
TRACE_ERR("Scores for " << m_description << " will be incorrect and probably all zeros" << endl);
|
| 51 |
+
broken = true;
|
| 52 |
+
}
|
| 53 |
+
// Compute scores
|
| 54 |
+
if (!broken) {
|
| 55 |
+
vector<float> const& inputCoord = ii->second;
|
| 56 |
+
vector<SPTR<vector<float> > > const& tpCoord = *tpp;
|
| 57 |
+
// Centroid of target phrase instances (from phrase extraction)
|
| 58 |
+
vector<float> centroid = vector<float>(inputCoord.size(), 0);
|
| 59 |
+
BOOST_FOREACH(SPTR<vector<float> > const coord, tpCoord) {
|
| 60 |
+
for (size_t i = 0; i < inputCoord.size(); ++i) {
|
| 61 |
+
centroid[i] += (*coord)[i];
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
for (size_t i = 0; i < inputCoord.size(); ++i) {
|
| 65 |
+
centroid[i] /= tpCoord.size();
|
| 66 |
+
}
|
| 67 |
+
// Average distance from the target phrase instances to (1) the input and
|
| 68 |
+
// (2) the target phrase centroid
|
| 69 |
+
float inputDistance = 0;
|
| 70 |
+
float centroidDistance = 0;
|
| 71 |
+
if (m_measure == EuclideanDistance) {
|
| 72 |
+
BOOST_FOREACH(SPTR<vector<float> > const coord, tpCoord) {
|
| 73 |
+
float pointInputDistance = 0;
|
| 74 |
+
float pointCentroidDistance = 0;
|
| 75 |
+
for (size_t i = 0; i < inputCoord.size(); ++i) {
|
| 76 |
+
pointInputDistance += pow(inputCoord[i] - (*coord)[i], 2);
|
| 77 |
+
pointCentroidDistance += pow(centroid[i] - (*coord)[i], 2);
|
| 78 |
+
}
|
| 79 |
+
inputDistance += sqrt(pointInputDistance);
|
| 80 |
+
centroidDistance += sqrt(pointCentroidDistance);
|
| 81 |
+
}
|
| 82 |
+
} else if (m_measure == TotalVariationDistance) {
|
| 83 |
+
BOOST_FOREACH(SPTR<vector<float> > const coord, tpCoord) {
|
| 84 |
+
float pointInputDistance = 0;
|
| 85 |
+
float pointCentroidDistance = 0;
|
| 86 |
+
for (size_t i = 0; i < inputCoord.size(); ++i) {
|
| 87 |
+
pointInputDistance += abs(inputCoord[i] - (*coord)[i]);
|
| 88 |
+
pointCentroidDistance += abs(centroid[i] - (*coord)[i]);
|
| 89 |
+
}
|
| 90 |
+
inputDistance += pointInputDistance / 2;
|
| 91 |
+
centroidDistance += pointCentroidDistance / 2;
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
inputDistance /= tpCoord.size();
|
| 95 |
+
centroidDistance /= tpCoord.size();
|
| 96 |
+
// Log transform scores, max with float epsilon to avoid domain error
|
| 97 |
+
scores[0] = log(max(inputDistance, Moses::FLOAT_EPSILON));
|
| 98 |
+
scores[1] = log(max(centroidDistance, Moses::FLOAT_EPSILON));
|
| 99 |
+
}
|
| 100 |
+
// Set scores
|
| 101 |
+
scoreBreakdown.Assign(this, scores);
|
| 102 |
+
return;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
void PhraseDistanceFeature::SetParameter(const string& key, const string& value)
|
| 106 |
+
{
|
| 107 |
+
if (key == "space") {
|
| 108 |
+
m_space = value;
|
| 109 |
+
m_spaceID = StaticData::InstanceNonConst().MapCoordSpace(m_space);
|
| 110 |
+
} else if (key == "measure") {
|
| 111 |
+
if (value == "euc") {
|
| 112 |
+
m_measure = EuclideanDistance;
|
| 113 |
+
} else if (value == "var") {
|
| 114 |
+
m_measure = TotalVariationDistance;
|
| 115 |
+
} else {
|
| 116 |
+
UTIL_THROW2("Unknown measure " << value << ", choices: euc var");
|
| 117 |
+
}
|
| 118 |
+
} else {
|
| 119 |
+
StatelessFeatureFunction::SetParameter(key, value);
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
} // namespace
|
mosesdecoder/moses/FF/PhraseDistanceFeature.h
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "StatelessFeatureFunction.h"
|
| 4 |
+
|
| 5 |
+
namespace Moses
|
| 6 |
+
{
|
| 7 |
+
|
| 8 |
+
class PhraseDistanceFeature : public StatelessFeatureFunction
|
| 9 |
+
{
|
| 10 |
+
enum Measure {
|
| 11 |
+
EuclideanDistance,
|
| 12 |
+
TotalVariationDistance,
|
| 13 |
+
};
|
| 14 |
+
|
| 15 |
+
public:
|
| 16 |
+
PhraseDistanceFeature(const std::string &line);
|
| 17 |
+
|
| 18 |
+
bool IsUseable(const FactorMask &mask) const {
|
| 19 |
+
return true;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
virtual void EvaluateInIsolation(const Phrase &source
|
| 23 |
+
, const TargetPhrase &targetPhrase
|
| 24 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 25 |
+
, ScoreComponentCollection &estimatedScores) const {
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 29 |
+
ScoreComponentCollection* accumulator) const {
|
| 30 |
+
}
|
| 31 |
+
void EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 32 |
+
ScoreComponentCollection* accumulator) const {
|
| 33 |
+
}
|
| 34 |
+
void EvaluateWhenApplied(const Syntax::SHyperedge &hyperedge,
|
| 35 |
+
ScoreComponentCollection* accumulator) const {
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
void EvaluateWithSourceContext(const InputType &input
|
| 39 |
+
, const InputPath &inputPath
|
| 40 |
+
, const TargetPhrase &targetPhrase
|
| 41 |
+
, const StackVec *stackVec
|
| 42 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 43 |
+
, ScoreComponentCollection *estimatedScores = NULL) const;
|
| 44 |
+
|
| 45 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 46 |
+
, const TranslationOptionList &translationOptionList) const {
|
| 47 |
+
}
|
| 48 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 49 |
+
|
| 50 |
+
protected:
|
| 51 |
+
Measure m_measure;
|
| 52 |
+
std::string m_space;
|
| 53 |
+
size_t m_spaceID;
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
} //namespace
|
mosesdecoder/moses/FF/PhraseLengthFeature.cpp
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <sstream>
|
| 2 |
+
#include "PhraseLengthFeature.h"
|
| 3 |
+
#include "moses/Hypothesis.h"
|
| 4 |
+
#include "moses/ScoreComponentCollection.h"
|
| 5 |
+
#include "moses/TranslationOption.h"
|
| 6 |
+
#include "util/string_stream.hh"
|
| 7 |
+
|
| 8 |
+
namespace Moses
|
| 9 |
+
{
|
| 10 |
+
|
| 11 |
+
using namespace std;
|
| 12 |
+
|
| 13 |
+
PhraseLengthFeature::PhraseLengthFeature(const std::string &line)
|
| 14 |
+
:StatelessFeatureFunction(0, line)
|
| 15 |
+
{
|
| 16 |
+
ReadParameters();
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
void PhraseLengthFeature::EvaluateInIsolation(const Phrase &source
|
| 20 |
+
, const TargetPhrase &targetPhrase
|
| 21 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 22 |
+
, ScoreComponentCollection &estimatedScores) const
|
| 23 |
+
{
|
| 24 |
+
// get length of source and target phrase
|
| 25 |
+
size_t targetLength = targetPhrase.GetSize();
|
| 26 |
+
size_t sourceLength = source.GetSize();
|
| 27 |
+
|
| 28 |
+
// create feature names
|
| 29 |
+
util::StringStream nameSource;
|
| 30 |
+
nameSource << "s" << sourceLength;
|
| 31 |
+
|
| 32 |
+
util::StringStream nameTarget;
|
| 33 |
+
nameTarget << "t" << targetLength;
|
| 34 |
+
|
| 35 |
+
util::StringStream nameBoth;
|
| 36 |
+
nameBoth << sourceLength << "," << targetLength;
|
| 37 |
+
|
| 38 |
+
// increase feature counts
|
| 39 |
+
scoreBreakdown.PlusEquals(this,nameSource.str(),1);
|
| 40 |
+
scoreBreakdown.PlusEquals(this,nameTarget.str(),1);
|
| 41 |
+
scoreBreakdown.PlusEquals(this,nameBoth.str(),1);
|
| 42 |
+
|
| 43 |
+
//cerr << nameSource.str() << " " << nameTarget.str() << " " << nameBoth.str() << endl;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
}
|
mosesdecoder/moses/FF/PhrasePenalty.cpp
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <vector>
|
| 2 |
+
#include "PhrasePenalty.h"
|
| 3 |
+
#include "moses/ScoreComponentCollection.h"
|
| 4 |
+
#include "moses/TranslationModel/PhraseDictionary.h"
|
| 5 |
+
#include "util/exception.hh"
|
| 6 |
+
|
| 7 |
+
using namespace std;
|
| 8 |
+
|
| 9 |
+
namespace Moses
|
| 10 |
+
{
|
| 11 |
+
PhrasePenalty::PhrasePenalty(const std::string &line)
|
| 12 |
+
: StatelessFeatureFunction(1, line)
|
| 13 |
+
, m_perPhraseTable(false)
|
| 14 |
+
{
|
| 15 |
+
ReadParameters();
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
void PhrasePenalty::EvaluateInIsolation(const Phrase &source
|
| 19 |
+
, const TargetPhrase &targetPhrase
|
| 20 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 21 |
+
, ScoreComponentCollection &estimatedScores) const
|
| 22 |
+
{
|
| 23 |
+
if (m_perPhraseTable) {
|
| 24 |
+
const PhraseDictionary *pt = targetPhrase.GetContainer();
|
| 25 |
+
if (pt) {
|
| 26 |
+
size_t ptId = pt->GetId();
|
| 27 |
+
UTIL_THROW_IF2(ptId >= m_numScoreComponents, "Wrong number of scores");
|
| 28 |
+
|
| 29 |
+
vector<float> scores(m_numScoreComponents, 0);
|
| 30 |
+
scores[ptId] = 1.0f;
|
| 31 |
+
|
| 32 |
+
scoreBreakdown.Assign(this, scores);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
} else {
|
| 36 |
+
scoreBreakdown.Assign(this, 1.0f);
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
void PhrasePenalty::SetParameter(const std::string& key, const std::string& value)
|
| 41 |
+
{
|
| 42 |
+
if (key == "per-phrase-table") {
|
| 43 |
+
m_perPhraseTable =Scan<bool>(value);
|
| 44 |
+
} else {
|
| 45 |
+
StatelessFeatureFunction::SetParameter(key, value);
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
} // namespace
|
| 51 |
+
|
mosesdecoder/moses/FF/PhrasePenalty.h
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "StatelessFeatureFunction.h"
|
| 4 |
+
|
| 5 |
+
namespace Moses
|
| 6 |
+
{
|
| 7 |
+
|
| 8 |
+
class PhrasePenalty : public StatelessFeatureFunction
|
| 9 |
+
{
|
| 10 |
+
public:
|
| 11 |
+
PhrasePenalty(const std::string &line);
|
| 12 |
+
|
| 13 |
+
bool IsUseable(const FactorMask &mask) const {
|
| 14 |
+
return true;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
virtual void EvaluateInIsolation(const Phrase &source
|
| 18 |
+
, const TargetPhrase &targetPhrase
|
| 19 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 20 |
+
, ScoreComponentCollection &estimatedScores) const;
|
| 21 |
+
|
| 22 |
+
void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 23 |
+
ScoreComponentCollection* accumulator) const {
|
| 24 |
+
}
|
| 25 |
+
void EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 26 |
+
ScoreComponentCollection* accumulator) const {
|
| 27 |
+
}
|
| 28 |
+
void EvaluateWhenApplied(const Syntax::SHyperedge &hyperedge,
|
| 29 |
+
ScoreComponentCollection* accumulator) const {
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
void EvaluateWithSourceContext(const InputType &input
|
| 33 |
+
, const InputPath &inputPath
|
| 34 |
+
, const TargetPhrase &targetPhrase
|
| 35 |
+
, const StackVec *stackVec
|
| 36 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 37 |
+
, ScoreComponentCollection *estimatedScores = NULL) const {
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 41 |
+
, const TranslationOptionList &translationOptionList) const {
|
| 42 |
+
}
|
| 43 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 44 |
+
|
| 45 |
+
protected:
|
| 46 |
+
bool m_perPhraseTable;
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
} //namespace
|
| 50 |
+
|
mosesdecoder/moses/FF/ReferenceComparison.cpp
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ReferenceComparison.h"
|
| 2 |
+
|
| 3 |
+
namespace Moses
|
| 4 |
+
{
|
| 5 |
+
ReferenceComparison::ReferenceComparison(const std::string &line)
|
| 6 |
+
:StatelessFeatureFunction(0, line)
|
| 7 |
+
{
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
}
|
| 11 |
+
|
mosesdecoder/moses/FF/RuleScope.h
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <string>
|
| 3 |
+
#include "StatelessFeatureFunction.h"
|
| 4 |
+
|
| 5 |
+
namespace Moses
|
| 6 |
+
{
|
| 7 |
+
|
| 8 |
+
// Rule Scope - not quite completely implemented yet
|
| 9 |
+
class RuleScope : public StatelessFeatureFunction
|
| 10 |
+
{
|
| 11 |
+
public:
|
| 12 |
+
RuleScope(const std::string &line);
|
| 13 |
+
|
| 14 |
+
virtual bool IsUseable(const FactorMask &mask) const {
|
| 15 |
+
return true;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
virtual void EvaluateInIsolation(const Phrase &source
|
| 19 |
+
, const TargetPhrase &targetPhrase
|
| 20 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 21 |
+
, ScoreComponentCollection &estimatedScores) const;
|
| 22 |
+
|
| 23 |
+
virtual void EvaluateWithSourceContext(const InputType &input
|
| 24 |
+
, const InputPath &inputPath
|
| 25 |
+
, const TargetPhrase &targetPhrase
|
| 26 |
+
, const StackVec *stackVec
|
| 27 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 28 |
+
, ScoreComponentCollection *estimatedScores = NULL) const {
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 32 |
+
, const TranslationOptionList &translationOptionList) const {
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
virtual void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 37 |
+
ScoreComponentCollection* accumulator) const {
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
virtual void EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 41 |
+
ScoreComponentCollection* accumulator) const {
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 45 |
+
|
| 46 |
+
protected:
|
| 47 |
+
bool m_sourceSyntax;
|
| 48 |
+
bool m_perScope;
|
| 49 |
+
bool m_futureCostOnly;
|
| 50 |
+
|
| 51 |
+
bool IsGlueRule(const Phrase &source) const;
|
| 52 |
+
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
}
|
| 56 |
+
|
mosesdecoder/moses/FF/SoftMatchingFeature.h
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "moses/Word.h"
|
| 4 |
+
#include "StatelessFeatureFunction.h"
|
| 5 |
+
|
| 6 |
+
#ifdef WITH_THREADS
|
| 7 |
+
#include <boost/thread/shared_mutex.hpp>
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
+
namespace Moses
|
| 11 |
+
{
|
| 12 |
+
|
| 13 |
+
class SoftMatchingFeature : public StatelessFeatureFunction
|
| 14 |
+
{
|
| 15 |
+
public:
|
| 16 |
+
SoftMatchingFeature(const std::string &line);
|
| 17 |
+
|
| 18 |
+
bool IsUseable(const FactorMask &mask) const {
|
| 19 |
+
return true;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
virtual void EvaluateWhenApplied(const ChartHypothesis& hypo,
|
| 23 |
+
ScoreComponentCollection* accumulator) const;
|
| 24 |
+
|
| 25 |
+
void EvaluateInIsolation(const Phrase &source
|
| 26 |
+
, const TargetPhrase &targetPhrase
|
| 27 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 28 |
+
, ScoreComponentCollection &estimatedScores) const {};
|
| 29 |
+
void EvaluateWithSourceContext(const InputType &input
|
| 30 |
+
, const InputPath &inputPath
|
| 31 |
+
, const TargetPhrase &targetPhrase
|
| 32 |
+
, const StackVec *stackVec
|
| 33 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 34 |
+
, ScoreComponentCollection *estimatedScores = NULL) const {};
|
| 35 |
+
|
| 36 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 37 |
+
, const TranslationOptionList &translationOptionList) const {
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 41 |
+
ScoreComponentCollection* accumulator) const {};
|
| 42 |
+
|
| 43 |
+
bool Load(const std::string &filePath);
|
| 44 |
+
|
| 45 |
+
std::vector<std::vector<Word> >& GetSoftMatches() {
|
| 46 |
+
return m_softMatches;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
void ResizeCache() const;
|
| 50 |
+
|
| 51 |
+
const std::string& GetOrSetFeatureName(const Word& RHS, const Word& LHS) const;
|
| 52 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
private:
|
| 56 |
+
mutable std::vector<std::vector<Word> > m_softMatches; // map RHS of new rule to list of possible LHS of old rule (subtree)
|
| 57 |
+
mutable std::vector<std::vector<std::string> > m_nameCache;
|
| 58 |
+
bool m_scoreIdentical;
|
| 59 |
+
|
| 60 |
+
#ifdef WITH_THREADS
|
| 61 |
+
//reader-writer lock
|
| 62 |
+
mutable boost::shared_mutex m_accessLock;
|
| 63 |
+
#endif
|
| 64 |
+
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
}
|
| 68 |
+
|
mosesdecoder/moses/FF/SoftSourceSyntacticConstraintsFeature.cpp
ADDED
|
@@ -0,0 +1,651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <vector>
|
| 2 |
+
#include <limits>
|
| 3 |
+
#include <cassert>
|
| 4 |
+
#include "SoftSourceSyntacticConstraintsFeature.h"
|
| 5 |
+
#include "moses/StaticData.h"
|
| 6 |
+
#include "moses/InputFileStream.h"
|
| 7 |
+
#include "moses/ScoreComponentCollection.h"
|
| 8 |
+
#include "moses/Hypothesis.h"
|
| 9 |
+
#include "moses/ChartHypothesis.h"
|
| 10 |
+
#include "moses/ChartManager.h"
|
| 11 |
+
#include "moses/FactorCollection.h"
|
| 12 |
+
#include "moses/TreeInput.h"
|
| 13 |
+
#include "moses/PP/SourceLabelsPhraseProperty.h"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
using namespace std;
|
| 17 |
+
|
| 18 |
+
namespace Moses
|
| 19 |
+
{
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
SoftSourceSyntacticConstraintsFeature::SoftSourceSyntacticConstraintsFeature(const std::string &line)
|
| 23 |
+
: StatelessFeatureFunction(6, line)
|
| 24 |
+
, m_useCoreSourceLabels(false)
|
| 25 |
+
, m_useLogprobs(true)
|
| 26 |
+
, m_useSparse(false)
|
| 27 |
+
, m_useSparseLabelPairs(false)
|
| 28 |
+
, m_noMismatches(false)
|
| 29 |
+
, m_floor(1e-7)
|
| 30 |
+
{
|
| 31 |
+
VERBOSE(1, "Initializing feature " << GetScoreProducerDescription() << " ...");
|
| 32 |
+
ReadParameters();
|
| 33 |
+
VERBOSE(1, " Done.");
|
| 34 |
+
VERBOSE(1, " Config:");
|
| 35 |
+
VERBOSE(1, " Log probabilities");
|
| 36 |
+
if ( m_useLogprobs ) {
|
| 37 |
+
VERBOSE(1, " active.");
|
| 38 |
+
} else {
|
| 39 |
+
VERBOSE(1, " inactive.");
|
| 40 |
+
}
|
| 41 |
+
VERBOSE(1, " Sparse scores");
|
| 42 |
+
if ( m_useSparse ) {
|
| 43 |
+
VERBOSE(1, " active.");
|
| 44 |
+
} else {
|
| 45 |
+
VERBOSE(1, " inactive.");
|
| 46 |
+
}
|
| 47 |
+
VERBOSE(1, " Sparse label pair scores");
|
| 48 |
+
if ( m_useSparseLabelPairs ) {
|
| 49 |
+
VERBOSE(1, " active.");
|
| 50 |
+
} else {
|
| 51 |
+
VERBOSE(1, " inactive.");
|
| 52 |
+
}
|
| 53 |
+
VERBOSE(1, " Core labels");
|
| 54 |
+
if ( m_useCoreSourceLabels ) {
|
| 55 |
+
VERBOSE(1, " active.");
|
| 56 |
+
} else {
|
| 57 |
+
VERBOSE(1, " inactive.");
|
| 58 |
+
}
|
| 59 |
+
VERBOSE(1, " No mismatches");
|
| 60 |
+
if ( m_noMismatches ) {
|
| 61 |
+
VERBOSE(1, " active.");
|
| 62 |
+
} else {
|
| 63 |
+
VERBOSE(1, " inactive.");
|
| 64 |
+
}
|
| 65 |
+
VERBOSE(1, std::endl);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
void SoftSourceSyntacticConstraintsFeature::SetParameter(const std::string& key, const std::string& value)
|
| 70 |
+
{
|
| 71 |
+
if (key == "sourceLabelSetFile") {
|
| 72 |
+
m_sourceLabelSetFile = value;
|
| 73 |
+
} else if (key == "coreSourceLabelSetFile") {
|
| 74 |
+
m_coreSourceLabelSetFile = value;
|
| 75 |
+
m_useCoreSourceLabels = true;
|
| 76 |
+
} else if (key == "targetSourceLeftHandSideJointCountFile") {
|
| 77 |
+
m_targetSourceLHSJointCountFile = value;
|
| 78 |
+
} else if (key == "noMismatches") {
|
| 79 |
+
m_noMismatches = Scan<bool>(value); // for a hard constraint, allow no mismatches (also set: weights 1 0 0 0 0 0, tuneable=false)
|
| 80 |
+
} else if (key == "logProbabilities") {
|
| 81 |
+
m_useLogprobs = Scan<bool>(value);
|
| 82 |
+
} else if (key == "sparse") {
|
| 83 |
+
m_useSparse = Scan<bool>(value);
|
| 84 |
+
} else if (key == "sparseLabelPairs") {
|
| 85 |
+
m_useSparseLabelPairs = Scan<bool>(value);
|
| 86 |
+
} else {
|
| 87 |
+
StatelessFeatureFunction::SetParameter(key, value);
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
void SoftSourceSyntacticConstraintsFeature::Load(AllOptions::ptr const& opts)
|
| 92 |
+
{
|
| 93 |
+
m_options = opts;
|
| 94 |
+
// don't change the loading order!
|
| 95 |
+
LoadSourceLabelSet();
|
| 96 |
+
if (!m_coreSourceLabelSetFile.empty()) {
|
| 97 |
+
LoadCoreSourceLabelSet();
|
| 98 |
+
}
|
| 99 |
+
if (!m_targetSourceLHSJointCountFile.empty()) {
|
| 100 |
+
LoadTargetSourceLeftHandSideJointCountFile();
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
void SoftSourceSyntacticConstraintsFeature::LoadSourceLabelSet()
|
| 105 |
+
{
|
| 106 |
+
FEATUREVERBOSE(2, "Loading source label set from file " << m_sourceLabelSetFile << " ...");
|
| 107 |
+
InputFileStream inFile(m_sourceLabelSetFile);
|
| 108 |
+
|
| 109 |
+
FactorCollection &factorCollection = FactorCollection::Instance();
|
| 110 |
+
|
| 111 |
+
// read source label set
|
| 112 |
+
std::string line;
|
| 113 |
+
m_sourceLabels.clear();
|
| 114 |
+
m_sourceLabelsByIndex.clear();
|
| 115 |
+
m_sourceLabelsByIndex_RHS_1.clear();
|
| 116 |
+
m_sourceLabelsByIndex_RHS_0.clear();
|
| 117 |
+
m_sourceLabelsByIndex_LHS_1.clear();
|
| 118 |
+
m_sourceLabelsByIndex_LHS_0.clear();
|
| 119 |
+
m_sourceLabelIndexesByFactor.clear();
|
| 120 |
+
while (getline(inFile, line)) {
|
| 121 |
+
std::istringstream tokenizer(line);
|
| 122 |
+
std::string label;
|
| 123 |
+
size_t index;
|
| 124 |
+
try {
|
| 125 |
+
tokenizer >> label >> index;
|
| 126 |
+
} catch (const std::exception &e) {
|
| 127 |
+
UTIL_THROW2(GetScoreProducerDescription()
|
| 128 |
+
<< ": Error reading source label set file " << m_sourceLabelSetFile << " .");
|
| 129 |
+
}
|
| 130 |
+
std::pair< boost::unordered_map<std::string,size_t>::iterator, bool > inserted = m_sourceLabels.insert( std::pair<std::string,size_t>(label,index) );
|
| 131 |
+
UTIL_THROW_IF2(!inserted.second, GetScoreProducerDescription()
|
| 132 |
+
<< ": Source label set file " << m_sourceLabelSetFile << " should contain each syntactic label only once.");
|
| 133 |
+
|
| 134 |
+
if (index >= m_sourceLabelsByIndex.size()) {
|
| 135 |
+
m_sourceLabelsByIndex.resize(index+1);
|
| 136 |
+
m_sourceLabelsByIndex_RHS_1.resize(index+1);
|
| 137 |
+
m_sourceLabelsByIndex_RHS_0.resize(index+1);
|
| 138 |
+
m_sourceLabelsByIndex_LHS_1.resize(index+1);
|
| 139 |
+
m_sourceLabelsByIndex_LHS_0.resize(index+1);
|
| 140 |
+
}
|
| 141 |
+
m_sourceLabelsByIndex[index] = label;
|
| 142 |
+
m_sourceLabelsByIndex_RHS_1[index] = "RHS_1_" + label;
|
| 143 |
+
m_sourceLabelsByIndex_RHS_0[index] = "RHS_0_" + label;
|
| 144 |
+
m_sourceLabelsByIndex_LHS_1[index] = "LHS_1_" + label;
|
| 145 |
+
m_sourceLabelsByIndex_LHS_0[index] = "LHS_0_" + label;
|
| 146 |
+
const Factor* sourceLabelFactor = factorCollection.AddFactor(label,true);
|
| 147 |
+
m_sourceLabelIndexesByFactor[sourceLabelFactor] = index;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
inFile.Close();
|
| 151 |
+
|
| 152 |
+
std::list<std::string> specialLabels;
|
| 153 |
+
specialLabels.push_back("GlueTop");
|
| 154 |
+
specialLabels.push_back("GlueX");
|
| 155 |
+
// specialLabels.push_back("XRHS");
|
| 156 |
+
// specialLabels.push_back("XLHS");
|
| 157 |
+
for (std::list<std::string>::const_iterator iter=specialLabels.begin();
|
| 158 |
+
iter!=specialLabels.end(); ++iter) {
|
| 159 |
+
boost::unordered_map<std::string,size_t>::iterator found = m_sourceLabels.find(*iter);
|
| 160 |
+
UTIL_THROW_IF2(found == m_sourceLabels.end(), GetScoreProducerDescription()
|
| 161 |
+
<< ": Source label set file " << m_sourceLabelSetFile << " should contain an entry for the special label \"" << *iter << "\".");
|
| 162 |
+
if (!(found->first).compare("GlueTop")) {
|
| 163 |
+
m_GlueTopLabel = found->second;
|
| 164 |
+
// } else if (!(found->first).compare("XRHS")) {
|
| 165 |
+
// m_XRHSLabel = found->second;
|
| 166 |
+
// } else if (!(found->first).compare("XLHS")) {
|
| 167 |
+
// m_XLHSLabel = found->second;
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
FEATUREVERBOSE2(2, " Done." << std::endl);
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
void SoftSourceSyntacticConstraintsFeature::LoadCoreSourceLabelSet()
|
| 175 |
+
{
|
| 176 |
+
FEATUREVERBOSE(2, "Loading core source label set from file " << m_coreSourceLabelSetFile << " ...");
|
| 177 |
+
// read core source label set
|
| 178 |
+
LoadLabelSet(m_coreSourceLabelSetFile, m_coreSourceLabels);
|
| 179 |
+
FEATUREVERBOSE2(2, " Done." << std::endl);
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
void SoftSourceSyntacticConstraintsFeature::LoadLabelSet(std::string &filename,
|
| 183 |
+
boost::unordered_set<size_t> &labelSet)
|
| 184 |
+
{
|
| 185 |
+
InputFileStream inFile(filename);
|
| 186 |
+
std::string line;
|
| 187 |
+
labelSet.clear();
|
| 188 |
+
while (getline(inFile, line)) {
|
| 189 |
+
istringstream tokenizer(line);
|
| 190 |
+
std::string label;
|
| 191 |
+
tokenizer >> label;
|
| 192 |
+
boost::unordered_map<std::string,size_t>::iterator foundSourceLabelIndex = m_sourceLabels.find( label );
|
| 193 |
+
if ( foundSourceLabelIndex != m_sourceLabels.end() ) {
|
| 194 |
+
labelSet.insert(foundSourceLabelIndex->second);
|
| 195 |
+
} else {
|
| 196 |
+
FEATUREVERBOSE(2, "Ignoring undefined source label \"" << label << "\" "
|
| 197 |
+
<< "from core source label set file " << filename << "."
|
| 198 |
+
<< std::endl);
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
inFile.Close();
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
void SoftSourceSyntacticConstraintsFeature::LoadTargetSourceLeftHandSideJointCountFile()
|
| 206 |
+
{
|
| 207 |
+
|
| 208 |
+
FEATUREVERBOSE(2, "Loading target/source label joint counts from file " << m_targetSourceLHSJointCountFile << " ...");
|
| 209 |
+
InputFileStream inFile(m_targetSourceLHSJointCountFile);
|
| 210 |
+
|
| 211 |
+
for (boost::unordered_map<const Factor*, std::vector< std::pair<float,float> >* >::iterator iter=m_labelPairProbabilities.begin();
|
| 212 |
+
iter!=m_labelPairProbabilities.end(); ++iter) {
|
| 213 |
+
delete iter->second;
|
| 214 |
+
}
|
| 215 |
+
m_labelPairProbabilities.clear();
|
| 216 |
+
|
| 217 |
+
// read joint counts
|
| 218 |
+
std::string line;
|
| 219 |
+
FactorCollection &factorCollection = FactorCollection::Instance();
|
| 220 |
+
boost::unordered_map<const Factor*,float> targetLHSCounts;
|
| 221 |
+
std::vector<float> sourceLHSCounts(m_sourceLabels.size(),0.0);
|
| 222 |
+
|
| 223 |
+
while (getline(inFile, line)) {
|
| 224 |
+
istringstream tokenizer(line);
|
| 225 |
+
std::string targetLabel;
|
| 226 |
+
std::string sourceLabel;
|
| 227 |
+
float count;
|
| 228 |
+
tokenizer >> targetLabel;
|
| 229 |
+
tokenizer >> sourceLabel;
|
| 230 |
+
tokenizer >> count;
|
| 231 |
+
|
| 232 |
+
boost::unordered_map<std::string,size_t>::iterator foundSourceLabelIndex = m_sourceLabels.find( sourceLabel );
|
| 233 |
+
UTIL_THROW_IF2(foundSourceLabelIndex == m_sourceLabels.end(), GetScoreProducerDescription()
|
| 234 |
+
<< ": Target/source label joint count file " << m_targetSourceLHSJointCountFile
|
| 235 |
+
<< " contains undefined source label \"" << sourceLabel << "\".");
|
| 236 |
+
|
| 237 |
+
const Factor* targetLabelFactor = factorCollection.AddFactor(targetLabel,true);
|
| 238 |
+
|
| 239 |
+
sourceLHSCounts[foundSourceLabelIndex->second] += count;
|
| 240 |
+
std::pair< boost::unordered_map<const Factor*,float >::iterator, bool > insertedTargetLHSCount =
|
| 241 |
+
targetLHSCounts.insert( std::pair<const Factor*,float>(targetLabelFactor,count) );
|
| 242 |
+
if (!insertedTargetLHSCount.second) {
|
| 243 |
+
(insertedTargetLHSCount.first)->second += count;
|
| 244 |
+
boost::unordered_map<const Factor*, std::vector< std::pair<float,float> >* >::iterator jointCountIt =
|
| 245 |
+
m_labelPairProbabilities.find( targetLabelFactor );
|
| 246 |
+
assert(jointCountIt != m_labelPairProbabilities.end());
|
| 247 |
+
(jointCountIt->second)->at(foundSourceLabelIndex->second).first += count;
|
| 248 |
+
(jointCountIt->second)->at(foundSourceLabelIndex->second).second += count;
|
| 249 |
+
} else {
|
| 250 |
+
std::pair<float,float> init(0.0,0.0);
|
| 251 |
+
std::vector< std::pair<float,float> >* sourceVector = new std::vector< std::pair<float,float> >(m_sourceLabels.size(),init);
|
| 252 |
+
sourceVector->at(foundSourceLabelIndex->second) = std::pair<float,float>(count,count);
|
| 253 |
+
std::pair< boost::unordered_map<const Factor*, std::vector< std::pair<float,float> >* >::iterator, bool > insertedJointCount =
|
| 254 |
+
m_labelPairProbabilities.insert( std::pair<const Factor*, std::vector< std::pair<float,float> >* >(targetLabelFactor,sourceVector) );
|
| 255 |
+
UTIL_THROW_IF2(!insertedJointCount.second, GetScoreProducerDescription()
|
| 256 |
+
<< ": Loading target/source label joint counts from file " << m_targetSourceLHSJointCountFile << " failed.");
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
// normalization
|
| 261 |
+
for (boost::unordered_map<const Factor*, std::vector< std::pair<float,float> >* >::iterator iter=m_labelPairProbabilities.begin();
|
| 262 |
+
iter!=m_labelPairProbabilities.end(); ++iter) {
|
| 263 |
+
float targetLHSCount = 0;
|
| 264 |
+
boost::unordered_map<const Factor*,float >::const_iterator targetLHSCountIt = targetLHSCounts.find( iter->first );
|
| 265 |
+
if ( targetLHSCountIt != targetLHSCounts.end() ) {
|
| 266 |
+
targetLHSCount = targetLHSCountIt->second;
|
| 267 |
+
}
|
| 268 |
+
std::vector< std::pair<float,float> > &probabilities = *(iter->second);
|
| 269 |
+
for (size_t index=0; index<probabilities.size(); ++index) {
|
| 270 |
+
|
| 271 |
+
if ( probabilities[index].first != 0 ) {
|
| 272 |
+
assert(targetLHSCount != 0);
|
| 273 |
+
probabilities[index].first /= targetLHSCount;
|
| 274 |
+
}
|
| 275 |
+
if ( probabilities[index].second != 0 ) {
|
| 276 |
+
assert(sourceLHSCounts[index] != 0);
|
| 277 |
+
probabilities[index].second /= sourceLHSCounts[index];
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
inFile.Close();
|
| 283 |
+
FEATUREVERBOSE2(2, " Done." << std::endl);
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const InputType &input
|
| 288 |
+
, const InputPath &inputPath
|
| 289 |
+
, const TargetPhrase &targetPhrase
|
| 290 |
+
, const StackVec *stackVec
|
| 291 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 292 |
+
, ScoreComponentCollection *estimatedScores) const
|
| 293 |
+
{
|
| 294 |
+
assert(stackVec);
|
| 295 |
+
|
| 296 |
+
IFFEATUREVERBOSE(3) {
|
| 297 |
+
FEATUREVERBOSE(3, targetPhrase << std::endl);
|
| 298 |
+
FEATUREVERBOSE(3, inputPath << std::endl);
|
| 299 |
+
for (size_t i = 0; i < stackVec->size(); ++i) {
|
| 300 |
+
const ChartCellLabel &cell = *stackVec->at(i);
|
| 301 |
+
const Range &ntRange = cell.GetCoverage();
|
| 302 |
+
FEATUREVERBOSE(3, "stackVec[ " << i << " ] : " << ntRange.GetStartPos() << " - " << ntRange.GetEndPos() << std::endl);
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
for (AlignmentInfo::const_iterator it=targetPhrase.GetAlignNonTerm().begin();
|
| 306 |
+
it!=targetPhrase.GetAlignNonTerm().end(); ++it) {
|
| 307 |
+
FEATUREVERBOSE(3, "alignNonTerm " << it->first << " " << it->second << std::endl);
|
| 308 |
+
}
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
// dense scores
|
| 312 |
+
std::vector<float> newScores(m_numScoreComponents,0);
|
| 313 |
+
|
| 314 |
+
const TreeInput& treeInput = static_cast<const TreeInput&>(input);
|
| 315 |
+
// const StaticData& staticData = StaticData::Instance();
|
| 316 |
+
// const Word& outputDefaultNonTerminal = staticData.GetOutputDefaultNonTerminal();
|
| 317 |
+
|
| 318 |
+
size_t nNTs = 1;
|
| 319 |
+
bool treeInputMismatchLHSBinary = true;
|
| 320 |
+
size_t treeInputMismatchRHSCount = 0;
|
| 321 |
+
bool hasCompleteTreeInputMatch = false;
|
| 322 |
+
float ruleLabelledProbability = 0.0;
|
| 323 |
+
float treeInputMatchProbRHS = 0.0;
|
| 324 |
+
float treeInputMatchProbLHS = 0.0;
|
| 325 |
+
|
| 326 |
+
// read SourceLabels property
|
| 327 |
+
const Factor* targetLHS = targetPhrase.GetTargetLHS()[0];
|
| 328 |
+
bool isGlueGrammarRule = false;
|
| 329 |
+
bool isUnkRule = false;
|
| 330 |
+
|
| 331 |
+
if (const PhraseProperty *property = targetPhrase.GetProperty("SourceLabels")) {
|
| 332 |
+
|
| 333 |
+
const SourceLabelsPhraseProperty *sourceLabelsPhraseProperty = static_cast<const SourceLabelsPhraseProperty*>(property);
|
| 334 |
+
|
| 335 |
+
nNTs = sourceLabelsPhraseProperty->GetNumberOfNonTerminals();
|
| 336 |
+
float totalCount = sourceLabelsPhraseProperty->GetTotalCount();
|
| 337 |
+
|
| 338 |
+
// prepare for input tree label matching
|
| 339 |
+
std::vector< boost::unordered_set<size_t> > treeInputLabelsRHS(nNTs-1);
|
| 340 |
+
boost::unordered_set<size_t> treeInputLabelsLHS;
|
| 341 |
+
|
| 342 |
+
// get index map for underlying hypotheses
|
| 343 |
+
const Range& range = inputPath.GetWordsRange();
|
| 344 |
+
size_t startPos = range.GetStartPos();
|
| 345 |
+
size_t endPos = range.GetEndPos();
|
| 346 |
+
const Phrase *sourcePhrase = targetPhrase.GetRuleSource();
|
| 347 |
+
|
| 348 |
+
if (nNTs > 1) { // rule has right-hand side non-terminals, i.e. it's a hierarchical rule
|
| 349 |
+
size_t nonTerminalNumber = 0;
|
| 350 |
+
size_t sourceSentPos = startPos;
|
| 351 |
+
|
| 352 |
+
for (size_t sourcePhrasePos=0; sourcePhrasePos<sourcePhrase->GetSize(); ++sourcePhrasePos) {
|
| 353 |
+
// consult rule for either word or non-terminal
|
| 354 |
+
const Word &word = sourcePhrase->GetWord(sourcePhrasePos);
|
| 355 |
+
size_t symbolStartPos = sourceSentPos;
|
| 356 |
+
size_t symbolEndPos = sourceSentPos;
|
| 357 |
+
if ( word.IsNonTerminal() ) {
|
| 358 |
+
// retrieve information that is required for input tree label matching (RHS)
|
| 359 |
+
const ChartCellLabel &cell = *stackVec->at(nonTerminalNumber);
|
| 360 |
+
const Range& prevWordsRange = cell.GetCoverage();
|
| 361 |
+
symbolStartPos = prevWordsRange.GetStartPos();
|
| 362 |
+
symbolEndPos = prevWordsRange.GetEndPos();
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
const NonTerminalSet& treeInputLabels = treeInput.GetLabelSet(symbolStartPos,symbolEndPos);
|
| 366 |
+
|
| 367 |
+
for (NonTerminalSet::const_iterator treeInputLabelsIt = treeInputLabels.begin();
|
| 368 |
+
treeInputLabelsIt != treeInputLabels.end(); ++treeInputLabelsIt) {
|
| 369 |
+
if (*treeInputLabelsIt != m_options->syntax.output_default_non_terminal) {
|
| 370 |
+
boost::unordered_map<const Factor*,size_t>::const_iterator foundTreeInputLabel
|
| 371 |
+
= m_sourceLabelIndexesByFactor.find((*treeInputLabelsIt)[0]);
|
| 372 |
+
if (foundTreeInputLabel != m_sourceLabelIndexesByFactor.end()) {
|
| 373 |
+
size_t treeInputLabelIndex = foundTreeInputLabel->second;
|
| 374 |
+
treeInputLabelsRHS[sourcePhrasePos].insert(treeInputLabelIndex);
|
| 375 |
+
}
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
if ( word.IsNonTerminal() ) {
|
| 380 |
+
++nonTerminalNumber;
|
| 381 |
+
}
|
| 382 |
+
sourceSentPos = symbolEndPos + 1;
|
| 383 |
+
}
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
// retrieve information that is required for input tree label matching (LHS)
|
| 387 |
+
const NonTerminalSet& treeInputLabels = treeInput.GetLabelSet(startPos,endPos);
|
| 388 |
+
|
| 389 |
+
for (NonTerminalSet::const_iterator treeInputLabelsIt = treeInputLabels.begin();
|
| 390 |
+
treeInputLabelsIt != treeInputLabels.end(); ++treeInputLabelsIt) {
|
| 391 |
+
if (*treeInputLabelsIt != m_options->syntax.output_default_non_terminal) {
|
| 392 |
+
boost::unordered_map<const Factor*,size_t>::const_iterator foundTreeInputLabel
|
| 393 |
+
= m_sourceLabelIndexesByFactor.find((*treeInputLabelsIt)[0]);
|
| 394 |
+
if (foundTreeInputLabel != m_sourceLabelIndexesByFactor.end()) {
|
| 395 |
+
size_t treeInputLabelIndex = foundTreeInputLabel->second;
|
| 396 |
+
treeInputLabelsLHS.insert(treeInputLabelIndex);
|
| 397 |
+
}
|
| 398 |
+
}
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
// inspect source-labelled rule items
|
| 403 |
+
|
| 404 |
+
std::vector< boost::unordered_set<size_t> > sparseScoredTreeInputLabelsRHS(nNTs-1);
|
| 405 |
+
boost::unordered_set<size_t> sparseScoredTreeInputLabelsLHS;
|
| 406 |
+
|
| 407 |
+
std::vector<bool> sourceLabelSeenAsLHS(m_sourceLabels.size(),false);
|
| 408 |
+
std::vector<bool> treeInputMatchRHSCountByNonTerminal(nNTs-1,false);
|
| 409 |
+
std::vector<float> treeInputMatchProbRHSByNonTerminal(nNTs-1,0.0);
|
| 410 |
+
|
| 411 |
+
const std::list<SourceLabelsPhrasePropertyItem> &sourceLabelItems = sourceLabelsPhraseProperty->GetSourceLabelItems();
|
| 412 |
+
|
| 413 |
+
for (std::list<SourceLabelsPhrasePropertyItem>::const_iterator sourceLabelItem = sourceLabelItems.begin();
|
| 414 |
+
sourceLabelItem != sourceLabelItems.end() && !hasCompleteTreeInputMatch; ++sourceLabelItem) {
|
| 415 |
+
|
| 416 |
+
const std::list<size_t> &sourceLabelsRHS = sourceLabelItem->GetSourceLabelsRHS();
|
| 417 |
+
const std::list< std::pair<size_t,float> > &sourceLabelsLHSList = sourceLabelItem->GetSourceLabelsLHSList();
|
| 418 |
+
float sourceLabelsRHSCount = sourceLabelItem->GetSourceLabelsRHSCount();
|
| 419 |
+
|
| 420 |
+
assert(sourceLabelsRHS.size() == nNTs-1);
|
| 421 |
+
|
| 422 |
+
bool currentSourceLabelItemIsCompleteTreeInputMatch = true;
|
| 423 |
+
|
| 424 |
+
size_t nonTerminalNumber=0;
|
| 425 |
+
for (std::list<size_t>::const_iterator sourceLabelsRHSIt = sourceLabelsRHS.begin();
|
| 426 |
+
sourceLabelsRHSIt != sourceLabelsRHS.end(); ++sourceLabelsRHSIt, ++nonTerminalNumber) {
|
| 427 |
+
|
| 428 |
+
if (treeInputLabelsRHS[nonTerminalNumber].find(*sourceLabelsRHSIt) != treeInputLabelsRHS[nonTerminalNumber].end()) {
|
| 429 |
+
|
| 430 |
+
treeInputMatchRHSCountByNonTerminal[nonTerminalNumber] = true;
|
| 431 |
+
treeInputMatchProbRHSByNonTerminal[nonTerminalNumber] += sourceLabelsRHSCount; // to be normalized later on
|
| 432 |
+
|
| 433 |
+
if ( m_useSparse &&
|
| 434 |
+
(!m_useCoreSourceLabels || m_coreSourceLabels.find(*sourceLabelsRHSIt) != m_coreSourceLabels.end()) ) {
|
| 435 |
+
// score sparse features: RHS match
|
| 436 |
+
if (sparseScoredTreeInputLabelsRHS[nonTerminalNumber].find(*sourceLabelsRHSIt) == sparseScoredTreeInputLabelsRHS[nonTerminalNumber].end()) {
|
| 437 |
+
// (only if no match has been scored for this tree input label and rule non-terminal with a previous sourceLabelItem)
|
| 438 |
+
float score_RHS_1 = (float)1/treeInputLabelsRHS[nonTerminalNumber].size();
|
| 439 |
+
scoreBreakdown.PlusEquals(this,
|
| 440 |
+
m_sourceLabelsByIndex_RHS_1[*sourceLabelsRHSIt],
|
| 441 |
+
score_RHS_1);
|
| 442 |
+
sparseScoredTreeInputLabelsRHS[nonTerminalNumber].insert(*sourceLabelsRHSIt);
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
} else {
|
| 447 |
+
|
| 448 |
+
currentSourceLabelItemIsCompleteTreeInputMatch = false;
|
| 449 |
+
|
| 450 |
+
}
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
for (std::list< std::pair<size_t,float> >::const_iterator sourceLabelsLHSIt = sourceLabelsLHSList.begin();
|
| 454 |
+
sourceLabelsLHSIt != sourceLabelsLHSList.end(); ++sourceLabelsLHSIt) {
|
| 455 |
+
|
| 456 |
+
if ( sourceLabelsLHSIt->first == m_GlueTopLabel ) {
|
| 457 |
+
isGlueGrammarRule = true;
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
if (treeInputLabelsLHS.find(sourceLabelsLHSIt->first) != treeInputLabelsLHS.end()) {
|
| 461 |
+
|
| 462 |
+
treeInputMismatchLHSBinary = false;
|
| 463 |
+
treeInputMatchProbLHS += sourceLabelsLHSIt->second; // to be normalized later on
|
| 464 |
+
|
| 465 |
+
if ( m_useSparse &&
|
| 466 |
+
(!m_useCoreSourceLabels || m_coreSourceLabels.find(sourceLabelsLHSIt->first) != m_coreSourceLabels.end()) ) {
|
| 467 |
+
// score sparse features: LHS match
|
| 468 |
+
if (sparseScoredTreeInputLabelsLHS.find(sourceLabelsLHSIt->first) == sparseScoredTreeInputLabelsLHS.end()) {
|
| 469 |
+
// (only if no match has been scored for this tree input label and rule non-terminal with a previous sourceLabelItem)
|
| 470 |
+
float score_LHS_1 = (float)1/treeInputLabelsLHS.size();
|
| 471 |
+
scoreBreakdown.PlusEquals(this,
|
| 472 |
+
m_sourceLabelsByIndex_LHS_1[sourceLabelsLHSIt->first],
|
| 473 |
+
score_LHS_1);
|
| 474 |
+
sparseScoredTreeInputLabelsLHS.insert(sourceLabelsLHSIt->first);
|
| 475 |
+
}
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
if ( currentSourceLabelItemIsCompleteTreeInputMatch ) {
|
| 479 |
+
ruleLabelledProbability += sourceLabelsLHSIt->second; // to be normalized later on
|
| 480 |
+
hasCompleteTreeInputMatch = true;
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
}
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
// normalization
|
| 488 |
+
for (std::vector<float>::iterator treeInputMatchProbRHSByNonTerminalIt = treeInputMatchProbRHSByNonTerminal.begin();
|
| 489 |
+
treeInputMatchProbRHSByNonTerminalIt != treeInputMatchProbRHSByNonTerminal.end(); ++treeInputMatchProbRHSByNonTerminalIt) {
|
| 490 |
+
*treeInputMatchProbRHSByNonTerminalIt /= totalCount;
|
| 491 |
+
if ( *treeInputMatchProbRHSByNonTerminalIt != 0 ) {
|
| 492 |
+
treeInputMatchProbRHS += ( m_useLogprobs ? TransformScore(*treeInputMatchProbRHSByNonTerminalIt) : *treeInputMatchProbRHSByNonTerminalIt );
|
| 493 |
+
}
|
| 494 |
+
}
|
| 495 |
+
treeInputMatchProbLHS /= totalCount;
|
| 496 |
+
ruleLabelledProbability /= totalCount;
|
| 497 |
+
|
| 498 |
+
// input tree matching (RHS)
|
| 499 |
+
if ( !hasCompleteTreeInputMatch ) {
|
| 500 |
+
treeInputMismatchRHSCount = nNTs-1;
|
| 501 |
+
for (std::vector<bool>::const_iterator treeInputMatchRHSCountByNonTerminalIt = treeInputMatchRHSCountByNonTerminal.begin();
|
| 502 |
+
treeInputMatchRHSCountByNonTerminalIt != treeInputMatchRHSCountByNonTerminal.end(); ++treeInputMatchRHSCountByNonTerminalIt) {
|
| 503 |
+
if (*treeInputMatchRHSCountByNonTerminalIt) {
|
| 504 |
+
--treeInputMismatchRHSCount;
|
| 505 |
+
}
|
| 506 |
+
}
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
// score sparse features: mismatches
|
| 510 |
+
if ( m_useSparse ) {
|
| 511 |
+
|
| 512 |
+
// RHS
|
| 513 |
+
|
| 514 |
+
for (size_t nonTerminalNumber = 0; nonTerminalNumber < nNTs-1; ++nonTerminalNumber) {
|
| 515 |
+
// nNTs-1 because nNTs also counts the left-hand side non-terminal
|
| 516 |
+
|
| 517 |
+
float score_RHS_0 = (float)1/treeInputLabelsRHS[nonTerminalNumber].size();
|
| 518 |
+
for (boost::unordered_set<size_t>::const_iterator treeInputLabelsRHSIt = treeInputLabelsRHS[nonTerminalNumber].begin();
|
| 519 |
+
treeInputLabelsRHSIt != treeInputLabelsRHS[nonTerminalNumber].end(); ++treeInputLabelsRHSIt) {
|
| 520 |
+
|
| 521 |
+
if ( !m_useCoreSourceLabels || m_coreSourceLabels.find(*treeInputLabelsRHSIt) != m_coreSourceLabels.end() ) {
|
| 522 |
+
|
| 523 |
+
if (sparseScoredTreeInputLabelsRHS[nonTerminalNumber].find(*treeInputLabelsRHSIt) == sparseScoredTreeInputLabelsRHS[nonTerminalNumber].end()) {
|
| 524 |
+
// score sparse features: RHS mismatch
|
| 525 |
+
scoreBreakdown.PlusEquals(this,
|
| 526 |
+
m_sourceLabelsByIndex_RHS_0[*treeInputLabelsRHSIt],
|
| 527 |
+
score_RHS_0);
|
| 528 |
+
}
|
| 529 |
+
}
|
| 530 |
+
}
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
// LHS
|
| 534 |
+
|
| 535 |
+
float score_LHS_0 = (float)1/treeInputLabelsLHS.size();
|
| 536 |
+
for (boost::unordered_set<size_t>::const_iterator treeInputLabelsLHSIt = treeInputLabelsLHS.begin();
|
| 537 |
+
treeInputLabelsLHSIt != treeInputLabelsLHS.end(); ++treeInputLabelsLHSIt) {
|
| 538 |
+
|
| 539 |
+
if ( !m_useCoreSourceLabels || m_coreSourceLabels.find(*treeInputLabelsLHSIt) != m_coreSourceLabels.end() ) {
|
| 540 |
+
|
| 541 |
+
if (sparseScoredTreeInputLabelsLHS.find(*treeInputLabelsLHSIt) == sparseScoredTreeInputLabelsLHS.end()) {
|
| 542 |
+
// score sparse features: RHS mismatch
|
| 543 |
+
scoreBreakdown.PlusEquals(this,
|
| 544 |
+
m_sourceLabelsByIndex_LHS_0[*treeInputLabelsLHSIt],
|
| 545 |
+
score_LHS_0);
|
| 546 |
+
}
|
| 547 |
+
}
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
if ( m_useSparseLabelPairs && !isGlueGrammarRule ) {
|
| 553 |
+
|
| 554 |
+
// left-hand side label pairs (target NT, source NT)
|
| 555 |
+
float t2sLabelsScore = 0.0;
|
| 556 |
+
float s2tLabelsScore = 0.0;
|
| 557 |
+
for (boost::unordered_set<size_t>::const_iterator treeInputLabelsLHSIt = treeInputLabelsLHS.begin();
|
| 558 |
+
treeInputLabelsLHSIt != treeInputLabelsLHS.end(); ++treeInputLabelsLHSIt) {
|
| 559 |
+
|
| 560 |
+
scoreBreakdown.PlusEquals(this,
|
| 561 |
+
"LHSPAIR_" + targetLHS->GetString().as_string() + "_" + m_sourceLabelsByIndex[*treeInputLabelsLHSIt],
|
| 562 |
+
(float)1/treeInputLabelsLHS.size());
|
| 563 |
+
|
| 564 |
+
if (!m_targetSourceLHSJointCountFile.empty()) {
|
| 565 |
+
std::pair<float,float> probPair = GetLabelPairProbabilities( targetLHS, *treeInputLabelsLHSIt);
|
| 566 |
+
t2sLabelsScore += probPair.first;
|
| 567 |
+
s2tLabelsScore += probPair.second;
|
| 568 |
+
}
|
| 569 |
+
}
|
| 570 |
+
if ( treeInputLabelsLHS.size() == 0 ) {
|
| 571 |
+
scoreBreakdown.PlusEquals(this,
|
| 572 |
+
"LHSPAIR_" + targetLHS->GetString().as_string() + "_"
|
| 573 |
+
+ m_options->syntax.output_default_non_terminal[0]
|
| 574 |
+
->GetString().as_string(),
|
| 575 |
+
1);
|
| 576 |
+
if (!m_targetSourceLHSJointCountFile.empty()) {
|
| 577 |
+
t2sLabelsScore = TransformScore(m_floor);
|
| 578 |
+
s2tLabelsScore = TransformScore(m_floor);
|
| 579 |
+
}
|
| 580 |
+
} else {
|
| 581 |
+
if (!m_targetSourceLHSJointCountFile.empty()) {
|
| 582 |
+
float norm = TransformScore(treeInputLabelsLHS.size());
|
| 583 |
+
t2sLabelsScore = TransformScore(t2sLabelsScore) - norm;
|
| 584 |
+
s2tLabelsScore = TransformScore(s2tLabelsScore) - norm;
|
| 585 |
+
}
|
| 586 |
+
}
|
| 587 |
+
if (!m_targetSourceLHSJointCountFile.empty()) {
|
| 588 |
+
scoreBreakdown.PlusEquals(this, "LHST2S", t2sLabelsScore);
|
| 589 |
+
scoreBreakdown.PlusEquals(this, "LHSS2T", s2tLabelsScore);
|
| 590 |
+
}
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
} else {
|
| 594 |
+
|
| 595 |
+
// abort with error message if the phrase does not translate an unknown word
|
| 596 |
+
UTIL_THROW_IF2(!targetPhrase.GetWord(0).IsOOV(), GetScoreProducerDescription()
|
| 597 |
+
<< ": Missing SourceLabels property. "
|
| 598 |
+
<< "Please check phrase table and glue rules.");
|
| 599 |
+
|
| 600 |
+
// unknown word
|
| 601 |
+
isUnkRule = true;
|
| 602 |
+
// ruleLabelledProbability = 1;
|
| 603 |
+
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
// add scores
|
| 607 |
+
|
| 608 |
+
// input tree matching
|
| 609 |
+
newScores[0] = !hasCompleteTreeInputMatch;
|
| 610 |
+
if ( m_noMismatches ) {
|
| 611 |
+
newScores[0] = ( (hasCompleteTreeInputMatch || isGlueGrammarRule || isUnkRule) ? 0 : -std::numeric_limits<float>::infinity() );
|
| 612 |
+
}
|
| 613 |
+
newScores[1] = treeInputMismatchLHSBinary;
|
| 614 |
+
newScores[2] = treeInputMismatchRHSCount;
|
| 615 |
+
|
| 616 |
+
if ( m_useLogprobs ) {
|
| 617 |
+
if ( ruleLabelledProbability != 0 ) {
|
| 618 |
+
ruleLabelledProbability = TransformScore(ruleLabelledProbability);
|
| 619 |
+
}
|
| 620 |
+
if ( treeInputMatchProbLHS != 0 ) {
|
| 621 |
+
treeInputMatchProbLHS = TransformScore(treeInputMatchProbLHS);
|
| 622 |
+
}
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
newScores[3] = ruleLabelledProbability;
|
| 626 |
+
newScores[4] = treeInputMatchProbLHS;
|
| 627 |
+
newScores[5] = treeInputMatchProbRHS;
|
| 628 |
+
|
| 629 |
+
scoreBreakdown.PlusEquals(this, newScores);
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
std::pair<float,float> SoftSourceSyntacticConstraintsFeature::GetLabelPairProbabilities(
|
| 634 |
+
const Factor* target,
|
| 635 |
+
const size_t source) const
|
| 636 |
+
{
|
| 637 |
+
boost::unordered_map<const Factor*, std::vector< std::pair<float,float> >* >::const_iterator found =
|
| 638 |
+
m_labelPairProbabilities.find(target);
|
| 639 |
+
if ( found == m_labelPairProbabilities.end() ) {
|
| 640 |
+
return std::pair<float,float>(m_floor,m_floor); // floor values
|
| 641 |
+
}
|
| 642 |
+
std::pair<float,float> ret = found->second->at(source);
|
| 643 |
+
if ( ret == std::pair<float,float>(0,0) ) {
|
| 644 |
+
return std::pair<float,float>(m_floor,m_floor); // floor values
|
| 645 |
+
}
|
| 646 |
+
return ret;
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
}
|
| 651 |
+
|
mosesdecoder/moses/FF/SourceGHKMTreeInputMatchFeature.cpp
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <map>
|
| 2 |
+
#include <vector>
|
| 3 |
+
#include <cassert>
|
| 4 |
+
#include "SourceGHKMTreeInputMatchFeature.h"
|
| 5 |
+
#include "moses/StaticData.h"
|
| 6 |
+
#include "moses/InputFileStream.h"
|
| 7 |
+
#include "moses/ScoreComponentCollection.h"
|
| 8 |
+
#include "moses/Hypothesis.h"
|
| 9 |
+
#include "moses/ChartHypothesis.h"
|
| 10 |
+
#include "moses/Factor.h"
|
| 11 |
+
#include "moses/FactorCollection.h"
|
| 12 |
+
#include "moses/InputPath.h"
|
| 13 |
+
#include "moses/TreeInput.h"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
using namespace std;
|
| 17 |
+
|
| 18 |
+
namespace Moses
|
| 19 |
+
{
|
| 20 |
+
|
| 21 |
+
SourceGHKMTreeInputMatchFeature::SourceGHKMTreeInputMatchFeature(const std::string &line)
|
| 22 |
+
: StatelessFeatureFunction(2, line)
|
| 23 |
+
{
|
| 24 |
+
std::cerr << GetScoreProducerDescription() << "Initializing feature...";
|
| 25 |
+
ReadParameters();
|
| 26 |
+
std::cerr << " Done." << std::endl;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
void SourceGHKMTreeInputMatchFeature::SetParameter(const std::string& key, const std::string& value)
|
| 30 |
+
{
|
| 31 |
+
UTIL_THROW(util::Exception, GetScoreProducerDescription() << ": Unknown parameter " << key << "=" << value);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// assumes that source-side syntax labels are stored in the target non-terminal field of the rules
|
| 35 |
+
void SourceGHKMTreeInputMatchFeature::EvaluateWithSourceContext(const InputType &input
|
| 36 |
+
, const InputPath &inputPath
|
| 37 |
+
, const TargetPhrase &targetPhrase
|
| 38 |
+
, const StackVec *stackVec
|
| 39 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 40 |
+
, ScoreComponentCollection *estimatedScores) const
|
| 41 |
+
{
|
| 42 |
+
const Range& range = inputPath.GetWordsRange();
|
| 43 |
+
size_t startPos = range.GetStartPos();
|
| 44 |
+
size_t endPos = range.GetEndPos();
|
| 45 |
+
const TreeInput& treeInput = static_cast<const TreeInput&>(input);
|
| 46 |
+
const NonTerminalSet& treeInputLabels = treeInput.GetLabelSet(startPos,endPos);
|
| 47 |
+
const Word& lhsLabel = targetPhrase.GetTargetLHS();
|
| 48 |
+
|
| 49 |
+
const StaticData& staticData = StaticData::Instance();
|
| 50 |
+
|
| 51 |
+
std::vector<float> newScores(m_numScoreComponents,0.0);
|
| 52 |
+
// m_numScoreComponents == 2 // first fires for matches, second for mismatches
|
| 53 |
+
|
| 54 |
+
if ( (treeInputLabels.find(lhsLabel) != treeInputLabels.end())
|
| 55 |
+
&& (lhsLabel != m_options->syntax.output_default_non_terminal) ) {
|
| 56 |
+
// match
|
| 57 |
+
newScores[0] = 1.0;
|
| 58 |
+
} else {
|
| 59 |
+
// mismatch
|
| 60 |
+
newScores[1] = 1.0;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
scoreBreakdown.PlusEquals(this, newScores);
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
void
|
| 67 |
+
SourceGHKMTreeInputMatchFeature::
|
| 68 |
+
Load(AllOptions::ptr const& opts)
|
| 69 |
+
{
|
| 70 |
+
m_options = opts;
|
| 71 |
+
// m_output_default_nonterminal = opts->syntax.output_default_non_terminal;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
}
|
| 75 |
+
|
mosesdecoder/moses/FF/SourceGHKMTreeInputMatchFeature.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "StatelessFeatureFunction.h"
|
| 4 |
+
#include "moses/parameters/AllOptions.h"
|
| 5 |
+
|
| 6 |
+
namespace Moses
|
| 7 |
+
{
|
| 8 |
+
|
| 9 |
+
// assumes that source-side syntax labels are stored in the target non-terminal field of the rules
|
| 10 |
+
class SourceGHKMTreeInputMatchFeature : public StatelessFeatureFunction
|
| 11 |
+
{
|
| 12 |
+
public:
|
| 13 |
+
SourceGHKMTreeInputMatchFeature(const std::string &line);
|
| 14 |
+
|
| 15 |
+
bool IsUseable(const FactorMask &mask) const {
|
| 16 |
+
return true;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 20 |
+
|
| 21 |
+
void EvaluateInIsolation(const Phrase &source
|
| 22 |
+
, const TargetPhrase &targetPhrase
|
| 23 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 24 |
+
, ScoreComponentCollection &estimatedScores) const {};
|
| 25 |
+
|
| 26 |
+
void EvaluateWithSourceContext(const InputType &input
|
| 27 |
+
, const InputPath &inputPath
|
| 28 |
+
, const TargetPhrase &targetPhrase
|
| 29 |
+
, const StackVec *stackVec
|
| 30 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 31 |
+
, ScoreComponentCollection *estimatedScores = NULL) const;
|
| 32 |
+
|
| 33 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 34 |
+
, const TranslationOptionList &translationOptionList) const {
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 39 |
+
ScoreComponentCollection* accumulator) const {};
|
| 40 |
+
|
| 41 |
+
void EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 42 |
+
ScoreComponentCollection* accumulator) const {};
|
| 43 |
+
|
| 44 |
+
void Load(AllOptions::ptr const& opts);
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
}
|
| 49 |
+
|
mosesdecoder/moses/FF/SourceWordDeletionFeature.h
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <string>
|
| 4 |
+
#include <boost/unordered_set.hpp>
|
| 5 |
+
|
| 6 |
+
#include "StatelessFeatureFunction.h"
|
| 7 |
+
#include "moses/FactorCollection.h"
|
| 8 |
+
#include "moses/AlignmentInfo.h"
|
| 9 |
+
|
| 10 |
+
namespace Moses
|
| 11 |
+
{
|
| 12 |
+
|
| 13 |
+
/** Sets the features for source word deletion
|
| 14 |
+
*/
|
| 15 |
+
class SourceWordDeletionFeature : public StatelessFeatureFunction
|
| 16 |
+
{
|
| 17 |
+
private:
|
| 18 |
+
boost::unordered_set<std::string> m_vocab;
|
| 19 |
+
FactorType m_factorType;
|
| 20 |
+
bool m_unrestricted;
|
| 21 |
+
std::string m_filename;
|
| 22 |
+
|
| 23 |
+
public:
|
| 24 |
+
SourceWordDeletionFeature(const std::string &line);
|
| 25 |
+
|
| 26 |
+
void Load(AllOptions::ptr const& opts);
|
| 27 |
+
|
| 28 |
+
bool IsUseable(const FactorMask &mask) const;
|
| 29 |
+
|
| 30 |
+
void EvaluateInIsolation(const Phrase &source
|
| 31 |
+
, const TargetPhrase &targetPhrase
|
| 32 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 33 |
+
, ScoreComponentCollection &estimatedScores) const;
|
| 34 |
+
void EvaluateWithSourceContext(const InputType &input
|
| 35 |
+
, const InputPath &inputPath
|
| 36 |
+
, const TargetPhrase &targetPhrase
|
| 37 |
+
, const StackVec *stackVec
|
| 38 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 39 |
+
, ScoreComponentCollection *estimatedScores = NULL) const {
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 43 |
+
, const TranslationOptionList &translationOptionList) const {
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 48 |
+
ScoreComponentCollection* accumulator) const {
|
| 49 |
+
}
|
| 50 |
+
void EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 51 |
+
ScoreComponentCollection* accumulator) const {
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
void ComputeFeatures(const Phrase &source,
|
| 55 |
+
const TargetPhrase& targetPhrase,
|
| 56 |
+
ScoreComponentCollection* accumulator,
|
| 57 |
+
const AlignmentInfo &alignmentInfo) const;
|
| 58 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 59 |
+
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
}
|
| 63 |
+
|
mosesdecoder/moses/FF/SpanLength.cpp
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <boost/shared_ptr.hpp>
|
| 2 |
+
#include "SpanLength.h"
|
| 3 |
+
#include "moses/StaticData.h"
|
| 4 |
+
#include "moses/Word.h"
|
| 5 |
+
#include "moses/ChartCellLabel.h"
|
| 6 |
+
#include "moses/Range.h"
|
| 7 |
+
#include "moses/StackVec.h"
|
| 8 |
+
#include "moses/TargetPhrase.h"
|
| 9 |
+
#include "moses/PP/PhraseProperty.h"
|
| 10 |
+
#include "moses/PP/SpanLengthPhraseProperty.h"
|
| 11 |
+
|
| 12 |
+
using namespace std;
|
| 13 |
+
|
| 14 |
+
namespace Moses
|
| 15 |
+
{
|
| 16 |
+
SpanLength::SpanLength(const std::string &line)
|
| 17 |
+
:StatelessFeatureFunction(1, line)
|
| 18 |
+
,m_smoothingMethod(None)
|
| 19 |
+
,m_const(0)
|
| 20 |
+
{
|
| 21 |
+
ReadParameters();
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
void SpanLength::EvaluateInIsolation(const Phrase &source
|
| 25 |
+
, const TargetPhrase &targetPhrase
|
| 26 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 27 |
+
, ScoreComponentCollection &estimatedScores) const
|
| 28 |
+
{
|
| 29 |
+
targetPhrase.SetRuleSource(source);
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
void SpanLength::EvaluateWithSourceContext(const InputType &input
|
| 33 |
+
, const InputPath &inputPath
|
| 34 |
+
, const TargetPhrase &targetPhrase
|
| 35 |
+
, const StackVec *stackVec
|
| 36 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 37 |
+
, ScoreComponentCollection *estimatedScores) const
|
| 38 |
+
{
|
| 39 |
+
assert(stackVec);
|
| 40 |
+
|
| 41 |
+
const PhraseProperty *property = targetPhrase.GetProperty("SpanLength");
|
| 42 |
+
if (property == NULL) {
|
| 43 |
+
return;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
const SpanLengthPhraseProperty *slProp = static_cast<const SpanLengthPhraseProperty*>(property);
|
| 47 |
+
|
| 48 |
+
assert(targetPhrase.GetRuleSource());
|
| 49 |
+
|
| 50 |
+
float score = 0;
|
| 51 |
+
for (size_t i = 0; i < stackVec->size(); ++i) {
|
| 52 |
+
const ChartCellLabel &cell = *stackVec->at(i);
|
| 53 |
+
const Range &ntRange = cell.GetCoverage();
|
| 54 |
+
size_t sourceWidth = ntRange.GetNumWordsCovered();
|
| 55 |
+
float prob = slProp->GetProb(i, sourceWidth, m_const);
|
| 56 |
+
score += TransformScore(prob);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
if (score < -100.0f) {
|
| 60 |
+
float weight = StaticData::Instance().GetWeight(this);
|
| 61 |
+
if (weight < 0) {
|
| 62 |
+
score = -100;
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
scoreBreakdown.PlusEquals(this, score);
|
| 67 |
+
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
void SpanLength::SetParameter(const std::string& key, const std::string& value)
|
| 71 |
+
{
|
| 72 |
+
if (key == "smoothing") {
|
| 73 |
+
if (value == "plus-constant") {
|
| 74 |
+
m_smoothingMethod = PlusConst;
|
| 75 |
+
} else if (value == "none") {
|
| 76 |
+
m_smoothingMethod = None;
|
| 77 |
+
} else {
|
| 78 |
+
UTIL_THROW(util::Exception, "Unknown smoothing type " << value);
|
| 79 |
+
}
|
| 80 |
+
} else if (key == "constant") {
|
| 81 |
+
m_const = Scan<float>(value);
|
| 82 |
+
} else {
|
| 83 |
+
StatelessFeatureFunction::SetParameter(key, value);
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
}
|
| 88 |
+
|
mosesdecoder/moses/FF/SpanLength.h
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <string>
|
| 3 |
+
#include "StatelessFeatureFunction.h"
|
| 4 |
+
|
| 5 |
+
namespace Moses
|
| 6 |
+
{
|
| 7 |
+
|
| 8 |
+
// Rule Scope - not quite completely implemented yet
|
| 9 |
+
class SpanLength : public StatelessFeatureFunction
|
| 10 |
+
{
|
| 11 |
+
public:
|
| 12 |
+
SpanLength(const std::string &line);
|
| 13 |
+
|
| 14 |
+
virtual bool IsUseable(const FactorMask &mask) const {
|
| 15 |
+
return true;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
virtual void EvaluateInIsolation(const Phrase &source
|
| 19 |
+
, const TargetPhrase &targetPhrase
|
| 20 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 21 |
+
, ScoreComponentCollection &estimatedScores) const;
|
| 22 |
+
|
| 23 |
+
virtual void EvaluateWithSourceContext(const InputType &input
|
| 24 |
+
, const InputPath &inputPath
|
| 25 |
+
, const TargetPhrase &targetPhrase
|
| 26 |
+
, const StackVec *stackVec
|
| 27 |
+
, ScoreComponentCollection &scoreBreakdown
|
| 28 |
+
, ScoreComponentCollection *estimatedScores = NULL) const;
|
| 29 |
+
|
| 30 |
+
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
|
| 31 |
+
, const TranslationOptionList &translationOptionList) const {
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
virtual void EvaluateWhenApplied(const Hypothesis& hypo,
|
| 36 |
+
ScoreComponentCollection* accumulator) const {
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
virtual void EvaluateWhenApplied(const ChartHypothesis &hypo,
|
| 40 |
+
ScoreComponentCollection* accumulator) const {
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
void SetParameter(const std::string& key, const std::string& value);
|
| 44 |
+
|
| 45 |
+
protected:
|
| 46 |
+
enum SmoothingMethod {
|
| 47 |
+
None,
|
| 48 |
+
PlusConst,
|
| 49 |
+
};
|
| 50 |
+
SmoothingMethod m_smoothingMethod;
|
| 51 |
+
|
| 52 |
+
float m_const;
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
}
|
| 56 |
+
|
mosesdecoder/moses/FF/SparseHieroReorderingFeature.cpp
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <iostream>
|
| 2 |
+
|
| 3 |
+
#include "moses/ChartHypothesis.h"
|
| 4 |
+
#include "moses/ChartManager.h"
|
| 5 |
+
#include "moses/FactorCollection.h"
|
| 6 |
+
#include "moses/Sentence.h"
|
| 7 |
+
|
| 8 |
+
#include "util/exception.hh"
|
| 9 |
+
#include "util/string_stream.hh"
|
| 10 |
+
|
| 11 |
+
#include "SparseHieroReorderingFeature.h"
|
| 12 |
+
|
| 13 |
+
using namespace std;
|
| 14 |
+
|
| 15 |
+
namespace Moses
|
| 16 |
+
{
|
| 17 |
+
|
| 18 |
+
SparseHieroReorderingFeature::SparseHieroReorderingFeature(const std::string &line)
|
| 19 |
+
:StatelessFeatureFunction(0, line),
|
| 20 |
+
m_type(SourceCombined),
|
| 21 |
+
m_sourceFactor(0),
|
| 22 |
+
m_targetFactor(0),
|
| 23 |
+
m_sourceVocabFile(""),
|
| 24 |
+
m_targetVocabFile("")
|
| 25 |
+
{
|
| 26 |
+
|
| 27 |
+
/*
|
| 28 |
+
Configuration of features.
|
| 29 |
+
factor - Which factor should it apply to
|
| 30 |
+
type - what type of sparse reordering feature. e.g. block (modelled on Matthias
|
| 31 |
+
Huck's EAMT 2012 features)
|
| 32 |
+
word - which words to include, e.g. src_bdry, src_all, tgt_bdry , ...
|
| 33 |
+
vocab - vocab file to limit it to
|
| 34 |
+
orientation - e.g. lr, etc.
|
| 35 |
+
*/
|
| 36 |
+
cerr << "Constructing a Sparse Reordering feature" << endl;
|
| 37 |
+
ReadParameters();
|
| 38 |
+
m_otherFactor = FactorCollection::Instance().AddFactor("##OTHER##");
|
| 39 |
+
LoadVocabulary(m_sourceVocabFile, m_sourceVocab);
|
| 40 |
+
LoadVocabulary(m_targetVocabFile, m_targetVocab);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
void SparseHieroReorderingFeature::SetParameter(const std::string& key, const std::string& value)
|
| 44 |
+
{
|
| 45 |
+
if (key == "input-factor") {
|
| 46 |
+
m_sourceFactor = Scan<FactorType>(value);
|
| 47 |
+
} else if (key == "output-factor") {
|
| 48 |
+
m_targetFactor = Scan<FactorType>(value);
|
| 49 |
+
} else if (key == "input-vocab-file") {
|
| 50 |
+
m_sourceVocabFile = value;
|
| 51 |
+
} else if (key == "output-vocab-file") {
|
| 52 |
+
m_targetVocabFile = value;
|
| 53 |
+
} else if (key == "type") {
|
| 54 |
+
if (value == "SourceCombined") {
|
| 55 |
+
m_type = SourceCombined;
|
| 56 |
+
} else if (value == "SourceLeft") {
|
| 57 |
+
m_type = SourceLeft;
|
| 58 |
+
} else if (value == "SourceRight") {
|
| 59 |
+
m_type = SourceRight;
|
| 60 |
+
} else {
|
| 61 |
+
UTIL_THROW(util::Exception, "Unknown sparse reordering type " << value);
|
| 62 |
+
}
|
| 63 |
+
} else {
|
| 64 |
+
FeatureFunction::SetParameter(key, value);
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
void SparseHieroReorderingFeature::LoadVocabulary(const std::string& filename, Vocab& vocab)
|
| 69 |
+
{
|
| 70 |
+
if (filename.empty()) return;
|
| 71 |
+
ifstream in(filename.c_str());
|
| 72 |
+
UTIL_THROW_IF(!in, util::Exception, "Unable to open vocab file: " << filename);
|
| 73 |
+
string line;
|
| 74 |
+
while(getline(in,line)) {
|
| 75 |
+
vocab.insert(FactorCollection::Instance().AddFactor(line));
|
| 76 |
+
}
|
| 77 |
+
in.close();
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
const Factor* SparseHieroReorderingFeature::GetFactor(const Word& word, const Vocab& vocab, FactorType factorType) const
|
| 81 |
+
{
|
| 82 |
+
const Factor* factor = word.GetFactor(factorType);
|
| 83 |
+
if (vocab.size() && vocab.find(factor) == vocab.end()) return m_otherFactor;
|
| 84 |
+
return factor;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
void SparseHieroReorderingFeature::EvaluateWhenApplied(
|
| 88 |
+
const ChartHypothesis& cur_hypo ,
|
| 89 |
+
ScoreComponentCollection* accumulator) const
|
| 90 |
+
{
|
| 91 |
+
// get index map for underlying hypotheses
|
| 92 |
+
//const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
|
| 93 |
+
// cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap();
|
| 94 |
+
|
| 95 |
+
//The Huck features. For a rule with source side:
|
| 96 |
+
// abXcdXef
|
| 97 |
+
//We first have to split into blocks:
|
| 98 |
+
// ab X cd X ef
|
| 99 |
+
//Then we extract features based in the boundary words of the neighbouring blocks
|
| 100 |
+
//For the block pair, we use the right word of the left block, and the left
|
| 101 |
+
//word of the right block.
|
| 102 |
+
|
| 103 |
+
//Need to get blocks, and their alignment. Each block has a word range (on the
|
| 104 |
+
// on the source), a non-terminal flag, and a set of alignment points in the target phrase
|
| 105 |
+
|
| 106 |
+
//We need to be able to map source word position to target word position, as
|
| 107 |
+
//much as possible (don't need interior of non-terminals). The alignment info
|
| 108 |
+
//objects just give us the mappings between *rule* positions. So if we can
|
| 109 |
+
//map source word position to source rule position, and target rule position
|
| 110 |
+
//to target word position, then we can map right through.
|
| 111 |
+
|
| 112 |
+
size_t sourceStart = cur_hypo.GetCurrSourceRange().GetStartPos();
|
| 113 |
+
size_t sourceSize = cur_hypo.GetCurrSourceRange().GetNumWordsCovered();
|
| 114 |
+
|
| 115 |
+
vector<Range> sourceNTSpans;
|
| 116 |
+
for (size_t prevHypoId = 0; prevHypoId < cur_hypo.GetPrevHypos().size(); ++prevHypoId) {
|
| 117 |
+
sourceNTSpans.push_back(cur_hypo.GetPrevHypo(prevHypoId)->GetCurrSourceRange());
|
| 118 |
+
}
|
| 119 |
+
//put in source order. Is this necessary?
|
| 120 |
+
sort(sourceNTSpans.begin(), sourceNTSpans.end());
|
| 121 |
+
//cerr << "Source NTs: ";
|
| 122 |
+
//for (size_t i = 0; i < sourceNTSpans.size(); ++i) cerr << sourceNTSpans[i] << " ";
|
| 123 |
+
//cerr << endl;
|
| 124 |
+
|
| 125 |
+
typedef pair<Range,bool> Block;//flag indicates NT
|
| 126 |
+
vector<Block> sourceBlocks;
|
| 127 |
+
sourceBlocks.push_back(Block(cur_hypo.GetCurrSourceRange(),false));
|
| 128 |
+
for (vector<Range>::const_iterator i = sourceNTSpans.begin();
|
| 129 |
+
i != sourceNTSpans.end(); ++i) {
|
| 130 |
+
const Range& prevHypoRange = *i;
|
| 131 |
+
Block lastBlock = sourceBlocks.back();
|
| 132 |
+
sourceBlocks.pop_back();
|
| 133 |
+
//split this range into before NT, NT and after NT
|
| 134 |
+
if (prevHypoRange.GetStartPos() > lastBlock.first.GetStartPos()) {
|
| 135 |
+
sourceBlocks.push_back(Block(Range(lastBlock.first.GetStartPos(),prevHypoRange.GetStartPos()-1),false));
|
| 136 |
+
}
|
| 137 |
+
sourceBlocks.push_back(Block(prevHypoRange,true));
|
| 138 |
+
if (prevHypoRange.GetEndPos() < lastBlock.first.GetEndPos()) {
|
| 139 |
+
sourceBlocks.push_back(Block(Range(prevHypoRange.GetEndPos()+1,lastBlock.first.GetEndPos()), false));
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
/*
|
| 143 |
+
cerr << "Source Blocks: ";
|
| 144 |
+
for (size_t i = 0; i < sourceBlocks.size(); ++i) cerr << sourceBlocks[i].first << " "
|
| 145 |
+
<< (sourceBlocks[i].second ? "NT" : "T") << " ";
|
| 146 |
+
cerr << endl;
|
| 147 |
+
*/
|
| 148 |
+
|
| 149 |
+
//Mapping from source word to target rule position
|
| 150 |
+
vector<size_t> sourceWordToTargetRulePos(sourceSize);
|
| 151 |
+
map<size_t,size_t> alignMap;
|
| 152 |
+
alignMap.insert(
|
| 153 |
+
cur_hypo.GetCurrTargetPhrase().GetAlignTerm().begin(),
|
| 154 |
+
cur_hypo.GetCurrTargetPhrase().GetAlignTerm().end());
|
| 155 |
+
alignMap.insert(
|
| 156 |
+
cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().begin(),
|
| 157 |
+
cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().end());
|
| 158 |
+
//vector<size_t> alignMapTerm = cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm()
|
| 159 |
+
size_t sourceRulePos = 0;
|
| 160 |
+
//cerr << "SW->RP ";
|
| 161 |
+
for (vector<Block>::const_iterator sourceBlockIt = sourceBlocks.begin();
|
| 162 |
+
sourceBlockIt != sourceBlocks.end(); ++sourceBlockIt) {
|
| 163 |
+
for (size_t sourceWordPos = sourceBlockIt->first.GetStartPos();
|
| 164 |
+
sourceWordPos <= sourceBlockIt->first.GetEndPos(); ++sourceWordPos) {
|
| 165 |
+
sourceWordToTargetRulePos[sourceWordPos - sourceStart] = alignMap[sourceRulePos];
|
| 166 |
+
// cerr << sourceWordPos - sourceStart << "-" << alignMap[sourceRulePos] << " ";
|
| 167 |
+
if (! sourceBlockIt->second) {
|
| 168 |
+
//T
|
| 169 |
+
++sourceRulePos;
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
if ( sourceBlockIt->second) {
|
| 173 |
+
//NT
|
| 174 |
+
++sourceRulePos;
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
//cerr << endl;
|
| 178 |
+
|
| 179 |
+
//Iterate through block pairs
|
| 180 |
+
const Sentence& sentence =
|
| 181 |
+
static_cast<const Sentence&>(cur_hypo.GetManager().GetSource());
|
| 182 |
+
//const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase();
|
| 183 |
+
for (size_t i = 0; i < sourceBlocks.size()-1; ++i) {
|
| 184 |
+
Block& leftSourceBlock = sourceBlocks[i];
|
| 185 |
+
Block& rightSourceBlock = sourceBlocks[i+1];
|
| 186 |
+
size_t sourceLeftBoundaryPos = leftSourceBlock.first.GetEndPos();
|
| 187 |
+
size_t sourceRightBoundaryPos = rightSourceBlock.first.GetStartPos();
|
| 188 |
+
const Word& sourceLeftBoundaryWord = sentence.GetWord(sourceLeftBoundaryPos);
|
| 189 |
+
const Word& sourceRightBoundaryWord = sentence.GetWord(sourceRightBoundaryPos);
|
| 190 |
+
sourceLeftBoundaryPos -= sourceStart;
|
| 191 |
+
sourceRightBoundaryPos -= sourceStart;
|
| 192 |
+
|
| 193 |
+
// Need to figure out where these map to on the target.
|
| 194 |
+
size_t targetLeftRulePos =
|
| 195 |
+
sourceWordToTargetRulePos[sourceLeftBoundaryPos];
|
| 196 |
+
size_t targetRightRulePos =
|
| 197 |
+
sourceWordToTargetRulePos[sourceRightBoundaryPos];
|
| 198 |
+
|
| 199 |
+
bool isMonotone = true;
|
| 200 |
+
if ((sourceLeftBoundaryPos < sourceRightBoundaryPos &&
|
| 201 |
+
targetLeftRulePos > targetRightRulePos) ||
|
| 202 |
+
((sourceLeftBoundaryPos > sourceRightBoundaryPos &&
|
| 203 |
+
targetLeftRulePos < targetRightRulePos))) {
|
| 204 |
+
isMonotone = false;
|
| 205 |
+
}
|
| 206 |
+
util::StringStream buf;
|
| 207 |
+
buf << "h_"; //sparse reordering, Huck
|
| 208 |
+
if (m_type == SourceLeft || m_type == SourceCombined) {
|
| 209 |
+
buf << GetFactor(sourceLeftBoundaryWord,m_sourceVocab,m_sourceFactor)->GetString();
|
| 210 |
+
buf << "_";
|
| 211 |
+
}
|
| 212 |
+
if (m_type == SourceRight || m_type == SourceCombined) {
|
| 213 |
+
buf << GetFactor(sourceRightBoundaryWord,m_sourceVocab,m_sourceFactor)->GetString();
|
| 214 |
+
buf << "_";
|
| 215 |
+
}
|
| 216 |
+
buf << (isMonotone ? "M" : "S");
|
| 217 |
+
accumulator->PlusEquals(this,buf.str(), 1);
|
| 218 |
+
}
|
| 219 |
+
// cerr << endl;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
}
|
| 224 |
+
|
mosesdecoder/moses/FF/SparseHieroReorderingFeatureTest.cpp
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***********************************************************************
|
| 2 |
+
Moses - factored phrase-based language decoder
|
| 3 |
+
Copyright (C) 2013- University of Edinburgh
|
| 4 |
+
|
| 5 |
+
This library is free software; you can redistribute it and/or
|
| 6 |
+
modify it under the terms of the GNU Lesser General Public
|
| 7 |
+
License as published by the Free Software Foundation; either
|
| 8 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This library is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 13 |
+
Lesser General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU Lesser General Public
|
| 16 |
+
License along with this library; if not, write to the Free Software
|
| 17 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 18 |
+
***********************************************************************/
|
| 19 |
+
#include <iostream>
|
| 20 |
+
|
| 21 |
+
#include <boost/test/unit_test.hpp>
|
| 22 |
+
|
| 23 |
+
#include "SparseHieroReorderingFeature.h"
|
| 24 |
+
|
| 25 |
+
using namespace Moses;
|
| 26 |
+
using namespace std;
|
| 27 |
+
|
| 28 |
+
BOOST_AUTO_TEST_SUITE(shrf)
|
| 29 |
+
|
| 30 |
+
BOOST_AUTO_TEST_CASE(lexical_rule)
|
| 31 |
+
{
|
| 32 |
+
SparseHieroReorderingFeature feature("name=shrf");
|
| 33 |
+
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
BOOST_AUTO_TEST_SUITE_END()
|