diff --git a/fairseq-0.10.2/examples/backtranslation/prepare-wmt18en2de.sh b/fairseq-0.10.2/examples/backtranslation/prepare-wmt18en2de.sh new file mode 100644 index 0000000000000000000000000000000000000000..f6fd275307db50ca84c299440ae02dce49064030 --- /dev/null +++ b/fairseq-0.10.2/examples/backtranslation/prepare-wmt18en2de.sh @@ -0,0 +1,135 @@ +#!/bin/bash +# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh + +echo 'Cloning Moses github repository (for tokenization scripts)...' +git clone https://github.com/moses-smt/mosesdecoder.git + +echo 'Cloning Subword NMT repository (for BPE pre-processing)...' +git clone https://github.com/rsennrich/subword-nmt.git + +SCRIPTS=mosesdecoder/scripts +TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl +CLEAN=$SCRIPTS/training/clean-corpus-n.perl +NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl +REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl +BPEROOT=subword-nmt/subword_nmt +BPE_TOKENS=32000 + +URLS=( + "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz" + "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz" + "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz" + "http://data.statmt.org/wmt18/translation-task/rapid2016.tgz" + "http://data.statmt.org/wmt17/translation-task/dev.tgz" + "http://statmt.org/wmt14/test-full.tgz" +) +FILES=( + "training-parallel-europarl-v7.tgz" + "training-parallel-commoncrawl.tgz" + "training-parallel-nc-v13.tgz" + "rapid2016.tgz" + "dev.tgz" + "test-full.tgz" +) +CORPORA=( + "training/europarl-v7.de-en" + "commoncrawl.de-en" + "training-parallel-nc-v13/news-commentary-v13.de-en" + "rapid2016.de-en" +) + +if [ ! -d "$SCRIPTS" ]; then + echo "Please set SCRIPTS variable correctly to point to Moses scripts." + exit 1 +fi + +OUTDIR=wmt18_en_de + +src=en +tgt=de +lang=en-de +prep=$OUTDIR +tmp=$prep/tmp +orig=orig + +mkdir -p $orig $tmp $prep + +cd $orig + +for ((i=0;i<${#URLS[@]};++i)); do + file=${FILES[i]} + if [ -f $file ]; then + echo "$file already exists, skipping download" + else + url=${URLS[i]} + wget "$url" + if [ -f $file ]; then + echo "$url successfully downloaded." + else + echo "$url not successfully downloaded." + exit 1 + fi + if [ ${file: -4} == ".tgz" ]; then + tar zxvf $file + elif [ ${file: -4} == ".tar" ]; then + tar xvf $file + fi + fi +done +cd .. + +echo "pre-processing train data..." +for l in $src $tgt; do + rm $tmp/train.tags.$lang.tok.$l + for f in "${CORPORA[@]}"; do + cat $orig/$f.$l | \ + perl $NORM_PUNC $l | \ + perl $REM_NON_PRINT_CHAR | \ + perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l + done +done + +echo "pre-processing test data..." +for l in $src $tgt; do + if [ "$l" == "$src" ]; then + t="src" + else + t="ref" + fi + grep '\s*//g' | \ + sed -e 's/\s*<\/seg>\s*//g' | \ + sed -e "s/\’/\'/g" | \ + perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l + echo "" +done + +echo "splitting train and valid..." +for l in $src $tgt; do + awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l + awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l +done + +TRAIN=$tmp/train.de-en +BPE_CODE=$prep/code +rm -f $TRAIN +for l in $src $tgt; do + cat $tmp/train.$l >> $TRAIN +done + +echo "learn_bpe.py on ${TRAIN}..." +python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE + +for L in $src $tgt; do + for f in train.$L valid.$L test.$L; do + echo "apply_bpe.py to ${f}..." + python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f + done +done + +perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250 +perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250 + +for L in $src $tgt; do + cp $tmp/bpe.test.$L $prep/test.$L +done diff --git a/fairseq-0.10.2/examples/byte_level_bpe/README.md b/fairseq-0.10.2/examples/byte_level_bpe/README.md new file mode 100644 index 0000000000000000000000000000000000000000..657092660eae42d20f67647417623b8b8cb7b66c --- /dev/null +++ b/fairseq-0.10.2/examples/byte_level_bpe/README.md @@ -0,0 +1,88 @@ +# Neural Machine Translation with Byte-Level Subwords + +https://arxiv.org/abs/1909.03341 + +We provide an implementation of byte-level byte-pair encoding (BBPE), taking IWSLT 2017 Fr-En translation as +example. + +## Data +Get data and generate fairseq binary dataset: +```bash +bash ./get_data.sh +``` + +## Model Training +Train Transformer model with Bi-GRU embedding contextualization (implemented in `gru_transformer.py`): +```bash +# VOCAB=bytes +# VOCAB=chars +VOCAB=bbpe2048 +# VOCAB=bpe2048 +# VOCAB=bbpe4096 +# VOCAB=bpe4096 +# VOCAB=bpe16384 +``` +```bash +fairseq-train "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \ + --arch gru_transformer --encoder-layers 2 --decoder-layers 2 --dropout 0.3 --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9, 0.98)' \ + --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --log-format 'simple' --log-interval 100 --save-dir "checkpoints/${VOCAB}" \ + --batch-size 100 --max-update 100000 --update-freq 2 +``` + +## Generation +`fairseq-generate` requires bytes (BBPE) decoder to convert byte-level representation back to characters: +```bash +# BPE=--bpe bytes +# BPE=--bpe characters +BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe2048.model +# BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe2048.model +# BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe4096.model +# BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe4096.model +# BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe16384.model +``` + +```bash +fairseq-generate "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \ + --source-lang fr --gen-subset test --sacrebleu --path "checkpoints/${VOCAB}/checkpoint_last.pt" \ + --tokenizer moses --moses-target-lang en ${BPE} +``` +When using `fairseq-interactive`, bytes (BBPE) encoder/decoder is required to tokenize input data and detokenize model predictions: +```bash +fairseq-interactive "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \ + --path "checkpoints/${VOCAB}/checkpoint_last.pt" --input data/test.fr --tokenizer moses --moses-source-lang fr \ + --moses-target-lang en ${BPE} --buffer-size 1000 --max-tokens 10000 +``` + +## Results +| Vocabulary | Model | BLEU | +|:-------------:|:-------------:|:-------------:| +| Joint BPE 16k ([Kudo, 2018](https://arxiv.org/abs/1804.10959)) | 512d LSTM 2+2 | 33.81 | +| Joint BPE 16k | Transformer base 2+2 (w/ GRU) | 36.64 (36.72) | +| Joint BPE 4k | Transformer base 2+2 (w/ GRU) | 35.49 (36.10) | +| Joint BBPE 4k | Transformer base 2+2 (w/ GRU) | 35.61 (35.82) | +| Joint BPE 2k | Transformer base 2+2 (w/ GRU) | 34.87 (36.13) | +| Joint BBPE 2k | Transformer base 2+2 (w/ GRU) | 34.98 (35.43) | +| Characters | Transformer base 2+2 (w/ GRU) | 31.78 (33.30) | +| Bytes | Transformer base 2+2 (w/ GRU) | 31.57 (33.62) | + + +## Citation +``` +@misc{wang2019neural, + title={Neural Machine Translation with Byte-Level Subwords}, + author={Changhan Wang and Kyunghyun Cho and Jiatao Gu}, + year={2019}, + eprint={1909.03341}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` + + +## Contact +Changhan Wang ([changhan@fb.com](mailto:changhan@fb.com)), +Kyunghyun Cho ([kyunghyuncho@fb.com](mailto:kyunghyuncho@fb.com)), +Jiatao Gu ([jgu@fb.com](mailto:jgu@fb.com)) diff --git a/fairseq-0.10.2/examples/byte_level_bpe/get_bitext.py b/fairseq-0.10.2/examples/byte_level_bpe/get_bitext.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac1eeec1e6167ec6bafd76b37173ee6987cae7e --- /dev/null +++ b/fairseq-0.10.2/examples/byte_level_bpe/get_bitext.py @@ -0,0 +1,254 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import os +import os.path as op +from collections import namedtuple +from multiprocessing import cpu_count +from typing import List, Optional + +import sentencepiece as sp +from fairseq.data.encoders.byte_bpe import ByteBPE +from fairseq.data.encoders.byte_utils import byte_encode +from fairseq.data.encoders.bytes import Bytes +from fairseq.data.encoders.characters import Characters +from fairseq.data.encoders.moses_tokenizer import MosesTokenizer +from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE + + +SPLITS = ["train", "valid", "test"] + + +def _convert_xml(in_path: str, out_path: str): + with open(in_path) as f, open(out_path, "w") as f_o: + for s in f: + ss = s.strip() + if not ss.startswith("", "").split('">') + assert len(ss) == 2 + f_o.write(ss[1].strip() + "\n") + + +def _convert_train(in_path: str, out_path: str): + with open(in_path) as f, open(out_path, "w") as f_o: + for s in f: + ss = s.strip() + if ss.startswith("<"): + continue + f_o.write(ss.strip() + "\n") + + +def _get_bytes(in_path: str, out_path: str): + with open(in_path) as f, open(out_path, "w") as f_o: + for s in f: + f_o.write(Bytes.encode(s.strip()) + "\n") + + +def _get_chars(in_path: str, out_path: str): + with open(in_path) as f, open(out_path, "w") as f_o: + for s in f: + f_o.write(Characters.encode(s.strip()) + "\n") + + +def pretokenize(in_path: str, out_path: str, src: str, tgt: str): + Args = namedtuple( + "Args", + [ + "moses_source_lang", + "moses_target_lang", + "moses_no_dash_splits", + "moses_no_escape", + ], + ) + args = Args( + moses_source_lang=src, + moses_target_lang=tgt, + moses_no_dash_splits=False, + moses_no_escape=False, + ) + pretokenizer = MosesTokenizer(args) + with open(in_path) as f, open(out_path, "w") as f_o: + for s in f: + f_o.write(pretokenizer.encode(s.strip()) + "\n") + + +def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str): + with open(out_path, "w") as f_o: + for lang in [src, tgt]: + with open(f"{in_path_prefix}.{lang}") as f: + for s in f: + f_o.write(byte_encode(s.strip()) + "\n") + + +def _get_bpe(in_path: str, model_prefix: str, vocab_size: int): + arguments = [ + f"--input={in_path}", + f"--model_prefix={model_prefix}", + f"--model_type=bpe", + f"--vocab_size={vocab_size}", + "--character_coverage=1.0", + "--normalization_rule_name=identity", + f"--num_threads={cpu_count()}", + ] + sp.SentencePieceTrainer.Train(" ".join(arguments)) + + +def _apply_bbpe(model_path: str, in_path: str, out_path: str): + Args = namedtuple("Args", ["sentencepiece_model_path"]) + args = Args(sentencepiece_model_path=model_path) + tokenizer = ByteBPE(args) + with open(in_path) as f, open(out_path, "w") as f_o: + for s in f: + f_o.write(tokenizer.encode(s.strip()) + "\n") + + +def _apply_bpe(model_path: str, in_path: str, out_path: str): + Args = namedtuple("Args", ["sentencepiece_model"]) + args = Args(sentencepiece_model=model_path) + tokenizer = SentencepieceBPE(args) + with open(in_path) as f, open(out_path, "w") as f_o: + for s in f: + f_o.write(tokenizer.encode(s.strip()) + "\n") + + +def _concat_files(in_paths: List[str], out_path: str): + with open(out_path, "w") as f_o: + for p in in_paths: + with open(p) as f: + for r in f: + f_o.write(r) + + +def preprocess_iwslt17( + root: str, + src: str, + tgt: str, + bpe_size: Optional[int], + need_chars: bool, + bbpe_size: Optional[int], + need_bytes: bool, +): + # extract bitext + in_root = op.join(root, f"{src}-{tgt}") + for lang in [src, tgt]: + _convert_train( + op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"), + op.join(root, f"train.{lang}"), + ) + _convert_xml( + op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"), + op.join(root, f"valid.{lang}"), + ) + _convert_xml( + op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"), + op.join(root, f"test.{lang}"), + ) + # pre-tokenize + for lang in [src, tgt]: + for split in SPLITS: + pretokenize( + op.join(root, f"{split}.{lang}"), + op.join(root, f"{split}.moses.{lang}"), + src, + tgt, + ) + # tokenize with BPE vocabulary + if bpe_size is not None: + # learn vocabulary + concated_train_path = op.join(root, "train.all") + _concat_files( + [op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")], + concated_train_path, + ) + bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}") + _get_bpe(concated_train_path, bpe_model_prefix, bpe_size) + os.remove(concated_train_path) + # apply + for lang in [src, tgt]: + for split in SPLITS: + _apply_bpe( + bpe_model_prefix + ".model", + op.join(root, f"{split}.moses.{lang}"), + op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"), + ) + # tokenize with bytes vocabulary + if need_bytes: + for lang in [src, tgt]: + for split in SPLITS: + _get_bytes( + op.join(root, f"{split}.moses.{lang}"), + op.join(root, f"{split}.moses.bytes.{lang}"), + ) + # tokenize with characters vocabulary + if need_chars: + for lang in [src, tgt]: + for split in SPLITS: + _get_chars( + op.join(root, f"{split}.moses.{lang}"), + op.join(root, f"{split}.moses.chars.{lang}"), + ) + # tokenize with byte-level BPE vocabulary + if bbpe_size is not None: + # learn vocabulary + bchar_path = op.join(root, "train.bchar") + _convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path) + bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}") + _get_bpe(bchar_path, bbpe_model_prefix, bbpe_size) + os.remove(bchar_path) + # apply + for lang in [src, tgt]: + for split in SPLITS: + _apply_bbpe( + bbpe_model_prefix + ".model", + op.join(root, f"{split}.moses.{lang}"), + op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"), + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--root", type=str, default="data") + parser.add_argument( + "--bpe-vocab", + default=None, + type=int, + help="Generate tokenized bitext with BPE of size K." + "Default to None (disabled).", + ) + parser.add_argument( + "--bbpe-vocab", + default=None, + type=int, + help="Generate tokenized bitext with BBPE of size K." + "Default to None (disabled).", + ) + parser.add_argument( + "--byte-vocab", + action="store_true", + help="Generate tokenized bitext with bytes vocabulary", + ) + parser.add_argument( + "--char-vocab", + action="store_true", + help="Generate tokenized bitext with chars vocabulary", + ) + args = parser.parse_args() + + preprocess_iwslt17( + args.root, + "fr", + "en", + args.bpe_vocab, + args.char_vocab, + args.bbpe_vocab, + args.byte_vocab, + ) + + +if __name__ == "__main__": + main() diff --git a/fairseq-0.10.2/examples/byte_level_bpe/get_data.sh b/fairseq-0.10.2/examples/byte_level_bpe/get_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..c3d55d4925a6e6e23d12d293f093c1ae14acf76e --- /dev/null +++ b/fairseq-0.10.2/examples/byte_level_bpe/get_data.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +PY_BIN_ROOT= + +# PyPI dependency +${PY_BIN_ROOT}pip install sentencepiece sacremoses + +# Get data +if [ ! -d "data" ]; then + mkdir data +fi + +if [ ! -f "data/fr-en.tgz" ]; then + wget https://wit3.fbk.eu/archive/2017-01-trnted/texts/fr/en/fr-en.tgz -P data + tar xvf data/fr-en.tgz -C data +fi +${PY_BIN_ROOT}python get_bitext.py --bpe-vocab 16384 --byte-vocab --char-vocab +for VOCAB_SIZE in 2048 4096; do + ${PY_BIN_ROOT}python get_bitext.py --bpe-vocab ${VOCAB_SIZE} --bbpe-vocab ${VOCAB_SIZE} +done +rm -r data/fr-en data/fr-en.tgz + +# Generate binary dataset +${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_bpe16384 --joined-dictionary \ + --workers "$(nproc)" --trainpref data/train.moses.bpe16384 --validpref data/valid.moses.bpe16384 \ + --testpref data/test.moses.bpe16384 + +${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_bytes --joined-dictionary \ + --workers "$(nproc)" --trainpref data/train.moses.bytes --validpref data/valid.moses.bytes \ + --testpref data/test.moses.bytes + +${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_chars --joined-dictionary \ + --workers "$(nproc)" --trainpref data/train.moses.chars --validpref data/valid.moses.chars \ + --testpref data/test.moses.chars + +for VOCAB_SIZE in 2048 4096; do + for TYPE in bbpe bpe; do + ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir "data/bin_${TYPE}${VOCAB_SIZE}" \ + --joined-dictionary --workers "$(nproc)" --trainpref "data/train.moses.${TYPE}${VOCAB_SIZE}" \ + --validpref "data/valid.moses.${TYPE}${VOCAB_SIZE}" --testpref "data/test.moses.${TYPE}${VOCAB_SIZE}" + done +done diff --git a/fairseq-0.10.2/examples/byte_level_bpe/gru_transformer.py b/fairseq-0.10.2/examples/byte_level_bpe/gru_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d4efa93a4d75da71c78e786d7f62101ef3266af4 --- /dev/null +++ b/fairseq-0.10.2/examples/byte_level_bpe/gru_transformer.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch.nn.functional as F +from fairseq.models import register_model, register_model_architecture +from fairseq.models.transformer import TransformerEncoder, TransformerModel + + +@register_model("gru_transformer") +class GRUTransformerModel(TransformerModel): + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return GRUTransformerEncoder(args, src_dict, embed_tokens) + + +class GRUTransformerEncoder(TransformerEncoder): + def __init__(self, args, dictionary, embed_tokens): + super().__init__(args, dictionary, embed_tokens) + self.emb_ctx = nn.GRU( + input_size=embed_tokens.embedding_dim, + hidden_size=embed_tokens.embedding_dim // 2, + num_layers=1, + bidirectional=True, + ) + + def forward_embedding(self, src_tokens): + # embed tokens and positions + x = embed = self.embed_scale * self.embed_tokens(src_tokens) + if self.embed_positions is not None: + x = embed + self.embed_positions(src_tokens) + + # contextualize embeddings + x = x.transpose(0, 1) + x = self.dropout_module(x) + x, _ = self.emb_ctx.forward(x) + x = x.transpose(0, 1) + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + x = self.dropout_module(x) + return x, embed + + +@register_model_architecture("gru_transformer", "gru_transformer") +def gru_transformer_base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.no_cross_attention = getattr(args, "no_cross_attention", False) + args.cross_self_attention = getattr(args, "cross_self_attention", False) + args.layer_wise_attention = getattr(args, "layer_wise_attention", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + + +@register_model_architecture("gru_transformer", "gru_transformer_big") +def gru_transformer_big(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.3) + gru_transformer_base_architecture(args) diff --git a/fairseq-0.10.2/examples/conv_seq2seq/README.md b/fairseq-0.10.2/examples/conv_seq2seq/README.md new file mode 100644 index 0000000000000000000000000000000000000000..95fe7e7909a77ee0e50fe31d4b8be38daa8f3be7 --- /dev/null +++ b/fairseq-0.10.2/examples/conv_seq2seq/README.md @@ -0,0 +1,25 @@ +# Convolutional Sequence to Sequence Learning (Gehring et al., 2017) + +## Pre-trained models + +Description | Dataset | Model | Test set(s) +---|---|---|--- +Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2)
newstest2012/2013:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2) +Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2) +Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2) + +## Example usage + +See the [translation README](../translation/README.md) for instructions on reproducing results for WMT'14 En-De and +WMT'14 En-Fr using the `fconv_wmt_en_de` and `fconv_wmt_en_fr` model architectures. + +## Citation + +```bibtex +@inproceedings{gehring2017convs2s, + title = {Convolutional Sequence to Sequence Learning}, + author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N}, + booktitle = {Proc. of ICML}, + year = 2017, +} +``` diff --git a/fairseq-0.10.2/examples/layerdrop/README.md b/fairseq-0.10.2/examples/layerdrop/README.md new file mode 100644 index 0000000000000000000000000000000000000000..394e710b0f522981dbb073f28eaf550ee28760cf --- /dev/null +++ b/fairseq-0.10.2/examples/layerdrop/README.md @@ -0,0 +1,154 @@ +# Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019) +This page contains information for how to train models with LayerDrop, based on this [paper](https://arxiv.org/abs/1909.11556). + +## Citation: +If you found this technique useful, please cite our paper: +```bibtex +@article{fan2019reducing, + title={Reducing Transformer Depth on Demand with Structured Dropout}, + author={Fan, Angela and Grave, Edouard and Joulin, Armand}, + journal={arXiv preprint arXiv:1909.11556}, + year={2019} +} +``` + +## Pre-trained models + +Model | Description | Download +---|---|--- +`layerdrop_wmt_en_de_12_6` | Transformer + LayerDrop 0.2 trained on WMT16 en-de with 12 encoder and 6 decoder layers | [layerdrop_wmt_en_de_12_6.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/layerdrop_wmt_en_de_12_6.tar.gz) +`roberta_layerdrop.base` | RoBERTa Base + LayerDrop 0.2 | [roberta_layerdrop.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.base.qnli.tar.gz) +`roberta_layerdrop.large` | RoBERTa Large + LayerDrop 0.2 | [roberta_layerdrop.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.large.tar.gz) +`roberta_layerdrop.large.mnli` | `roberta_layerdrop.large` finetuned on [MNLI](http://www.nyu.edu/projects/bowman/multinli) | [roberta_layerdrop.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.large.mnli.tar.gz) +`roberta_layerdrop.large.qnli` | `roberta_layerdrop.large` finetuned on [QNLI](https://arxiv.org/abs/1804.07461) | [roberta_layerdrop.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta_layerdrop.large.qnli.tar.gz) + + +Evaluate performance of these pre-trained models: +```bash +# Example for Machine Translation +fairseq-generate /path/to/bped/wmt/data --path nmt_checkpoint.pt \ + --beam 8 --lenpen 0.4 \ + --batch-size 64 \ + --remove-bpe \ + --gen-subset test > wmt16_gen.txt +bash scripts/compound_split_bleu.sh wmt16_gen.txt +# prints BLEU4 = 30.17 +``` + +```python +# Example for RoBERTa + LayerDrop finetuned on MNLI: +from fairseq.models.roberta import RobertaModel + +roberta_layerdrop = RobertaModel.from_pretrained( + '/path/to/MNLI/model', + checkpoint_file='mnli_checkpoint.pt', + data_name_or_path='/path/to/MNLI/data/MNLI-bin' +) +label_map = {0: 'contradiction', 2: 'neutral', 1: 'entailment'} +ncorrect, nsamples = 0, 0 +roberta_layerdrop.cuda() +roberta_layerdrop.eval() +with open('/path/to/MNLI/data/dev_matched.tsv') as fin: + fin.readline() + for index, line in enumerate(fin): + tokens = line.strip().split('\t') + sent1, sent2, target = tokens[8], tokens[9], tokens[-1] + tokens = roberta_layerdrop.encode(sent1, sent2) + prediction = roberta_layerdrop.predict('sentence_classification_head', tokens).argmax().item() + prediction_label = label_map[prediction] + ncorrect += int(prediction_label == target) + nsamples += 1 +print('| Accuracy: ', float(ncorrect)/float(nsamples)) +# prints | Accuracy: 0.9026999490575649 + + +# Example for RoBERTa + LayerDrop finetuned on QNLI: +roberta = RobertaModel.from_pretrained( + '/path/to/QNLI/model', + checkpoint_file='qnli_checkpoint.pt', + data_name_or_path='/path/to/QNLI/data/QNLI-bin' +) + +label_fn = lambda label: roberta.task.label_dictionary.string( + [label + roberta.task.target_dictionary.nspecial] +) +ncorrect, nsamples = 0, 0 +roberta.cuda() +roberta.eval() +with open('/path/to/QNLI/data/dev.tsv') as fin: + fin.readline() + for index, line in enumerate(fin): + tokens = line.strip().split('\t') + sent1, sent2, target = tokens[1], tokens[2], tokens[3] + tokens = roberta.encode(sent1, sent2) + prediction = roberta.predict('sentence_classification_head', tokens).argmax().item() + prediction_label = label_fn(prediction) + ncorrect += int(prediction_label == target) + nsamples += 1 +print('| Accuracy: ', float(ncorrect)/float(nsamples)) +# prints | Accuracy: 0.9480139117700896 +``` + + +## Example usage + +To train a model with LayerDrop, add the following flags. We recommend 0.2, a value that worked well in our experiments. For Language Models that are decoder-only, you need only the decoder flag. For RoBERTa, an encoder, you need only the encoder flag. The encoder and decoder LayerDrop values can be set differently. +``` +--encoder-layerdrop 0.2 --decoder-layerdrop 0.2 +``` + +To prune a model that has been trained with LayerDrop, add the following flags followed by a comma separated list of which layers you would like to keep. +``` +--encoder-layers-to-keep 0,2,4,6,8,10,12,14 --decoder-layers-to-keep 0,2,4,6,8,10,12,14 +``` +Setting these flags should print a message such as: +``` +| Pruning model to specified layer configuration +``` +You should also see a smaller number of parameters in the model, for example the 16-Layer Transformer Language Model prints: +``` +num. model params: 246933504 +``` +while a model pruned to 8 Layers prints: +``` +num. model params: 146163712 +``` + +If you would like to pick up training with a model that has been pruned, simply adding these flags is sufficient. If you would like to use a script that only does evaluation (no training), you may need to pass an override command. A specific example would be for language modeling: +```bash +fairseq-eval-lm /path/to/wikitext-103 \ + --path /path/to/model/checkpoint.pt \ + --model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}" +``` +This model override command overrides the training parameters and updates the model arguments so that the pruned model is run instead of the full model. + +## Reproduce Paper Results + +Looking to reproduce the results in the paper? + +1. For Translation on WMT16 en-de, we followed this setting [here](https://github.com/pytorch/fairseq/blob/master/examples/scaling_nmt/README.md) +2. To train RoBERTa, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/roberta) +3. To train Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model) + + +## Tips + +1. If you would like to train large models with better performance, LayerDrop should be set to a smaller value such as 0.1 or 0.2. Too much LayerDrop will mean the model has too much regularization, so may not reach the best performance. Since LayerDrop adds regularization, you may achieve the best performance by slightly reducing the amount of standard dropout (for example, reduce by 0.1). + +2. If you would like to train large models to be pruned and made smaller, LayerDrop should be set to a larger value such as 0.5 if you want to prune very aggressively (such as removing half the network or more). If you would like to prune fewer layers away, LayerDrop can be set to a smaller value such as 0.2. Our experiments were conducted with low values of LayerDrop (such as 0.1 and 0.2), for reference. + +3. When pruning layers at inference time, it is best to spread out the layers remaining so they are evenly spaced throughout the network. For example, if you want to remove 50% of the network, keeping every other layer is good. + + +## FAQ + +1. How did the sharing layers experiment work? In an appendix (https://openreview.net/pdf?id=SylO2yStDr) we added an experiment on Wikitext-103 language modeling that combined LayerDrop with Weight Sharing. We shared chunks of 2 layers such that every other layer had shared weights. For example, if our network has layers 1 through 6, then layer 1 and 2 are shared, layer 3 and 4 are shared, and layer 5 and 6 are shared. + +2. LayerDrop hasn't been helping in my setting? During training time, LayerDrop can help regularize your network. This is most important if your network is already overfitting - if your network is underfitting, it is possible LayerDrop is adding too much regularization. We recommend using smaller values (such as 0.1 or 0.2) and also decreasing the quantity of standard dropout (for example, reduce by 0.1). + +3. Can you train a model without LayerDrop and finetune with LayerDrop (e.g. for BERT)? In our experiments, we did not see great performance. Models such as RoBERTa have trained for a long time in the pre-training setting, so only finetuning with LayerDrop for a few epochs on a downstream task such as MNLI does not achieve the robustness required for successful pruning. + + +## Having an issue or have a question? + +Please open an issue in this repository with the details of your question. Thanks! diff --git a/fairseq-0.10.2/examples/noisychannel/rerank_score_lm.py b/fairseq-0.10.2/examples/noisychannel/rerank_score_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..fa3aa64462623b18820d301cccb42bf3483d97a7 --- /dev/null +++ b/fairseq-0.10.2/examples/noisychannel/rerank_score_lm.py @@ -0,0 +1,81 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os + +from fairseq import options + +from . import rerank_options, rerank_utils + + +def score_lm(args): + using_nbest = args.nbest_list is not None + ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) = rerank_utils.get_directories( + args.data_dir_name, + args.num_rescore, + args.gen_subset, + args.gen_model_name, + args.shard_id, + args.num_shards, + args.sampling, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + + predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" + if using_nbest: + print("Using predefined n-best list from interactive.py") + predictions_bpe_file = args.nbest_list + + gen_output = rerank_utils.BitextOutputFromGen( + predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest + ) + + if args.language_model is not None: + lm_score_file = rerank_utils.rescore_file_name( + pre_gen, args.prefix_len, args.lm_name, lm_file=True + ) + + if args.language_model is not None and not os.path.isfile(lm_score_file): + print("STEP 4.5: language modeling for P(T)") + if args.lm_bpe_code is None: + bpe_status = "no bpe" + elif args.lm_bpe_code == "shared": + bpe_status = "shared" + else: + bpe_status = "different" + + rerank_utils.lm_scoring( + lm_preprocessed_dir, + bpe_status, + gen_output, + pre_gen, + args.lm_dict, + args.lm_name, + args.language_model, + args.lm_bpe_code, + 128, + lm_score_file, + args.target_lang, + args.source_lang, + prefix_len=args.prefix_len, + ) + + +def cli_main(): + parser = rerank_options.get_reranking_parser() + args = options.parse_args_and_arch(parser) + score_lm(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq-0.10.2/examples/quant_noise/README.md b/fairseq-0.10.2/examples/quant_noise/README.md new file mode 100644 index 0000000000000000000000000000000000000000..057ea620ab2a53a58565e5ab1c745696ce8dde2a --- /dev/null +++ b/fairseq-0.10.2/examples/quant_noise/README.md @@ -0,0 +1,298 @@ +# Training with Quantization Noise for Extreme Model Compression ({Fan\*, Stock\*} *et al.*, 2020) +This page contains information for how to train and quantize models with Quantization Noise, for both scalar quantization like `int8` and Iterative Product Quantization. +Check out our paper [here](https://arxiv.org/abs/2004.07320). + +Looking for pretrained models? They will be added shortly. +Looking for code to train vision models? We are working on open sourcing our code as part of ClassyVision. Please check back, but note that both the Scalar and Iterative Product Quantization counterparts of the `nn.Conv2d` module are already included in this release. + +**Contents**: +- [Walk through of code](#walk-through-the-code) +- [Reproduce NLP Results](#looking-to-reproduce-the-nlp-results-in-the-paper) +- [Reproduce Vision Results](#looking-to-reproduce-the-vision-results-in-the-paper) + + +## Citation +```bibtex +@article{fan2020training, + title={Training with Quantization Noise for Extreme Model Compression}, + author={Angela Fan* and Pierre Stock* and and Benjamin Graham and Edouard Grave and Remi Gribonval and Herve Jegou and Armand Joulin}, + year={2020}, + eprint={2004.07320}, + archivePrefix={arXiv}, + primaryClass={cs.ML} +} +``` + +## Walk through the code + +Training a model with Quant-Noise improves the performance in subsequent inference-time quantization by training models to be robust to quantization. This technique is useful for both scalar and product quantization methods, as well as multiple domains. We detail below our approach to train, quantize models and integrate our code to quantize your favorite models. + +### Scalar Quantization + +Unlike the section [Iterative Product Quantization](#iterative-product-quantization) which gives state-of-the-art compression, this section showcases the usefulness of our approach for simple scalar quantization baselines such as int8 using on-GPU Fake Quantization. + +#### Training + +Scalar quantization with Quant-Noise consists in randomly quantizing a proportion `p` of the weights during training. Scalar quantization is implemented [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quantization/scalar) under the form of Fake Quantization, meaning that we emulate int8 on GPU by quantizing and de-quantizing both the weights and the activations. We rely on PyTorch's [quantization primitives](https://github.com/pytorch/pytorch/tree/master/torch/quantization). + +To train a model with Quant-Noise, add the following flag: +``` +--quant-noise-scalar 0.5 +``` +Large values of noise make the network easier to quantize but may result in higher non-quantized test and validation perplexities. + +#### Quantization + +When evaluating a network, all quantized modules and activation hooks automatically switch to `p=1` so the validation accuracy reported by Fairseq is actually the quantized one, nothing more to do. + + +#### Integration with your own code + +Looking to quantize your own models with Quant-Noise + Scalar Quantization? +- Use the function `quantize_model_` implemented [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quantization/scalar/utils.py) to (1) replace all your modules by their quantized counterparts and (2) add hooks to those modules to quantize the activations. +- Then, perform your training as usual. Note that in `eval()` mode, the network is always fully quantized (weights and activations) by default (`p=1`). + + + +### Iterative Product Quantization + + +Iterative Product Quantization with Quant-Noise proceeds in two steps. First, a model must be trained uncompressed with Quant-Noise. Second, the model must be quantized with iPQ. Note that we implement here the simplest form of noise, which consists in randomly dropping a proportion `p` of blocks, and that worked as well as assigning those blocks to their current centroid. + +#### Training + +To train a model with Quant-Noise, add the following flags: +``` +--quant-noise-pq 0.1 --quant-noise-pq-block-size 8 +``` +`quant-noise-pq` controls how much dropout is applied to the blocks of the weight matrix. `quant-noise-pq-block-size` controls the size of the weight matrix blocks. +We recommend training with 0.05 to 0.2 Quant-Noise, a value that worked well in our experiments. For the block-size, we recommend training with block-size of 8. Note that the block size must be a multiple of `input_features`, see the size checks [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quant_noise.py). Large block sizes result in higher compression ratio but may induce a loss in accuracy. + +We currently support training Transformer based models, such as sequence-to-sequence, language models, and BERT architectures. The `quant_noise` function [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quant_noise.py) wraps a module. It splits a weight matrix into blocks and applies random dropout to these blocks. +In the Transformer architectures, quant-noise is applied to the input and output embeddings, the attention, and the FFN. + +Quant-Noise can also be combined with **LayerDrop** (see [here](https://github.com/pytorch/fairseq/tree/master/examples/layerdrop)) to add its pruning effect to the quantized model and make the model even smaller. We recommend training with LayerDrop 0.1 or 0.2. + +#### Quantization + +We implement an improved version of product quantization from Stock et al, **iPQ**, described [here](https://arxiv.org/abs/1907.05686), see code with old API [here](https://github.com/facebookresearch/kill-the-bits). Note that we improved the iPQ API in terms of both compute speed and usability as described below. + +For the particular case of PQ, quantization is made sequentially. We recommend first quantizing the FFNs, then the EMBs, and finally the ATTNs. Quantization is done in two sub-steps: +- First, perform `n` steps of Product Quantization (generally `n=20` is enough). +- Then, finetune the obtained centroids. + +#### Integration with your own code + +Looking to quantize your own models with Quant-Noise + iPQ? +- First wrap your modules with the `quant_noise` function [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quant_noise.py), which is module-agnostic and train your favorite model. +- Then, quantize your trained model using the code [here](https://github.com/pytorch/fairseq/tree/master/fairseq/modules/quantization/pq). This can be done *without any changes to your training loop*. Below is an example code for integration. +Note that we tried our approach only on Transformers and various Convolutional Models such as EfficientNets. + +```python +from fairseq.modules.quantization.pq import quantize_model_, SizeTracker + +# get configuration parameters +n_centroids_config = config["n_centroids"] +block_sizes_config = config["block_sizes"] +layers_to_quantize = config["layers_to_quantize"] + +# size tracker for keeping track of assignments, centroids and non-compressed sizes +size_tracker = SizeTracker(model) + +# Quantize model by stages +for step in range(len(layers_to_quantize)): + + # quantize model in-place + quantized_layers = quantize_model_( + model, + size_tracker, + layers_to_quantize, + block_sizes_config, + n_centroids_config, + step=step, + ) + logger.info(f"Finetuning stage {step}, quantized layers: {quantized_layers}") + logger.info(f"{size_tracker}") + + # Don't forget to re-create/update trainer/optimizer since model parameters have changed + optimizer = ... + + # Finetune the centroids with your usual training loop for a few epochs + trainer.train_epoch() +``` + + +## Looking to reproduce the NLP results in the paper? + +We detail below how to reproduce the state-of-the-art results in reported in the paper for Quant-Noise + Iterative Product Quantization. + +### Training with Quant-Noise + +To **train** RoBERTa + QuantNoise, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/roberta). +The following command can be used to train a RoBERTa Base + QuantNoise model: + +```bash +TOTAL_UPDATES=125000 +WARMUP_UPDATES=10000 +PEAK_LR=0.0005 +TOKENS_PER_SAMPLE=512 +MAX_POSITIONS=512 +MAX_SENTENCES=16 +UPDATE_FREQ=2 +DATA_DIR=/path/to/data/here + +fairseq-train $DATA_DIR \ + --task masked_lm --criterion masked_lm --arch roberta_base \ + --sample-break-mode complete \ + --tokens-per-sample $TOKENS_PER_SAMPLE --max-positions $MAX_POSITIONS \ + --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 \ + --clip-norm 0.0 \ + --lr-scheduler polynomial_decay --lr $PEAK_LR \ + --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \ + --dropout 0.1 --attention-dropout 0.1 \ + --weight-decay 0.01 \ + --batch-size $MAX_SENTENCES \ + --update-freq $UPDATE_FREQ --max-update $TOTAL_UPDATES \ + --save-dir checkpoint/roberta \ + --ddp-backend no_c10d --encoder-layerdrop 0.2 \ + --quant-noise-pq 0.2 --quant-noise-pq-block-size 8 --untie-weights-roberta +``` + +To **finetune** RoBERTa + QuantNoise, we followed this setting [here](https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.glue.md). +The following command can be used to finetune a RoBERTa Base + QuantNoise model on the RTE dataset: + +```bash +TOTAL_NUM_UPDATES=2036 +WARMUP_UPDATES=122 +LR=2e-05 +NUM_CLASSES=2 +MAX_SENTENCES=16 +ROBERTA_PATH=/path/to/roberta_quantnoise/model.pt + +fairseq-train /path/to/rte/data/ \ + --restore-file $ROBERTA_PATH \ + --max-positions 512 \ + --batch-size $MAX_SENTENCES \ + --max-tokens 4400 \ + --task sentence_prediction \ + --reset-optimizer --reset-dataloader --reset-meters \ + --required-batch-size-multiple 1 \ + --init-token 0 --separator-token 2 \ + --arch roberta_large \ + --criterion sentence_prediction \ + --num-classes $NUM_CLASSES \ + --dropout 0.1 --attention-dropout 0.1 \ + --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ + --clip-norm 0.0 \ + --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ + --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ + --max-epoch 10 \ + --find-unused-parameters \ + --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ + --ddp-backend no_c10d \ + --quant-noise-pq 0.2 --quant-noise-pq-block-size 8 +``` + +To **train** Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model). +The following command can be used to train a Transformer + QuantNoise model on Wikitext-103: + +```bash +fairseq-train --task language_modeling /path/to/wikitext-103/data \ + --save-dir checkpoints/transformer_wikitext-103 \ + --adaptive-input --adaptive-input-cutoff 20000,60000 --adaptive-input-factor 4 \ + --adaptive-softmax-cutoff 20000,60000 --adaptive-softmax-dropout 0.2 --adaptive-softmax-factor 4.0 \ + --tie-adaptive-proj --tie-adaptive-weights \ + --arch transformer_lm_gbw \ + --attention-dropout 0.1 --dropout 0.2 --relu-dropout 0.1 \ + --clip-norm 0.1 --criterion adaptive_loss \ + --ddp-backend no_c10d \ + --decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 \ + --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \ + --lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --max-lr 1.0 --t-mult 2.0 \ + --max-tokens 3072 --tokens-per-sample 3072 --momentum 0.99 --optimizer nag \ + --sample-break-mode none --update-freq 3 \ + --warmup-init-lr 1e-07 --warmup-updates 16000 \ + --weight-decay 0 --seed 1 --min-lr 1e-09 \ + --quant-noise-pq 0.05 --quant-noise-pq-block-size 8 +``` + +To **evaluate** this model, note you need to use the `eval.py` script. The following command can be used to evaluate: + +```bash +fairseq-eval-lm /path/to/wikitext-103/data --path /path/to/model/checkpoint \ + --sample-break-mode complete \ + --max-tokens 3072 \ + --context-window 2560 \ + --softmax-batch 1024 \ + --gen-subset valid +``` +and change the `--gen-subset` to `test` if you would like to evaluate on the test set instead. + + +### Iterative Product Quantization + +To quantize the finetuned RoBERTa model, we use this command on 1 GPU. This should run in a day. +```bash +TOTAL_NUM_UPDATES=6108 # 2036 updates for each iteration +WARMUP_UPDATES=122 +LR=2e-05 +NUM_CLASSES=2 +MAX_SENTENCES=16 +fairseq-train --task sentence_prediction /path/to/data/ \ + --restore-file $ROBERTA_PATH \ + --save-dir checkpoints/roberta_finetuned \ + --max-positions 512 \ + --batch-size $MAX_SENTENCES \ + --max-tokens 4400 \ + --init-token 0 --separator-token 2 \ + --arch roberta_large \ + --criterion sentence_prediction \ + --num-classes $NUM_CLASSES \ + --dropout 0.1 --attention-dropout 0.1 \ + --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ + --clip-norm 0.0 --lr-scheduler polynomial_decay \ + --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ + --no-progress-bar --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d \ + --quantization-config-path /path/to/config/yaml +``` + +To quantize the trained Language Model, we use this command on 8 V100 23GB GPUs. This should run in a couple of hours. +```bash +fairseq-train --task language_modeling /path/to/wikitext-103/data \ + --save-dir checkpoints/transformer_wikitext-103 \ + --adaptive-input --adaptive-input-cutoff 20000,60000 --adaptive-input-factor 4 \ + --adaptive-softmax-cutoff 20000,60000 --adaptive-softmax-dropout 0.2 --adaptive-softmax-factor 4.0 \ + --arch transformer_lm_gbw \ + --attention-dropout 0.1 --dropout 0.2 --relu-dropout 0.1 \ + --bucket-cap-mb 25 --char-embedder-highway-layers 2 --character-embedding-dim 4 \ + --clip-norm 0.1 --criterion adaptive_loss \ + --ddp-backend no_c10d \ + --decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \ + --fp16 --keep-last-epochs -1 \ + --lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --max-lr 0.05 --min-lr 1e-09 \ + --max-tokens 2944 --tokens-per-sample 2944\ + --momentum 0.99 --no-epoch-checkpoints --no-progress-bar --optimizer nag --required-batch-size-multiple 8 \ + --sample-break-mode none --t-mult 2.0 --skip-invalid-size-inputs-valid-test \ + --tie-adaptive-proj --tie-adaptive-weights --update-freq 3 --weight-decay 0 --seed 1 \ + --log-interval 100 --no-progress-bar --skip-invalid-size-inputs-valid-test \ + --restore-file path/to/trained/lm/with/quant/noise \ + --max-update 13500 --quantization-config-path /path/to/config/yaml +``` +If you have less capacity or if your distributed training freezes, try reducing `--max-tokens` and `--tokens-per-sample` (this may reduce the quantized accuracy a bit). + +### Remarks + +We try to keep the open-sourced code as readable and as easy-to-plug as possible. Therefore, we did not test it for the following cases: +- Scalar quantization with RoBERTa. +- Quantization with iPQ and `int8` combined. + +If you have trouble adapting it, we will be more than happy to help! + +## Looking to reproduce the Vision results in the paper? + +We are working on open sourcing our code as part of ClassyVision. Please check back. + + +## Having an issue or have a question? + +Please open an issue in this repository with the details of your question. Thanks! diff --git a/fairseq-0.10.2/examples/quant_noise/transformer_quantization_config.yaml b/fairseq-0.10.2/examples/quant_noise/transformer_quantization_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d4be14a93a3593f8e6dc66c3b05061bfdde3e0e0 --- /dev/null +++ b/fairseq-0.10.2/examples/quant_noise/transformer_quantization_config.yaml @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This file defines example configuration arguments for quantizing +# a transformer model with product quantization + +# Number of Centroids for Product Quantization, by default 256 (byte-aligned) +n_centroids: + Linear: + key: in_features + value: {"*": 256} + Embedding: + key: embedding_dim + value: {"*": 256} + +# Block Sizes for Product Quantization +# We suggest: 8 for FFN, 4 for ATTN, 4 for embedding projections, 8 for embeddings +block_sizes: + Linear: + key: fuzzy_name + value: {fc: 8, attn: 4, emb: 4} + Embedding: + key: fuzzy_name + value: {emb: 8} + +# Layers to Quantize Sequentially +# We suggest: first FFN, then EMB, then ATTN +layers_to_quantize: + - decoder\\.layers\\.\d+\\.fc[12] + - decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01] + - decoder\\.layers\\.\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj) diff --git a/mosesdecoder/moses/FF/BleuScoreFeature.h b/mosesdecoder/moses/FF/BleuScoreFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..074ad0c457252b42ba5cab94de37400aec047559 --- /dev/null +++ b/mosesdecoder/moses/FF/BleuScoreFeature.h @@ -0,0 +1,191 @@ +#ifndef BLUESCOREFEATURE_H +#define BLUESCOREFEATURE_H + +#include +#include +#include + +#include + +#include "StatefulFeatureFunction.h" + +#include "moses/FF/FFState.h" +#include "moses/Phrase.h" +#include "moses/ChartHypothesis.h" + +namespace Moses +{ + +class BleuScoreFeature; + +class BleuScoreState : public FFState +{ +public: + friend class BleuScoreFeature; + static size_t bleu_order; + + BleuScoreState(bool is_syntax); + size_t hash() const; + virtual bool operator==(const FFState& other) const; + + void print(std::ostream& out) const; + +private: + Phrase m_words; + size_t m_source_length; + size_t m_target_length; + bool m_is_syntax; + // scaled reference length is needed for scoring incomplete hypotheses against reference translation + float m_scaled_ref_length; + + std::vector< size_t > m_ngram_counts; + std::vector< size_t > m_ngram_matches; + + void AddNgramCountAndMatches(std::vector< size_t >& counts, std::vector< size_t >& matches); +}; + + +std::ostream& operator<<(std::ostream& out, const BleuScoreState& state); + +typedef boost::unordered_map< Phrase, size_t > NGrams; + +class RefValue : public std::pair,NGrams> +{ +public: + RefValue& operator=( const RefValue& rhs ) { + first = rhs.first; + second = rhs.second; + return *this; + } +}; + + +class BleuScoreFeature : public StatefulFeatureFunction +{ +public: + static const std::vector& GetColl() { + return s_staticColl; + } + + typedef boost::unordered_map RefCounts; + typedef boost::unordered_map Matches; + + BleuScoreFeature(const std::string &line); + + void SetParameter(const std::string& key, const std::string& value); + + std::vector DefaultWeights() const; + + void PrintHistory(std::ostream& out) const; + void LoadReferences(const std::vector< std::vector< std::string > > &); + void SetCurrSourceLength(size_t); + void SetCurrNormSourceLength(size_t); + void SetCurrShortestRefLength(size_t); + void SetCurrAvgRefLength(size_t sent_id); + void SetAvgInputLength (float l) { + m_avg_input_length = l; + } + void SetCurrReferenceNgrams(size_t sent_id); + size_t GetShortestRefIndex(size_t ref_id); + size_t GetClosestRefLength(size_t ref_id, int hypoLength); + void UpdateHistory(const std::vector< const Word* >&); + void UpdateHistory(const std::vector< std::vector< const Word* > >& hypos, std::vector& sourceLengths, std::vector& ref_ids, size_t rank, size_t epoch); + void PrintRefLength(const std::vector& ref_ids); + void SetBleuParameters(bool disable, bool sentenceBleu, bool scaleByInputLength, bool scaleByAvgInputLength, + bool scaleByInverseLength, bool scaleByAvgInverseLength, + float scaleByX, float historySmoothing, size_t scheme, bool simpleHistoryBleu); + + void GetNgramMatchCounts(Phrase&, + const NGrams&, + std::vector< size_t >&, + std::vector< size_t >&, + size_t skip = 0) const; + void GetNgramMatchCounts_prefix(Phrase&, + const NGrams&, + std::vector< size_t >&, + std::vector< size_t >&, + size_t new_start_indices, + size_t last_end_index) const; + void GetNgramMatchCounts_overlap(Phrase& phrase, + const NGrams& ref_ngram_counts, + std::vector< size_t >& ret_counts, + std::vector< size_t >& ret_matches, + size_t overlap_index) const; + void GetClippedNgramMatchesAndCounts(Phrase&, + const NGrams&, + std::vector< size_t >&, + std::vector< size_t >&, + size_t skip = 0) const; + + FFState* EvaluateWhenApplied( const Hypothesis& cur_hypo, + const FFState* prev_state, + ScoreComponentCollection* accumulator) const; + FFState* EvaluateWhenApplied(const ChartHypothesis& cur_hypo, + int featureID, + ScoreComponentCollection* accumulator) const; + + bool Enabled() const { + return m_enabled; + } + + bool IsUseable(const FactorMask &mask) const; + + float CalculateBleu(BleuScoreState*) const; + float CalculateBleu(Phrase translation) const; + const FFState* EmptyHypothesisState(const InputType&) const; + + float GetSourceLengthHistory() { + return m_source_length_history; + } + float GetTargetLengthHistory() { + return m_target_length_history; + } + float GetAverageInputLength() { + return m_avg_input_length; + } + + void Load(AllOptions::ptr const& opts); + +private: + static std::vector s_staticColl; + + bool m_enabled; + bool m_sentence_bleu; + bool m_simple_history_bleu; + bool m_is_syntax; + // counts for pseudo-document + std::vector< float > m_count_history; + std::vector< float > m_match_history; + float m_source_length_history; + float m_target_length_history; + float m_ref_length_history; + + size_t m_cur_source_length; + size_t m_cur_norm_source_length; // length without , + RefCounts m_refs; + NGrams m_cur_ref_ngrams; + float m_cur_ref_length; + + // scale BLEU score by history of input length + bool m_scale_by_input_length; + bool m_scale_by_avg_input_length; + + // scale by the inverse of the input length * 100 + bool m_scale_by_inverse_length; + bool m_scale_by_avg_inverse_length; + + float m_avg_input_length; + + float m_scale_by_x; + + // smoothing factor for history counts + float m_historySmoothing; + + enum SmoothingScheme { PLUS_ONE = 1, PLUS_POINT_ONE = 2, PAPINENI = 3 }; + SmoothingScheme m_smoothing_scheme; +}; + +} // Namespace. + +#endif //BLUESCOREFEATURE_H + diff --git a/mosesdecoder/moses/FF/ConstrainedDecoding.cpp b/mosesdecoder/moses/FF/ConstrainedDecoding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2eefbc44468d0823acefed151cc1db701cd6e982 --- /dev/null +++ b/mosesdecoder/moses/FF/ConstrainedDecoding.cpp @@ -0,0 +1,212 @@ +#include "ConstrainedDecoding.h" +#include "moses/Hypothesis.h" +#include "moses/Manager.h" +#include "moses/ChartHypothesis.h" +#include "moses/ChartManager.h" +#include "moses/StaticData.h" +#include "moses/InputFileStream.h" +#include "moses/Util.h" +#include "util/exception.hh" + +using namespace std; + +namespace Moses +{ +ConstrainedDecodingState::ConstrainedDecodingState(const Hypothesis &hypo) +{ + hypo.GetOutputPhrase(m_outputPhrase); +} + +ConstrainedDecodingState::ConstrainedDecodingState(const ChartHypothesis &hypo) +{ + hypo.GetOutputPhrase(m_outputPhrase); +} + +size_t ConstrainedDecodingState::hash() const +{ + size_t ret = hash_value(m_outputPhrase); + return ret; +} + +bool ConstrainedDecodingState::operator==(const FFState& other) const +{ + const ConstrainedDecodingState &otherFF = static_cast(other); + bool ret = m_outputPhrase == otherFF.m_outputPhrase; + return ret; +} + +////////////////////////////////////////////////////////////////// +ConstrainedDecoding::ConstrainedDecoding(const std::string &line) + :StatefulFeatureFunction(1, line) + ,m_maxUnknowns(0) + ,m_negate(false) + ,m_soft(false) +{ + m_tuneable = false; + ReadParameters(); +} + +void ConstrainedDecoding::Load(AllOptions::ptr const& opts) +{ + m_options = opts; + const StaticData &staticData = StaticData::Instance(); + bool addBeginEndWord + = ((opts->search.algo == CYKPlus) || (opts->search.algo == ChartIncremental)); + + for(size_t i = 0; i < m_paths.size(); ++i) { + InputFileStream constraintFile(m_paths[i]); + std::string line; + long sentenceID = opts->output.start_translation_id - 1 ; + while (getline(constraintFile, line)) { + vector vecStr = Tokenize(line, "\t"); + + Phrase phrase(0); + if (vecStr.size() == 1) { + sentenceID++; + phrase.CreateFromString(Output, opts->output.factor_order, vecStr[0], NULL); + } else if (vecStr.size() == 2) { + sentenceID = Scan(vecStr[0]); + phrase.CreateFromString(Output, opts->output.factor_order, vecStr[1], NULL); + } else { + UTIL_THROW(util::Exception, "Reference file not loaded"); + } + + if (addBeginEndWord) { + phrase.InitStartEndWord(); + } + m_constraints[sentenceID].push_back(phrase); + } + } +} + +std::vector ConstrainedDecoding::DefaultWeights() const +{ + UTIL_THROW_IF2(m_numScoreComponents != 1, + "ConstrainedDecoding must only have 1 score"); + vector ret(1, 1); + return ret; +} + +template +const std::vector *GetConstraint(const std::map > &constraints, const H &hypo) +{ + const M &mgr = hypo.GetManager(); + const InputType &input = mgr.GetSource(); + long id = input.GetTranslationId(); + + map >::const_iterator iter; + iter = constraints.find(id); + + if (iter == constraints.end()) { + UTIL_THROW(util::Exception, "Couldn't find reference " << id); + + return NULL; + } else { + return &iter->second; + } +} + +FFState* ConstrainedDecoding::EvaluateWhenApplied( + const Hypothesis& hypo, + const FFState* prev_state, + ScoreComponentCollection* accumulator) const +{ + const std::vector *ref = GetConstraint(m_constraints, hypo); + assert(ref); + + ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo); + const Phrase &outputPhrase = ret->GetPhrase(); + + size_t searchPos = NOT_FOUND; + size_t i = 0; + size_t size = 0; + while(searchPos == NOT_FOUND && i < ref->size()) { + searchPos = (*ref)[i].Find(outputPhrase, m_maxUnknowns); + size = (*ref)[i].GetSize(); + i++; + } + + float score; + if (hypo.IsSourceCompleted()) { + // translated entire sentence. + bool match = (searchPos == 0) && (size == outputPhrase.GetSize()); + if (!m_negate) { + score = match ? 0 : - ( m_soft ? 1 : std::numeric_limits::infinity()); + } else { + score = !match ? 0 : - ( m_soft ? 1 : std::numeric_limits::infinity()); + } + } else if (m_negate) { + // keep all derivations + score = 0; + } else { + score = (searchPos != NOT_FOUND) ? 0 : - ( m_soft ? 1 : std::numeric_limits::infinity()); + } + + accumulator->PlusEquals(this, score); + + return ret; +} + +FFState* ConstrainedDecoding::EvaluateWhenApplied( + const ChartHypothesis &hypo, + int /* featureID - used to index the state in the previous hypotheses */, + ScoreComponentCollection* accumulator) const +{ + const std::vector *ref = GetConstraint(m_constraints, hypo); + assert(ref); + + const ChartManager &mgr = hypo.GetManager(); + const Sentence &source = static_cast(mgr.GetSource()); + + ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo); + const Phrase &outputPhrase = ret->GetPhrase(); + + size_t searchPos = NOT_FOUND; + size_t i = 0; + size_t size = 0; + while(searchPos == NOT_FOUND && i < ref->size()) { + searchPos = (*ref)[i].Find(outputPhrase, m_maxUnknowns); + size = (*ref)[i].GetSize(); + i++; + } + + float score; + if (hypo.GetCurrSourceRange().GetStartPos() == 0 && + hypo.GetCurrSourceRange().GetEndPos() == source.GetSize() - 1) { + // translated entire sentence. + bool match = (searchPos == 0) && (size == outputPhrase.GetSize()); + + if (!m_negate) { + score = match ? 0 : - ( m_soft ? 1 : std::numeric_limits::infinity()); + } else { + score = !match ? 0 : - ( m_soft ? 1 : std::numeric_limits::infinity()); + } + } else if (m_negate) { + // keep all derivations + score = 0; + } else { + score = (searchPos != NOT_FOUND) ? 0 : - ( m_soft ? 1 : std::numeric_limits::infinity()); + } + + accumulator->PlusEquals(this, score); + + return ret; +} + +void ConstrainedDecoding::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "path") { + m_paths = Tokenize(value, ","); + } else if (key == "max-unknowns") { + m_maxUnknowns = Scan(value); + } else if (key == "negate") { + m_negate = Scan(value); + } else if (key == "soft") { + m_soft = Scan(value); + } else { + StatefulFeatureFunction::SetParameter(key, value); + } +} + +} + diff --git a/mosesdecoder/moses/FF/ControlRecombination.cpp b/mosesdecoder/moses/FF/ControlRecombination.cpp new file mode 100644 index 0000000000000000000000000000000000000000..10c2898b123f8ace2803ff4a0f0452dd043aa324 --- /dev/null +++ b/mosesdecoder/moses/FF/ControlRecombination.cpp @@ -0,0 +1,96 @@ +#include "ControlRecombination.h" +#include "moses/Hypothesis.h" +#include "moses/Manager.h" +#include "moses/ChartHypothesis.h" +#include "moses/ChartManager.h" +#include "moses/StaticData.h" +#include "moses/InputFileStream.h" +#include "moses/Util.h" +#include "util/exception.hh" + +using namespace std; + +namespace Moses +{ +ControlRecombinationState::ControlRecombinationState(const Hypothesis &hypo, const ControlRecombination &ff) + :m_ff(ff) +{ + if (ff.GetType() == SameOutput) { + //UTIL_THROW(util::Exception, "Implemented not yet completed for phrase-based model. Need to take into account the coverage"); + hypo.GetOutputPhrase(m_outputPhrase); + } else { + m_hypo = &hypo; + } +} + +ControlRecombinationState::ControlRecombinationState(const ChartHypothesis &hypo, const ControlRecombination &ff) + :m_ff(ff) +{ + if (ff.GetType() == SameOutput) { + hypo.GetOutputPhrase(m_outputPhrase); + } else { + m_hypo = &hypo; + } +} + +size_t ControlRecombinationState::hash() const +{ + size_t ret; + if (m_ff.GetType() == SameOutput) { + ret = hash_value(m_outputPhrase); + } else { + // compare hypo address. Won't be equal unless they're actually the same hypo + ret = (size_t) m_hypo; + } + return ret; +} + +bool ControlRecombinationState::operator==(const FFState& other) const +{ + const ControlRecombinationState &otherFF = static_cast(other); + + if (m_ff.GetType() == SameOutput) { + return m_outputPhrase == otherFF.m_outputPhrase; + } else { + // compare hypo address. Won't be equal unless they're actually the same hypo + if (m_hypo == otherFF.m_hypo) + return true; + return (m_hypo == otherFF.m_hypo); + } +} + +std::vector ControlRecombination::DefaultWeights() const +{ + UTIL_THROW_IF2(m_numScoreComponents, + "ControlRecombination should not have any scores"); + vector ret(0); + return ret; +} + +FFState* ControlRecombination::EvaluateWhenApplied( + const Hypothesis& hypo, + const FFState* prev_state, + ScoreComponentCollection* accumulator) const +{ + return new ControlRecombinationState(hypo, *this); +} + +FFState* ControlRecombination::EvaluateWhenApplied( + const ChartHypothesis &hypo, + int /* featureID - used to index the state in the previous hypotheses */, + ScoreComponentCollection* accumulator) const +{ + return new ControlRecombinationState(hypo, *this); +} + +void ControlRecombination::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "type") { + m_type = (ControlRecombinationType) Scan(value); + } else { + StatefulFeatureFunction::SetParameter(key, value); + } +} + +} + diff --git a/mosesdecoder/moses/FF/ControlRecombination.h b/mosesdecoder/moses/FF/ControlRecombination.h new file mode 100644 index 0000000000000000000000000000000000000000..034b1a790ae2198f91af9b6fb45d490e320ff115 --- /dev/null +++ b/mosesdecoder/moses/FF/ControlRecombination.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include "StatefulFeatureFunction.h" +#include "FFState.h" +#include "moses/Phrase.h" + +namespace Moses +{ +enum ControlRecombinationType { + // when to recombine + SameOutput = 1, + Never = 2 +}; + +class ControlRecombination; + +class ControlRecombinationState : public FFState +{ +public: + ControlRecombinationState(const ControlRecombination &ff) + :m_ff(ff) { + } + + ControlRecombinationState(const Hypothesis &hypo, const ControlRecombination &ff); + ControlRecombinationState(const ChartHypothesis &hypo, const ControlRecombination &ff); + + virtual size_t hash() const; + virtual bool operator==(const FFState& other) const; + + const Phrase &GetPhrase() const { + return m_outputPhrase; + } + +protected: + Phrase m_outputPhrase; + const ControlRecombination &m_ff; + const void *m_hypo; +}; + +////////////////////////////////////////////////////////////////// + +// only allow recombination for the same output +class ControlRecombination : public StatefulFeatureFunction +{ +public: + ControlRecombination(const std::string &line) + :StatefulFeatureFunction(0, line) + ,m_type(SameOutput) + + { + m_tuneable = false; + ReadParameters(); + } + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + FFState* EvaluateWhenApplied( + const Hypothesis& cur_hypo, + const FFState* prev_state, + ScoreComponentCollection* accumulator) const; + + FFState* EvaluateWhenApplied( + const ChartHypothesis& /* cur_hypo */, + int /* featureID - used to index the state in the previous hypotheses */, + ScoreComponentCollection* accumulator) const; + + virtual const FFState* EmptyHypothesisState(const InputType &input) const { + return new ControlRecombinationState(*this); + } + + std::vector DefaultWeights() const; + + void SetParameter(const std::string& key, const std::string& value); + + ControlRecombinationType GetType() const { + return m_type; + } +protected: + ControlRecombinationType m_type; +}; + + +} + diff --git a/mosesdecoder/moses/FF/CorrectionPattern.cpp b/mosesdecoder/moses/FF/CorrectionPattern.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9770f7d66989c27118dbfbfbfcadf6cf9c5e40b0 --- /dev/null +++ b/mosesdecoder/moses/FF/CorrectionPattern.cpp @@ -0,0 +1,354 @@ +#include +#include "CorrectionPattern.h" +#include "moses/Phrase.h" +#include "moses/TargetPhrase.h" +#include "moses/InputPath.h" +#include "moses/Hypothesis.h" +#include "moses/ChartHypothesis.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/TranslationOption.h" +#include "util/string_piece_hash.hh" +#include "util/exception.hh" + +#include +#include + +#include +#include + +#include "Diffs.h" + +namespace Moses +{ + +using namespace std; + +std::string MakePair(const std::string &s1, const std::string &s2, bool general) +{ + std::vector sourceList; + std::vector targetList; + + if(general) { + Diffs diffs = CreateDiff(s1, s2); + + size_t i = 0, j = 0; + char lastType = 'm'; + + std::string source, target; + std::string match; + + int count = 1; + + BOOST_FOREACH(Diff type, diffs) { + if(type == 'm') { + if(lastType != 'm') { + sourceList.push_back(source); + targetList.push_back(target); + } + source.clear(); + target.clear(); + + if(s1[i] == '+') { + if(match.size() >= 3) { + sourceList.push_back("(\\w{3,})·"); + std::string temp = "1"; + sprintf((char*)temp.c_str(), "%d", count); + targetList.push_back("\\" + temp + "·"); + count++; + } else { + sourceList.push_back(match + "·"); + targetList.push_back(match + "·"); + } + match.clear(); + } else + match.push_back(s1[i]); + + i++; + j++; + } else if(type == 'd') { + if(s1[i] == '+') + source += "·"; + else + source.push_back(s1[i]); + i++; + } else if(type == 'i') { + if(s2[j] == '+') + target += "·"; + else + target.push_back(s2[j]); + j++; + } + if(type != 'm' && !match.empty()) { + if(match.size() >= 3) { + sourceList.push_back("(\\w{3,})"); + std::string temp = "1"; + sprintf((char*)temp.c_str(), "%d", count); + targetList.push_back("\\" + temp); + count++; + } else { + sourceList.push_back(match); + targetList.push_back(match); + } + + match.clear(); + } + + lastType = type; + } + if(lastType != 'm') { + sourceList.push_back(source); + targetList.push_back(target); + } + + if(!match.empty()) { + if(match.size() >= 3) { + sourceList.push_back("(\\w{3,})"); + std::string temp = "1"; + sprintf((char*)temp.c_str(), "%d", count); + targetList.push_back("\\"+ temp); + count++; + } else { + sourceList.push_back(match); + targetList.push_back(match); + } + } + match.clear(); + } else { + std::string cs1 = s1; + std::string cs2 = s2; + boost::replace_all(cs1, "+", "·"); + boost::replace_all(cs2, "+", "·"); + + sourceList.push_back(cs1); + targetList.push_back(cs2); + } + + std::stringstream out; + out << "sub(«"; + out << boost::join(sourceList, ""); + out << "»,«"; + out << boost::join(targetList, ""); + out << "»)"; + + return out.str(); +} + +std::string CorrectionPattern::CreateSinglePattern(const Tokens &s1, const Tokens &s2) const +{ + std::stringstream out; + if(s1.empty()) { + out << "ins(«" << boost::join(s2, "·") << "»)"; + return out.str(); + } else if(s2.empty()) { + out << "del(«" << boost::join(s1, "·") << "»)"; + return out.str(); + } else { + Tokens::value_type v1 = boost::join(s1, "+"); + Tokens::value_type v2 = boost::join(s2, "+"); + out << MakePair(v1, v2, m_general); + return out.str(); + } +} + +std::vector GetContext(size_t pos, + size_t len, + size_t window, + const InputType &input, + const InputPath &inputPath, + const std::vector& factorTypes, + bool isRight) +{ + + const Sentence& sentence = static_cast(input); + const Range& range = inputPath.GetWordsRange(); + + int leftPos = range.GetStartPos() + pos - len - 1; + int rightPos = range.GetStartPos() + pos; + + std::vector contexts; + + for(int length = 1; length <= (int)window; ++length) { + std::vector current; + if(!isRight) { + for(int i = 0; i < length; i++) { + if(leftPos - i >= 0) { + current.push_back(sentence.GetWord(leftPos - i).GetString(factorTypes, false)); + } else { + current.push_back(""); + } + } + + if(current.back() == "" && current.size() >= 2 && current[current.size()-2] == "") + continue; + + std::reverse(current.begin(), current.end()); + contexts.push_back("left(«" + boost::join(current, "·") + "»)_"); + } + if(isRight) { + for(int i = 0; i < length; i++) { + if(rightPos + i < (int)sentence.GetSize()) { + current.push_back(sentence.GetWord(rightPos + i).GetString(factorTypes, false)); + } else { + current.push_back(""); + } + } + + if(current.back() == "" && current.size() >= 2 && current[current.size()-2] == "") + continue; + + contexts.push_back("_right(«" + boost::join(current, "·") + "»)"); + } + } + return contexts; +} + +std::vector +CorrectionPattern::CreatePattern(const Tokens &s1, + const Tokens &s2, + const InputType &input, + const InputPath &inputPath) const +{ + + Diffs diffs = CreateDiff(s1, s2); + size_t i = 0, j = 0; + char lastType = 'm'; + std::vector patternList; + Tokens source, target; + BOOST_FOREACH(Diff type, diffs) { + if(type == 'm') { + if(lastType != 'm') { + std::string pattern = CreateSinglePattern(source, target); + patternList.push_back(pattern); + + if(m_context > 0) { + std::vector leftContexts = GetContext(i, source.size(), m_context, input, inputPath, m_contextFactors, false); + std::vector rightContexts = GetContext(i, source.size(), m_context, input, inputPath, m_contextFactors, true); + + BOOST_FOREACH(std::string left, leftContexts) + patternList.push_back(left + pattern); + + BOOST_FOREACH(std::string right, rightContexts) + patternList.push_back(pattern + right); + + BOOST_FOREACH(std::string left, leftContexts) + BOOST_FOREACH(std::string right, rightContexts) + patternList.push_back(left + pattern + right); + } + } + source.clear(); + target.clear(); + if(s1[i] != s2[j]) { + source.push_back(s1[i]); + target.push_back(s2[j]); + } + i++; + j++; + } else if(type == 'd') { + source.push_back(s1[i]); + i++; + } else if(type == 'i') { + target.push_back(s2[j]); + j++; + } + lastType = type; + } + if(lastType != 'm') { + std::string pattern = CreateSinglePattern(source, target); + patternList.push_back(pattern); + + if(m_context > 0) { + std::vector leftContexts = GetContext(i, source.size(), m_context, input, inputPath, m_contextFactors, false); + std::vector rightContexts = GetContext(i, source.size(), m_context, input, inputPath, m_contextFactors, true); + + BOOST_FOREACH(std::string left, leftContexts) + patternList.push_back(left + pattern); + + BOOST_FOREACH(std::string right, rightContexts) + patternList.push_back(pattern + right); + + BOOST_FOREACH(std::string left, leftContexts) + BOOST_FOREACH(std::string right, rightContexts) + patternList.push_back(left + pattern + right); + } + } + + return patternList; +} + +CorrectionPattern::CorrectionPattern(const std::string &line) + : StatelessFeatureFunction(0, line), m_factors(1, 0), m_general(false), + m_context(0), m_contextFactors(1, 0) +{ + std::cerr << "Initializing correction pattern feature.." << std::endl; + ReadParameters(); +} + +void CorrectionPattern::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "factor") { + m_factors = std::vector(1, Scan(value)); + } else if (key == "context-factor") { + m_contextFactors = std::vector(1, Scan(value)); + } else if (key == "general") { + m_general = Scan(value); + } else if (key == "context") { + m_context = Scan(value); + } else { + StatelessFeatureFunction::SetParameter(key, value); + } +} + +void CorrectionPattern::EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedFutureScore) const +{ + ComputeFeatures(input, inputPath, targetPhrase, &scoreBreakdown); +} + +void CorrectionPattern::ComputeFeatures( + const InputType &input, + const InputPath &inputPath, + const TargetPhrase& target, + ScoreComponentCollection* accumulator) const +{ + const Phrase &source = inputPath.GetPhrase(); + + std::vector sourceTokens; + for(size_t i = 0; i < source.GetSize(); ++i) + sourceTokens.push_back(source.GetWord(i).GetString(m_factors, false)); + + std::vector targetTokens; + for(size_t i = 0; i < target.GetSize(); ++i) + targetTokens.push_back(target.GetWord(i).GetString(m_factors, false)); + + std::vector patternList = CreatePattern(sourceTokens, targetTokens, input, inputPath); + for(size_t i = 0; i < patternList.size(); ++i) + accumulator->PlusEquals(this, patternList[i], 1); + + /* + BOOST_FOREACH(std::string w, sourceTokens) + std::cerr << w << " "; + std::cerr << std::endl; + BOOST_FOREACH(std::string w, targetTokens) + std::cerr << w << " "; + std::cerr << std::endl; + BOOST_FOREACH(std::string w, patternList) + std::cerr << w << " "; + std::cerr << std::endl << std::endl; + */ +} + +bool CorrectionPattern::IsUseable(const FactorMask &mask) const +{ + bool ret = true; + for(size_t i = 0; i < m_factors.size(); ++i) + ret = ret && mask[m_factors[i]]; + for(size_t i = 0; i < m_contextFactors.size(); ++i) + ret = ret && mask[m_contextFactors[i]]; + return ret; +} + +} diff --git a/mosesdecoder/moses/FF/CorrectionPattern.h b/mosesdecoder/moses/FF/CorrectionPattern.h new file mode 100644 index 0000000000000000000000000000000000000000..516a56ce2e7ac896e345af6fe95c3e79ac5f0e0b --- /dev/null +++ b/mosesdecoder/moses/FF/CorrectionPattern.h @@ -0,0 +1,73 @@ +#ifndef moses_CorrectionPattern_h +#define moses_CorrectionPattern_h + +#include +#include + +#include "StatelessFeatureFunction.h" +#include "moses/FactorCollection.h" +#include "moses/AlignmentInfo.h" + +namespace Moses +{ + +typedef std::vector Tokens; + +/** Sets the features for length of source phrase, target phrase, both. + */ +class CorrectionPattern : public StatelessFeatureFunction +{ +private: + std::vector m_factors; + bool m_general; + size_t m_context; + std::vector m_contextFactors; + +public: + CorrectionPattern(const std::string &line); + + bool IsUseable(const FactorMask &mask) const; + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedFutureScore) const + {} + + virtual void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedFutureScore = NULL) const; + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const + {} + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const + {} + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const + {} + + void ComputeFeatures(const InputType &input, + const InputPath &inputPath, + const TargetPhrase& targetPhrase, + ScoreComponentCollection* accumulator) const; + + void SetParameter(const std::string& key, const std::string& value); + + std::vector CreatePattern(const Tokens &s1, + const Tokens &s2, + const InputType &input, + const InputPath &inputPath) const; + + std::string CreateSinglePattern(const Tokens &s1, const Tokens &s2) const; + +}; + +} + +#endif // moses_CorrectionPattern_h diff --git a/mosesdecoder/moses/FF/CoveredReferenceFeature.cpp b/mosesdecoder/moses/FF/CoveredReferenceFeature.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e05d62a8674d5ba21bfcb86a441a5517ad72b76b --- /dev/null +++ b/mosesdecoder/moses/FF/CoveredReferenceFeature.cpp @@ -0,0 +1,129 @@ +#include +#include +#include +#include +#include +#include "CoveredReferenceFeature.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/Hypothesis.h" +#include "moses/Manager.h" +#include "moses/ChartHypothesis.h" +#include "moses/ChartManager.h" +#include "moses/StaticData.h" +#include "moses/InputFileStream.h" +#include "moses/Util.h" +#include "util/exception.hh" + +using namespace std; + +namespace Moses +{ + +size_t CoveredReferenceState::hash() const +{ + UTIL_THROW2("TODO:Haven't figure this out yet"); +} + +bool CoveredReferenceState::operator==(const FFState& other) const +{ + UTIL_THROW2("TODO:Haven't figure this out yet"); +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +void CoveredReferenceFeature::EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores) const +{ + long id = input.GetTranslationId(); + boost::unordered_map >::const_iterator refIt = m_refs.find(id); + multiset wordsInPhrase = GetWordsInPhrase(targetPhrase); + multiset covered; + set_intersection(wordsInPhrase.begin(), wordsInPhrase.end(), + refIt->second.begin(), refIt->second.end(), + inserter(covered, covered.begin())); + vector scores; + scores.push_back(covered.size()); + + scoreBreakdown.Assign(this, scores); + estimatedScores->Assign(this, scores); +} + +void CoveredReferenceFeature::Load(AllOptions::ptr const& opts) +{ + m_options = opts; + InputFileStream refFile(m_path); + std::string line; + const StaticData &staticData = StaticData::Instance(); + long sentenceID = opts->output.start_translation_id; + while (getline(refFile, line)) { + vector words = Tokenize(line, " "); + multiset wordSet; + // TODO make Tokenize work with other containers than vector + copy(words.begin(), words.end(), inserter(wordSet, wordSet.begin())); + m_refs.insert(make_pair(sentenceID++, wordSet)); + } +} + +void CoveredReferenceFeature::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "path") { + m_path = value; + } else { + StatefulFeatureFunction::SetParameter(key, value); + } +} + +FFState* CoveredReferenceFeature::EvaluateWhenApplied( + const Hypothesis& cur_hypo, + const FFState* prev_state, + ScoreComponentCollection* accumulator) const +{ + const CoveredReferenceState &prev = static_cast(*prev_state); + CoveredReferenceState *ret = new CoveredReferenceState(prev); + + const Manager &mgr = cur_hypo.GetManager(); + const InputType &input = mgr.GetSource(); + long id = input.GetTranslationId(); + + // which words from the reference remain uncovered + multiset remaining; + boost::unordered_map >::const_iterator refIt = m_refs.find(id); + if (refIt == m_refs.end()) UTIL_THROW(util::Exception, "Sentence id out of range: " + SPrint(id)); + set_difference(refIt->second.begin(), refIt->second.end(), + ret->m_coveredRef.begin(), ret->m_coveredRef.end(), + inserter(remaining, remaining.begin())); + + // which of the remaining words are present in the current phrase + multiset wordsInPhrase = GetWordsInPhrase(cur_hypo.GetCurrTargetPhrase()); + multiset newCovered; + set_intersection(wordsInPhrase.begin(), wordsInPhrase.end(), + remaining.begin(), remaining.end(), + inserter(newCovered, newCovered.begin())); + + vector estimateScore = + cur_hypo.GetCurrTargetPhrase().GetScoreBreakdown().GetScoresForProducer(this); + vector scores; + scores.push_back(newCovered.size() - estimateScore[0]); + accumulator->PlusEquals(this, scores); + + // update feature state + multiset::const_iterator newCoveredIt; + for (newCoveredIt = newCovered.begin(); newCoveredIt != newCovered.end(); newCoveredIt++) { + ret->m_coveredRef.insert(*newCoveredIt); + } + return ret; +} + +FFState* CoveredReferenceFeature::EvaluateWhenApplied( + const ChartHypothesis& /* cur_hypo */, + int /* featureID - used to index the state in the previous hypotheses */, + ScoreComponentCollection* accumulator) const +{ + UTIL_THROW(util::Exception, "Not implemented"); +} + +} diff --git a/mosesdecoder/moses/FF/DecodeFeature.cpp b/mosesdecoder/moses/FF/DecodeFeature.cpp new file mode 100644 index 0000000000000000000000000000000000000000..64b120519b2834c1ba9fa7656f05e60626fabb74 --- /dev/null +++ b/mosesdecoder/moses/FF/DecodeFeature.cpp @@ -0,0 +1,117 @@ +// $Id: PhraseDictionaryMemory.cpp 2477 2009-08-07 16:47:54Z bhaddow $ +// vim:tabstop=2 + +/*********************************************************************** +Moses - factored phrase-based language decoder +Copyright (C) 2010 University of Edinburgh + +This library is free software; you can redistribute it and/or +modify it under the terms of the GNU Lesser General Public +License as published by the Free Software Foundation; either +version 2.1 of the License, or (at your option) any later version. + +This library is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public +License along with this library; if not, write to the Free Software +Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +***********************************************************************/ + +#include + +#include "DecodeFeature.h" +#include "moses/DecodeStep.h" +#include "moses/StaticData.h" + +using namespace std; + +namespace Moses +{ + +DecodeFeature::DecodeFeature(const std::string &line, bool registerNow) + : StatelessFeatureFunction(line, registerNow) + , m_container(NULL) +{ + VERBOSE(2,"DecodeFeature:" << std::endl); +} + +DecodeFeature::DecodeFeature(size_t numScoreComponents + , const std::string &line) + : StatelessFeatureFunction(numScoreComponents, line) + , m_container(NULL) +{ + VERBOSE(2,"DecodeFeature: no factors yet" << std::endl); +} + +DecodeFeature::DecodeFeature(size_t numScoreComponents + , const std::vector &input + , const std::vector &output + , const std::string &line) + : StatelessFeatureFunction(numScoreComponents, line) + , m_input(input), m_output(output) + , m_container(NULL) +{ + m_inputFactors = FactorMask(input); + m_outputFactors = FactorMask(output); + VERBOSE(2,"DecodeFeature: input=" << m_inputFactors << " output=" << m_outputFactors << std::endl); +} + +void DecodeFeature::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "input-factor") { + m_input =Tokenize(value, ","); + m_inputFactors = FactorMask(m_input); + } else if (key == "output-factor") { + m_output =Tokenize(value, ","); + m_outputFactors = FactorMask(m_output); + } else { + StatelessFeatureFunction::SetParameter(key, value); + } +} + + +const FactorMask& DecodeFeature::GetOutputFactorMask() const +{ + return m_outputFactors; +} + + +const FactorMask& DecodeFeature::GetInputFactorMask() const +{ + return m_inputFactors; +} + +const std::vector& DecodeFeature::GetInput() const +{ + return m_input; +} + +const std::vector& DecodeFeature::GetOutput() const +{ + return m_output; +} + +bool DecodeFeature::IsUseable(const FactorMask &mask) const +{ + for (size_t i = 0; i < m_output.size(); ++i) { + const FactorType &factor = m_output[i]; + if (!mask[factor]) { + return false; + } + } + return true; +} + +const DecodeGraph &DecodeFeature::GetDecodeGraph() const +{ + assert(m_container); + const DecodeGraph *graph = m_container->GetContainer(); + assert(graph); + return *graph; +} + +} + diff --git a/mosesdecoder/moses/FF/DistortionScoreProducer.h b/mosesdecoder/moses/FF/DistortionScoreProducer.h new file mode 100644 index 0000000000000000000000000000000000000000..d59214df76f09eb6ae1e8e1ae59f2e00d927eb9c --- /dev/null +++ b/mosesdecoder/moses/FF/DistortionScoreProducer.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include "StatefulFeatureFunction.h" +#include "moses/Range.h" + +namespace Moses +{ + +/** Calculates Distortion scores + */ +class DistortionScoreProducer : public StatefulFeatureFunction +{ +protected: + static std::vector s_staticColl; + + FactorType m_sparseFactorTypeSource; + FactorType m_sparseFactorTypeTarget; + bool m_useSparse; + bool m_sparseDistance; + bool m_sparseSubordinate; + FactorType m_sparseFactorTypeTargetSubordinate; + const Factor* m_subordinateConjunctionTagFactor; + +public: + static const std::vector& GetDistortionFeatureFunctions() { + return s_staticColl; + } + + DistortionScoreProducer(const std::string &line); + + void SetParameter(const std::string& key, const std::string& value); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + static float CalculateDistortionScore(const Hypothesis& hypo, + const Range &prev, const Range &curr, const int FirstGapPosition); + + virtual const FFState* EmptyHypothesisState(const InputType &input) const; + + virtual FFState* EvaluateWhenApplied( + const Hypothesis& cur_hypo, + const FFState* prev_state, + ScoreComponentCollection* accumulator) const; + + virtual FFState* EvaluateWhenApplied( + const ChartHypothesis& /* cur_hypo */, + int /* featureID - used to index the state in the previous hypotheses */, + ScoreComponentCollection*) const { + UTIL_THROW(util::Exception, "DIstortion not implemented in chart decoder"); + } + +}; +} + diff --git a/mosesdecoder/moses/FF/DynamicCacheBasedLanguageModel.cpp b/mosesdecoder/moses/FF/DynamicCacheBasedLanguageModel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f45e02bc12a807af2372ed202df6b7492b9502f4 --- /dev/null +++ b/mosesdecoder/moses/FF/DynamicCacheBasedLanguageModel.cpp @@ -0,0 +1,459 @@ +#include +#include "moses/StaticData.h" +#include "moses/InputFileStream.h" +#include "DynamicCacheBasedLanguageModel.h" + +namespace Moses +{ + +std::map< const std::string, DynamicCacheBasedLanguageModel * > DynamicCacheBasedLanguageModel::s_instance_map; +DynamicCacheBasedLanguageModel *DynamicCacheBasedLanguageModel::s_instance = NULL; + +DynamicCacheBasedLanguageModel::DynamicCacheBasedLanguageModel(const std::string &line) + : StatelessFeatureFunction(1, line) +{ + VERBOSE(2,"Initializing DynamicCacheBasedLanguageModel feature..." << std::endl); + + m_query_type = CBLM_QUERY_TYPE_ALLSUBSTRINGS; + m_score_type = CBLM_SCORE_TYPE_HYPERBOLA; + m_maxAge = 1000; + m_name = "default"; + m_constant = false; + + ReadParameters(); + UTIL_THROW_IF2(s_instance_map.find(m_name) != s_instance_map.end(), "Only 1 DynamicCacheBasedLanguageModel feature named " + m_name + " is allowed"); + s_instance_map[m_name] = this; + s_instance = this; //for back compatibility + + SetPreComputedScores(); +} + +DynamicCacheBasedLanguageModel::~DynamicCacheBasedLanguageModel() {}; + +void DynamicCacheBasedLanguageModel::SetPreComputedScores() +{ +#ifdef WITH_THREADS + boost::shared_lock lock(m_cacheLock); +#endif + precomputedScores.clear(); + for (unsigned int i=0; i(value)); + } else if (key == "cblm-score-type") { + SetScoreType(Scan(value)); + } else if (key == "cblm-max-age") { + SetMaxAge(Scan(value)); + } else if (key == "cblm-file") { + m_initfiles = Scan(value); + } else if (key == "cblm-name") { + m_name = Scan(value); + } else if (key == "cblm-constant") { + m_constant = Scan(value); + } else { + StatelessFeatureFunction::SetParameter(key, value); + } +} + +void DynamicCacheBasedLanguageModel::EvaluateInIsolation(const Phrase &sp + , const TargetPhrase &tp + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const +{ + float score = m_lower_score; + switch(m_query_type) { + case CBLM_QUERY_TYPE_WHOLESTRING: + score = Evaluate_Whole_String(tp); + break; + case CBLM_QUERY_TYPE_ALLSUBSTRINGS: + score = Evaluate_All_Substrings(tp); + break; + default: + UTIL_THROW_IF2(false, "This score type (" << m_query_type << ") is unknown."); + } + + scoreBreakdown.Assign(this, score); +} + +float DynamicCacheBasedLanguageModel::Evaluate_Whole_String(const TargetPhrase& tp) const +{ + //consider all words in the TargetPhrase as one n-gram + // and compute the decaying_score for the whole n-gram + // and return this value + + decaying_cache_t::const_iterator it; + float score = m_lower_score; + + std::string w = ""; + size_t endpos = tp.GetSize(); + for (size_t pos = 0 ; pos < endpos ; ++pos) { + w += tp.GetWord(pos).GetFactor(0)->GetString().as_string(); + if (pos < endpos - 1) { + w += " "; + } + } + it = m_cache.find(w); + + VERBOSE(4,"cblm::Evaluate_Whole_String: searching w:|" << w << "|" << std::endl); + if (it != m_cache.end()) { //found! + score = ((*it).second).second; + VERBOSE(4,"cblm::Evaluate_Whole_String: found w:|" << w << "|" << std::endl); + } + + VERBOSE(4,"cblm::Evaluate_Whole_String: returning score:|" << score << "|" << std::endl); + return score; +} + +float DynamicCacheBasedLanguageModel::Evaluate_All_Substrings(const TargetPhrase& tp) const +{ + //loop over all n-grams in the TargetPhrase (no matter of n) + //and compute the decaying_score for all words + //and return their sum + + decaying_cache_t::const_iterator it; + float score = 0.0; + + for (size_t startpos = 0 ; startpos < tp.GetSize() ; ++startpos) { + std::string w = ""; + for (size_t endpos = startpos; endpos < tp.GetSize() ; ++endpos) { + w += tp.GetWord(endpos).GetFactor(0)->GetString().as_string(); + it = m_cache.find(w); + + if (it != m_cache.end()) { //found! + score += ((*it).second).second; + VERBOSE(3,"cblm::Evaluate_All_Substrings: found w:|" << w << "| actual score:|" << ((*it).second).second << "| score:|" << score << "|" << std::endl); + } else { + score += m_lower_score; + } + + if (endpos == startpos) { + w += " "; + } + + } + } + VERBOSE(3,"cblm::Evaluate_All_Substrings: returning score:|" << score << "|" << std::endl); + return score; +} + +void DynamicCacheBasedLanguageModel::Print() const +{ +#ifdef WITH_THREADS + boost::shared_lock read_lock(m_cacheLock); +#endif + decaying_cache_t::const_iterator it; + std::cout << "Content of the cache of Cache-Based Language Model" << std::endl; + std::cout << "Size of the cache of Cache-Based Language Model:|" << m_cache.size() << "|" << std::endl; + for ( it=m_cache.begin() ; it != m_cache.end(); it++ ) { + std::cout << "word:|" << (*it).first << "| age:|" << ((*it).second).first << "| score:|" << ((*it).second).second << "|" << std::endl; + } +} + +void DynamicCacheBasedLanguageModel::Decay() +{ +#ifdef WITH_THREADS + boost::shared_lock lock(m_cacheLock); +#endif + decaying_cache_t::iterator it; + + unsigned int age; + float score; + for ( it=m_cache.begin() ; it != m_cache.end(); it++ ) { + age=((*it).second).first + 1; + if (age > m_maxAge) { + m_cache.erase(it); + it--; + } else { + score = GetPreComputedScores(age); +// score = decaying_score(age); + decaying_cache_value_t p (age, score); + (*it).second = p; + } + } +} + +void DynamicCacheBasedLanguageModel::Update(std::vector words, int age) +{ +#ifdef WITH_THREADS + boost::shared_lock lock(m_cacheLock); +#endif + VERBOSE(3,"words.size():|" << words.size() << "|" << std::endl); + for (size_t j=0; j e (words[j],p); + m_cache.erase(words[j]); //always erase the element (do nothing if the entry does not exist) + m_cache.insert(e); //insert the entry + } +} + +void DynamicCacheBasedLanguageModel::ClearEntries(std::string &entries) +{ + if (entries != "") { + VERBOSE(3,"entries:|" << entries << "|" << std::endl); + std::vector elements = TokenizeMultiCharSeparator(entries, "||"); + VERBOSE(3,"elements.size() after:|" << elements.size() << "|" << std::endl); + ClearEntries(elements); + } +} + +void DynamicCacheBasedLanguageModel::ClearEntries(std::vector words) +{ +#ifdef WITH_THREADS + boost::shared_lock lock(m_cacheLock); +#endif + VERBOSE(3,"words.size():|" << words.size() << "|" << std::endl); + for (size_t j=0; j elements = TokenizeMultiCharSeparator(entries, "||"); + VERBOSE(3,"elements.size() after:|" << elements.size() << "|" << std::endl); + Insert(elements); + } +} + +void DynamicCacheBasedLanguageModel::Insert(std::vector ngrams) +{ + VERBOSE(3,"DynamicCacheBasedLanguageModel Insert ngrams.size():|" << ngrams.size() << "|" << std::endl); + if (m_constant == false) { + Decay(); + } + Update(ngrams,1); + IFVERBOSE(3) Print(); +} + +void DynamicCacheBasedLanguageModel::ExecuteDlt(std::map dlt_meta) +{ + if (dlt_meta.find("cblm") != dlt_meta.end()) { + Insert(dlt_meta["cblm"]); + } + if (dlt_meta.find("cblm-command") != dlt_meta.end()) { + Execute(dlt_meta["cblm-command"]); + } + if (dlt_meta.find("cblm-file") != dlt_meta.end()) { + Load(dlt_meta["cblm-file"]); + } + if (dlt_meta.find("cblm-clear-entries") != dlt_meta.end()) { + ClearEntries(dlt_meta["cblm-clear-entries"]); + } + if (dlt_meta.find("cblm-clear-all") != dlt_meta.end()) { + Clear(); + } + +} + +void DynamicCacheBasedLanguageModel::Execute(std::string command) +{ + VERBOSE(2,"DynamicCacheBasedLanguageModel::Execute(std::string command:|" << command << "|" << std::endl); + std::vector commands = Tokenize(command, "||"); + Execute(commands); +} + +void DynamicCacheBasedLanguageModel::Execute(std::vector commands) +{ + for (size_t j=0; j lock(m_cacheLock); +#endif + m_cache.clear(); +} + +void DynamicCacheBasedLanguageModel::Load(AllOptions::ptr const& opts) +{ + m_options = opts; +// SetPreComputedScores(); + VERBOSE(2,"DynamicCacheBasedLanguageModel::Load()" << std::endl); + Load(m_initfiles); +} + +void DynamicCacheBasedLanguageModel::Load(const std::string filestr) +{ + VERBOSE(2,"DynamicCacheBasedLanguageModel::Load(const std::string filestr)" << std::endl); +// std::vector files = Tokenize(m_initfiles, "||"); + std::vector files = Tokenize(filestr, "||"); + Load_Multiple_Files(files); +} + + +void DynamicCacheBasedLanguageModel::Load_Multiple_Files(std::vector files) +{ + VERBOSE(2,"DynamicCacheBasedLanguageModel::Load_Multiple_Files(std::vector files)" << std::endl); + for(size_t j = 0; j < files.size(); ++j) { + Load_Single_File(files[j]); + } +} + +void DynamicCacheBasedLanguageModel::Load_Single_File(const std::string file) +{ + VERBOSE(2,"DynamicCacheBasedLanguageModel::Load_Single_File(const std::string file)" << std::endl); + //file format + //age || n-gram + //age || n-gram || n-gram || n-gram || ... + //.... + //each n-gram is a sequence of n words (no matter of n) + // + //there is no limit on the size of n + // + //entries can be repeated, but the last entry overwrites the previous + + + VERBOSE(2,"Loading data from the cache file " << file << std::endl); + InputFileStream cacheFile(file); + + std::string line; + int age; + std::vector words; + + while (getline(cacheFile, line)) { + std::vector vecStr = TokenizeMultiCharSeparator( line , "||" ); + if (vecStr.size() >= 2) { + age = Scan(vecStr[0]); + vecStr.erase(vecStr.begin()); + Update(vecStr,age); + } else { + UTIL_THROW_IF2(false, "The format of the loaded file is wrong: " << line); + } + } + IFVERBOSE(2) Print(); +} + +void DynamicCacheBasedLanguageModel::SetQueryType(size_t type) +{ +#ifdef WITH_THREADS + boost::shared_lock read_lock(m_cacheLock); +#endif + + m_query_type = type; + if ( m_query_type != CBLM_QUERY_TYPE_WHOLESTRING + && m_query_type != CBLM_QUERY_TYPE_ALLSUBSTRINGS ) { + VERBOSE(2, "This query type " << m_query_type << " is unknown. Instead used " << CBLM_QUERY_TYPE_ALLSUBSTRINGS << "." << std::endl); + m_query_type = CBLM_QUERY_TYPE_ALLSUBSTRINGS; + } + VERBOSE(2, "CacheBasedLanguageModel QueryType: " << m_query_type << std::endl); + +}; + +void DynamicCacheBasedLanguageModel::SetScoreType(size_t type) +{ +#ifdef WITH_THREADS + boost::shared_lock read_lock(m_cacheLock); +#endif + m_score_type = type; + if ( m_score_type != CBLM_SCORE_TYPE_HYPERBOLA + && m_score_type != CBLM_SCORE_TYPE_POWER + && m_score_type != CBLM_SCORE_TYPE_EXPONENTIAL + && m_score_type != CBLM_SCORE_TYPE_COSINE + && m_score_type != CBLM_SCORE_TYPE_HYPERBOLA_REWARD + && m_score_type != CBLM_SCORE_TYPE_POWER_REWARD + && m_score_type != CBLM_SCORE_TYPE_EXPONENTIAL_REWARD ) { + VERBOSE(2, "This score type " << m_score_type << " is unknown. Instead used " << CBLM_SCORE_TYPE_HYPERBOLA << "." << std::endl); + m_score_type = CBLM_SCORE_TYPE_HYPERBOLA; + } + VERBOSE(2, "CacheBasedLanguageModel ScoreType: " << m_score_type << std::endl); +}; + +void DynamicCacheBasedLanguageModel::SetMaxAge(unsigned int age) +{ +#ifdef WITH_THREADS + boost::shared_lock read_lock(m_cacheLock); +#endif + m_maxAge = age; + VERBOSE(2, "CacheBasedLanguageModel MaxAge: " << m_maxAge << std::endl); +}; + +float DynamicCacheBasedLanguageModel::decaying_score(const unsigned int age) +{ + float sc; + switch(m_score_type) { + case CBLM_SCORE_TYPE_HYPERBOLA: + sc = (float) 1.0/age - 1.0; + break; + case CBLM_SCORE_TYPE_POWER: + sc = (float) pow(age, -0.25) - 1.0; + break; + case CBLM_SCORE_TYPE_EXPONENTIAL: + sc = (age == 1) ? 0.0 : (float) exp( 1.0/age ) / exp(1.0) - 1.0; + break; + case CBLM_SCORE_TYPE_COSINE: + sc = (float) cos( (age-1) * (PI/2) / m_maxAge ) - 1.0; + break; + case CBLM_SCORE_TYPE_HYPERBOLA_REWARD: + sc = (float) 1.0/age; + break; + case CBLM_SCORE_TYPE_POWER_REWARD: + sc = (float) pow(age, -0.25); + break; + case CBLM_SCORE_TYPE_EXPONENTIAL_REWARD: + sc = (age == 1) ? 1.0 : (float) exp( 1.0/age ) / exp(1.0); + break; + default: + sc = -1.0; + } + return sc; +} +} diff --git a/mosesdecoder/moses/FF/DynamicCacheBasedLanguageModel.h b/mosesdecoder/moses/FF/DynamicCacheBasedLanguageModel.h new file mode 100644 index 0000000000000000000000000000000000000000..be3d0726995929d07a6dcc8662e2d8a2ade48e28 --- /dev/null +++ b/mosesdecoder/moses/FF/DynamicCacheBasedLanguageModel.h @@ -0,0 +1,164 @@ +// $Id$ + +#ifndef moses_DynamicCacheBasedLanguageModel_h +#define moses_DynamicCacheBasedLanguageModel_h + +#include "moses/Util.h" +#include "FeatureFunction.h" + +#ifdef WITH_THREADS +#include +#include +#endif + +typedef std::pair decaying_cache_value_t; +typedef std::map decaying_cache_t; + +#define CBLM_QUERY_TYPE_UNDEFINED (-1) +#define CBLM_QUERY_TYPE_ALLSUBSTRINGS 0 +#define CBLM_QUERY_TYPE_WHOLESTRING 1 + +#define CBLM_SCORE_TYPE_UNDEFINED (-1) +#define CBLM_SCORE_TYPE_HYPERBOLA 0 +#define CBLM_SCORE_TYPE_POWER 1 +#define CBLM_SCORE_TYPE_EXPONENTIAL 2 +#define CBLM_SCORE_TYPE_COSINE 3 +#define CBLM_SCORE_TYPE_HYPERBOLA_REWARD 10 +#define CBLM_SCORE_TYPE_POWER_REWARD 11 +#define CBLM_SCORE_TYPE_EXPONENTIAL_REWARD 12 +#define PI 3.14159265 + +namespace Moses +{ + +class Range; + +/** Calculates score for the Dynamic Cache-Based pseudo LM + */ +class DynamicCacheBasedLanguageModel : public StatelessFeatureFunction +{ + // data structure for the cache; + // the key is the word and the value is the decaying score + decaying_cache_t m_cache; + size_t m_query_type; //way of querying the cache + size_t m_score_type; //way of scoring entries of the cache + std::string m_initfiles; // vector of files loaded in the initialization phase + std::string m_name; // internal name to identify this instance of the Cache-based pseudo LM + float m_lower_score; //lower_bound_score for no match + bool m_constant; //flag for setting a non-decaying cache + std::vector precomputedScores; + unsigned int m_maxAge; + +#ifdef WITH_THREADS + //multiple readers - single writer lock + mutable boost::shared_mutex m_cacheLock; +#endif + + float decaying_score(unsigned int age); + void SetPreComputedScores(); + float GetPreComputedScores(const unsigned int age); + + float Evaluate_Whole_String( const TargetPhrase&) const; + float Evaluate_All_Substrings( const TargetPhrase&) const; + + void Decay(); + void Update(std::vector words, int age); + + void ClearEntries(std::vector entries); + + void Execute(std::vector commands); + void Execute_Single_Command(std::string command); + + void Load_Multiple_Files(std::vector files); + void Load_Single_File(const std::string file); + + void Insert(std::vector ngrams); + +// void EvaluateInIsolation(const Phrase&, const TargetPhrase&, ScoreComponentCollection&, ScoreComponentCollection& ) const; + void Print() const; + +protected: + static DynamicCacheBasedLanguageModel* s_instance; + static std::map< const std::string, DynamicCacheBasedLanguageModel* > s_instance_map; + +public: + DynamicCacheBasedLanguageModel(const std::string &line); + ~DynamicCacheBasedLanguageModel(); + + inline const std::string GetName() { + return m_name; + }; + inline void SetName(const std::string name) { + m_name = name; + } + + static const DynamicCacheBasedLanguageModel* Instance(const std::string& name) { + if (s_instance_map.find(name) == s_instance_map.end()) { + return NULL; + } + return s_instance_map[name]; + } + + static DynamicCacheBasedLanguageModel* InstanceNonConst(const std::string& name) { + if (s_instance_map.find(name) == s_instance_map.end()) { + return NULL; + } + return s_instance_map[name]; + } + + + + static const DynamicCacheBasedLanguageModel& Instance() { + return *s_instance; + } + static DynamicCacheBasedLanguageModel& InstanceNonConst() { + return *s_instance; + } + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + void Load(AllOptions::ptr const& opts); + void Load(const std::string filestr); + void Execute(std::string command); + void SetParameter(const std::string& key, const std::string& value); + void ExecuteDlt(std::map dlt_meta); + + void ClearEntries(std::string &entries); + void Insert(std::string &entries); + void Clear(); + + virtual void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const; + + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const { + } + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + + void SetQueryType(size_t type); + void SetScoreType(size_t type); + void SetMaxAge(unsigned int age); +}; + +} + +#endif diff --git a/mosesdecoder/moses/FF/EditOps.h b/mosesdecoder/moses/FF/EditOps.h new file mode 100644 index 0000000000000000000000000000000000000000..e7e7dd3152be766eb63cfa1cf14eb375835d7c5e --- /dev/null +++ b/mosesdecoder/moses/FF/EditOps.h @@ -0,0 +1,64 @@ +#ifndef moses_EditOps_h +#define moses_EditOps_h + +#include +#include + +#include "StatelessFeatureFunction.h" +#include "moses/FactorCollection.h" +#include "moses/AlignmentInfo.h" + +namespace Moses +{ + +typedef std::vector Tokens; + +/** Calculates string edit operations that transform source phrase into target + * phrase using the LCS algorithm. Potentially usefule for monolingual tasks + * like paraphrasing, summarization, correction. + */ +class EditOps : public StatelessFeatureFunction +{ +private: + FactorType m_factorType; + bool m_chars; + std::string m_scores; + +public: + EditOps(const std::string &line); + + bool IsUseable(const FactorMask &mask) const; + + void Load(); + + virtual void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedFutureScore) const; + + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedFutureScore = NULL) const + {} + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const + {} + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const + {} + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const + {} + + void ComputeFeatures(const Phrase &source, + const TargetPhrase& targetPhrase, + ScoreComponentCollection* accumulator) const; + void SetParameter(const std::string& key, const std::string& value); +}; + +} + +#endif // moses_CorrectionPattern_h diff --git a/mosesdecoder/moses/FF/ExampleStatelessFF.cpp b/mosesdecoder/moses/FF/ExampleStatelessFF.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0e62ad0ad8f325724a95ed681b7fb1ed5cb6e505 --- /dev/null +++ b/mosesdecoder/moses/FF/ExampleStatelessFF.cpp @@ -0,0 +1,69 @@ +#include +#include "ExampleStatelessFF.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/TargetPhrase.h" + +using namespace std; + +namespace Moses +{ +ExampleStatelessFF::ExampleStatelessFF(const std::string &line) + :StatelessFeatureFunction(2, line) +{ + ReadParameters(); +} + +void ExampleStatelessFF::EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const +{ + // dense scores + vector newScores(m_numScoreComponents); + newScores[0] = 1.5; + newScores[1] = 0.3; + scoreBreakdown.PlusEquals(this, newScores); + + // sparse scores + scoreBreakdown.PlusEquals(this, "sparse-name", 2.4); + +} + +void ExampleStatelessFF::EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores) const +{ + if (targetPhrase.GetNumNonTerminals()) { + vector newScores(m_numScoreComponents); + newScores[0] = - std::numeric_limits::infinity(); + scoreBreakdown.PlusEquals(this, newScores); + } +} + +void ExampleStatelessFF::EvaluateTranslationOptionListWithSourceContext(const InputType &input + + , const TranslationOptionList &translationOptionList) const +{} + +void ExampleStatelessFF::EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const +{} + +void ExampleStatelessFF::EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const +{} + +void ExampleStatelessFF::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "arg") { + // set value here + } else { + StatelessFeatureFunction::SetParameter(key, value); + } +} + +} + diff --git a/mosesdecoder/moses/FF/ExampleTranslationOptionListFeature.h b/mosesdecoder/moses/FF/ExampleTranslationOptionListFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..7686eb3ffaaa1df99c1ec9a37c47c4832da4d7ba --- /dev/null +++ b/mosesdecoder/moses/FF/ExampleTranslationOptionListFeature.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include "StatelessFeatureFunction.h" + +namespace Moses +{ + +class ExampleTranslationOptionListFeature : public StatelessFeatureFunction +{ +public: + ExampleTranslationOptionListFeature(const std::string &line) + :StatelessFeatureFunction(1, line) { + ReadParameters(); + } + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedFutureScore) const { + } + + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedFutureScore = NULL) const { + } + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + std::vector newScores(m_numScoreComponents); + newScores[0] = translationOptionList.size(); + + TranslationOptionList::const_iterator iterTransOpt; + for(iterTransOpt = translationOptionList.begin() ; + iterTransOpt != translationOptionList.end() ; ++iterTransOpt) { + TranslationOption &transOpt = **iterTransOpt; + + ScoreComponentCollection &scoreBreakDown = transOpt.GetScoreBreakdown(); + scoreBreakDown.PlusEquals(this, newScores); + + transOpt.UpdateScore(); + } + } + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + + + void SetParameter(const std::string& key, const std::string& value) { + } + +}; + +} + diff --git a/mosesdecoder/moses/FF/FFState.cpp b/mosesdecoder/moses/FF/FFState.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4df3787af74fa063b7640fbde991c8f1eeac6509 --- /dev/null +++ b/mosesdecoder/moses/FF/FFState.cpp @@ -0,0 +1,9 @@ +#include "moses/FF/FFState.h" + +namespace Moses +{ + +FFState::~FFState() {} + +} + diff --git a/mosesdecoder/moses/FF/FFState.h b/mosesdecoder/moses/FF/FFState.h new file mode 100644 index 0000000000000000000000000000000000000000..ffecb2e8ab3bb59c5aa910e028e7d152c7be75a6 --- /dev/null +++ b/mosesdecoder/moses/FF/FFState.h @@ -0,0 +1,39 @@ +#ifndef moses_FFState_h +#define moses_FFState_h + +#include +#include +#include "util/exception.hh" + +namespace Moses +{ + +class FFState +{ +public: + virtual ~FFState(); + virtual size_t hash() const = 0; + virtual bool operator==(const FFState& other) const = 0; + + virtual bool operator!=(const FFState& other) const { + return !(*this == other); + } +}; + +class DummyState : public FFState +{ +public: + DummyState() {} + + virtual size_t hash() const { + return 0; + } + + virtual bool operator==(const FFState& other) const { + return true; + } + +}; + +} +#endif diff --git a/mosesdecoder/moses/FF/FeatureFunction.h b/mosesdecoder/moses/FF/FeatureFunction.h new file mode 100644 index 0000000000000000000000000000000000000000..06649634a4cdacab5bf583f0083a51232dd440b0 --- /dev/null +++ b/mosesdecoder/moses/FF/FeatureFunction.h @@ -0,0 +1,200 @@ +// -*- c++ -*- +#ifndef moses_FeatureFunction_h +#define moses_FeatureFunction_h + +#include +#include +#include +#include "moses/FeatureVector.h" +#include "moses/TypeDef.h" +#include "moses/parameters/AllOptions.h" +#include + +namespace Moses +{ + +class AllOptions; +class Phrase; +class TargetPhrase; +class TranslationOptionList; +class TranslationOption; +class Hypothesis; +class ChartHypothesis; +class InputType; +class ScoreComponentCollection; +class Bitmap; +class Range; +class FactorMask; +class InputPath; +class StackVec; +class DistortionScoreProducer; +class TranslationTask; + +/** base class for all feature functions. + */ +class FeatureFunction +{ +protected: + /**< all the score producers in this run */ + static std::vector s_staticColl; + + std::string m_description, m_argLine; + std::vector > m_args; + bool m_tuneable; + bool m_requireSortingAfterSourceContext; + size_t m_verbosity; + size_t m_numScoreComponents; + size_t m_index; // index into vector covering ALL feature function values + std::vector m_tuneableComponents; + size_t m_numTuneableComponents; + AllOptions::ptr m_options; + //In case there's multiple producers with the same description + static std::multiset description_counts; + +public: + static void Register(FeatureFunction* ff); +private: + // void Initialize(const std::string &line); + void ParseLine(const std::string &line); + +public: + static const std::vector& GetFeatureFunctions() { + return s_staticColl; + } + + static FeatureFunction &FindFeatureFunction(const std::string& name); + static void Destroy(); + + FeatureFunction(const std::string &line, bool registerNow); + FeatureFunction(size_t numScoreComponents, const std::string &line, bool registerNow = true); + virtual bool IsStateless() const = 0; + virtual ~FeatureFunction(); + + //! override to load model files + virtual void Load(AllOptions::ptr const& opts) { + m_options = opts; + } + + AllOptions::ptr const& + options() const { + return m_options; + } + + static void ResetDescriptionCounts() { + description_counts.clear(); + } + + //! returns the number of scores that a subclass produces. + //! For example, a language model conventionally produces 1, a translation table some arbitrary number, etc + size_t GetNumScoreComponents() const { + return m_numScoreComponents; + } + + //! returns a string description of this producer + const std::string& GetScoreProducerDescription() const { + return m_description; + } + + FName GetFeatureName(const std::string& name) const { + return FName(GetScoreProducerDescription(), name); + } + + + //! if false, then this feature is not displayed in the n-best list. + // use with care + virtual bool IsTuneable() const { + return m_tuneable; + } + + virtual bool HasTuneableComponents() const { + return m_numTuneableComponents; + } + + virtual bool IsTuneableComponent(size_t i) const { + if (m_numTuneableComponents == m_numScoreComponents) { + return true; + } + return m_tuneableComponents[i]; + } + + virtual bool RequireSortingAfterSourceContext() const { + return m_requireSortingAfterSourceContext; + } + + virtual std::vector DefaultWeights() const; + + size_t GetIndex() const; + size_t SetIndex(size_t const idx); + +protected: + virtual void + CleanUpAfterSentenceProcessing(InputType const& source) { } + +public: + //! Called before search and collecting of translation options + virtual void + InitializeForInput(ttasksptr const& ttask) { }; + + // clean up temporary memory, called after processing each sentence + virtual void + CleanUpAfterSentenceProcessing(ttasksptr const& ttask); + + const std::string & + GetArgLine() const { + return m_argLine; + } + + // given a target phrase containing only factors specified in mask + // return true if the feature function can be evaluated + virtual bool IsUseable(const FactorMask &mask) const = 0; + + // used by stateless ff and stateful ff. Calculate initial score + // estimate during loading of phrase table + // + // source phrase is the substring that the phrase table uses to look + // up the target phrase, + // + // may have more factors than actually need, but not guaranteed. + // For SCFG decoding, the source contains non-terminals, NOT the raw + // source from the input sentence + virtual void + EvaluateInIsolation(const Phrase &source, const TargetPhrase &targetPhrase, + ScoreComponentCollection& scoreBreakdown, + ScoreComponentCollection& estimatedScores) const = 0; + + // for context-dependent processing + static void SetupAll(TranslationTask const& task); + virtual void Setup(TranslationTask const& task) const { }; + + // This method is called once all the translation options are retrieved from the phrase table, and + // just before search. + // 'inputPath' is guaranteed to be the raw substring from the input. No factors were added or taken away + // 'stackVec' is a vector of chart cells that the RHS non-terms cover. + // It is guaranteed to be in the same order as the non-terms in the source phrase. + // For pb models, stackvec is NULL. + // No FF should set estimatedScores in both overloads! + virtual void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const = 0; + + // This method is called once all the translation options are retrieved from the phrase table, and + // just before search. + // 'inputPath' is guaranteed to be the raw substring from the input. No factors were added or taken away + // 'stackVec' is a vector of chart cells that the RHS non-terms cover. + // It is guaranteed to be in the same order as the non-terms in the source phrase. + // For pb models, stackvec is NULL. + // No FF should set estimatedScores in both overloads! + virtual void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const = 0; + + virtual void SetParameter(const std::string& key, const std::string& value); + virtual void ReadParameters(); + virtual void SetTuneableComponents(const std::string& value); +}; + +} + +#endif diff --git a/mosesdecoder/moses/FF/GlobalLexicalModel.h b/mosesdecoder/moses/FF/GlobalLexicalModel.h new file mode 100644 index 0000000000000000000000000000000000000000..8391609a2c3d431adf5d8ffaf9ea5d9e7c8ed35a --- /dev/null +++ b/mosesdecoder/moses/FF/GlobalLexicalModel.h @@ -0,0 +1,102 @@ +#ifndef moses_GlobalLexicalModel_h +#define moses_GlobalLexicalModel_h + +#include +#include +#include +#include +#include "StatelessFeatureFunction.h" +#include "moses/Factor.h" +#include "moses/Phrase.h" +#include "moses/TypeDef.h" +#include "moses/Util.h" +#include "moses/Range.h" +#include "moses/FactorTypeSet.h" +#include "moses/Sentence.h" + +#ifdef WITH_THREADS +#include +#endif + +namespace Moses +{ + +class Factor; +class Phrase; +class Hypothesis; +class InputType; + +/** Discriminatively trained global lexicon model + * This is a implementation of Mauser et al., 2009's model that predicts + * each output word from _all_ the input words. The intuition behind this + * feature is that it uses context words for disambiguation + */ +class GlobalLexicalModel : public StatelessFeatureFunction +{ + typedef boost::unordered_map< const Word*, + boost::unordered_map< const Word*, float, UnorderedComparer , UnorderedComparer >, + UnorderedComparer, UnorderedComparer > DoubleHash; + typedef boost::unordered_map< const Word*, float, UnorderedComparer, UnorderedComparer > SingleHash; + typedef std::map< const TargetPhrase*, float > LexiconCache; + + struct ThreadLocalStorage { + LexiconCache cache; + const Sentence *input; + }; + +private: + DoubleHash m_hash; +#ifdef WITH_THREADS + boost::thread_specific_ptr m_local; +#else + std::auto_ptr m_local; +#endif + Word *m_bias; + + FactorMask m_inputFactors, m_outputFactors; + std::vector m_inputFactorsVec, m_outputFactorsVec; + std::string m_filePath; + + void Load(AllOptions::ptr const& opts); + + float ScorePhrase( const TargetPhrase& targetPhrase ) const; + float GetFromCacheOrScorePhrase( const TargetPhrase& targetPhrase ) const; + +public: + GlobalLexicalModel(const std::string &line); + virtual ~GlobalLexicalModel(); + + void SetParameter(const std::string& key, const std::string& value); + + void InitializeForInput(ttasksptr const& ttask); + + bool IsUseable(const FactorMask &mask) const; + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const { + } + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const; + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + +}; + +} +#endif diff --git a/mosesdecoder/moses/FF/GlobalLexicalModelUnlimited.cpp b/mosesdecoder/moses/FF/GlobalLexicalModelUnlimited.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15eec019c881bc3ea8c3d3eccf3da70fb4366422 --- /dev/null +++ b/mosesdecoder/moses/FF/GlobalLexicalModelUnlimited.cpp @@ -0,0 +1,340 @@ +#include "GlobalLexicalModelUnlimited.h" +#include +#include "moses/StaticData.h" +#include "moses/InputFileStream.h" +#include "moses/Hypothesis.h" +#include "moses/TranslationTask.h" +#include "util/string_piece_hash.hh" +#include "util/string_stream.hh" + +using namespace std; + +namespace Moses +{ +GlobalLexicalModelUnlimited::GlobalLexicalModelUnlimited(const std::string &line) + :StatelessFeatureFunction(0, line) +{ + UTIL_THROW(util::Exception, + "GlobalLexicalModelUnlimited hasn't been refactored for new feature function framework yet"); // TODO need to update arguments to key=value + + const vector modelSpec = Tokenize(line); + + for (size_t i = 0; i < modelSpec.size(); i++ ) { + bool ignorePunctuation = true, biasFeature = false, restricted = false; + size_t context = 0; + string filenameSource, filenameTarget; + vector< string > factors; + vector< string > spec = Tokenize(modelSpec[i]," "); + + // read optional punctuation and bias specifications + if (spec.size() > 0) { + if (spec.size() != 2 && spec.size() != 3 && spec.size() != 4 && spec.size() != 6) { + std::cerr << "Format of glm feature is - [ignore-punct] [use-bias] " + << "[context-type] [filename-src filename-tgt]"; + //return false; + } + + factors = Tokenize(spec[0],"-"); + if (spec.size() >= 2) + ignorePunctuation = Scan(spec[1]); + if (spec.size() >= 3) + biasFeature = Scan(spec[2]); + if (spec.size() >= 4) + context = Scan(spec[3]); + if (spec.size() == 6) { + filenameSource = spec[4]; + filenameTarget = spec[5]; + restricted = true; + } + } else + factors = Tokenize(modelSpec[i],"-"); + + if ( factors.size() != 2 ) { + std::cerr << "Wrong factor definition for global lexical model unlimited: " << modelSpec[i]; + //return false; + } + + const vector inputFactors = Tokenize(factors[0],","); + const vector outputFactors = Tokenize(factors[1],","); + throw runtime_error("GlobalLexicalModelUnlimited should be reimplemented as a stateful feature"); + GlobalLexicalModelUnlimited* glmu = NULL; // new GlobalLexicalModelUnlimited(inputFactors, outputFactors, biasFeature, ignorePunctuation, context); + + if (restricted) { + cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl; + if (!glmu->Load(filenameSource, filenameTarget)) { + std::cerr << "Unable to load word lists for word translation feature from files " + << filenameSource + << " and " + << filenameTarget; + //return false; + } + } + } +} + +bool GlobalLexicalModelUnlimited::Load(const std::string &filePathSource, + const std::string &filePathTarget) +{ + // restricted source word vocabulary + ifstream inFileSource(filePathSource.c_str()); + if (!inFileSource) { + cerr << "could not open file " << filePathSource << endl; + return false; + } + + std::string line; + while (getline(inFileSource, line)) { + m_vocabSource.insert(line); + } + + inFileSource.close(); + + // restricted target word vocabulary + ifstream inFileTarget(filePathTarget.c_str()); + if (!inFileTarget) { + cerr << "could not open file " << filePathTarget << endl; + return false; + } + + while (getline(inFileTarget, line)) { + m_vocabTarget.insert(line); + } + + inFileTarget.close(); + + m_unrestricted = false; + return true; +} + +void GlobalLexicalModelUnlimited::InitializeForInput(ttasksptr const& ttask) +{ + UTIL_THROW_IF2(ttask->GetSource()->GetType() != SentenceInput, + "GlobalLexicalModel works only with sentence input."); + Sentence const* s = reinterpret_cast(ttask->GetSource().get()); + m_local.reset(new ThreadLocalStorage); + m_local->input = s; +} + +void GlobalLexicalModelUnlimited::EvaluateWhenApplied(const Hypothesis& cur_hypo, ScoreComponentCollection* accumulator) const +{ + const Sentence& input = *(m_local->input); + const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase(); + + for(size_t targetIndex = 0; targetIndex < targetPhrase.GetSize(); targetIndex++ ) { + StringPiece targetString = targetPhrase.GetWord(targetIndex).GetString(0); // TODO: change for other factors + + if (m_ignorePunctuation) { + // check if first char is punctuation + char firstChar = targetString[0]; + CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); + if(charIterator != m_punctuationHash.end()) + continue; + } + + if (m_biasFeature) { + util::StringStream feature; + feature << "glm_"; + feature << targetString; + feature << "~"; + feature << "**BIAS**"; + accumulator->SparsePlusEquals(feature.str(), 1); + } + + boost::unordered_set alreadyScored; + for(size_t sourceIndex = 0; sourceIndex < input.GetSize(); sourceIndex++ ) { + const StringPiece sourceString = input.GetWord(sourceIndex).GetString(0); + // TODO: change for other factors + + if (m_ignorePunctuation) { + // check if first char is punctuation + char firstChar = sourceString[0]; + CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); + if(charIterator != m_punctuationHash.end()) + continue; + } + const uint64_t sourceHash = util::MurmurHashNative(sourceString.data(), sourceString.size()); + + if ( alreadyScored.find(sourceHash) == alreadyScored.end()) { + bool sourceExists, targetExists; + if (!m_unrestricted) { + sourceExists = FindStringPiece(m_vocabSource, sourceString ) != m_vocabSource.end(); + targetExists = FindStringPiece(m_vocabTarget, targetString) != m_vocabTarget.end(); + } + + // no feature if vocab is in use and both words are not in restricted vocabularies + if (m_unrestricted || (sourceExists && targetExists)) { + if (m_sourceContext) { + if (sourceIndex == 0) { + // add trigger feature for source + util::StringStream feature; + feature << "glm_"; + feature << targetString; + feature << "~"; + feature << ","; + feature << sourceString; + accumulator->SparsePlusEquals(feature.str(), 1); + alreadyScored.insert(sourceHash); + } + + // add source words to the right of current source word as context + for(int contextIndex = sourceIndex+1; contextIndex < input.GetSize(); contextIndex++ ) { + StringPiece contextString = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors + bool contextExists; + if (!m_unrestricted) + contextExists = FindStringPiece(m_vocabSource, contextString ) != m_vocabSource.end(); + + if (m_unrestricted || contextExists) { + util::StringStream feature; + feature << "glm_"; + feature << targetString; + feature << "~"; + feature << sourceString; + feature << ","; + feature << contextString; + accumulator->SparsePlusEquals(feature.str(), 1); + alreadyScored.insert(sourceHash); + } + } + } else if (m_biphrase) { + // --> look backwards for constructing context + int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex; + + // 1) source-target pair, trigger source word (can be discont.) and adjacent target word (bigram) + StringPiece targetContext; + if (globalTargetIndex > 0) + targetContext = cur_hypo.GetWord(globalTargetIndex-1).GetString(0); // TODO: change for other factors + else + targetContext = ""; + + if (sourceIndex == 0) { + StringPiece sourceTrigger = ""; + AddFeature(accumulator, sourceTrigger, sourceString, + targetContext, targetString); + } else + for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) { + StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors + bool sourceTriggerExists = false; + if (!m_unrestricted) + sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end(); + + if (m_unrestricted || sourceTriggerExists) + AddFeature(accumulator, sourceTrigger, sourceString, + targetContext, targetString); + } + + // 2) source-target pair, adjacent source word (bigram) and trigger target word (can be discont.) + StringPiece sourceContext; + if (sourceIndex-1 >= 0) + sourceContext = input.GetWord(sourceIndex-1).GetString(0); // TODO: change for other factors + else + sourceContext = ""; + + if (globalTargetIndex == 0) { + string targetTrigger = ""; + AddFeature(accumulator, sourceContext, sourceString, + targetTrigger, targetString); + } else + for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) { + StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors + bool targetTriggerExists = false; + if (!m_unrestricted) + targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end(); + + if (m_unrestricted || targetTriggerExists) + AddFeature(accumulator, sourceContext, sourceString, + targetTrigger, targetString); + } + } else if (m_bitrigger) { + // allow additional discont. triggers on both sides + int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex; + + if (sourceIndex == 0) { + StringPiece sourceTrigger = ""; + bool sourceTriggerExists = true; + + if (globalTargetIndex == 0) { + string targetTrigger = ""; + bool targetTriggerExists = true; + + if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) + AddFeature(accumulator, sourceTrigger, sourceString, + targetTrigger, targetString); + } else { + // iterate backwards over target + for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) { + StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors + bool targetTriggerExists = false; + if (!m_unrestricted) + targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end(); + + if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) + AddFeature(accumulator, sourceTrigger, sourceString, + targetTrigger, targetString); + } + } + } + // iterate over both source and target + else { + // iterate backwards over source + for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) { + StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors + bool sourceTriggerExists = false; + if (!m_unrestricted) + sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end(); + + if (globalTargetIndex == 0) { + string targetTrigger = ""; + bool targetTriggerExists = true; + + if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) + AddFeature(accumulator, sourceTrigger, sourceString, + targetTrigger, targetString); + } else { + // iterate backwards over target + for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) { + StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors + bool targetTriggerExists = false; + if (!m_unrestricted) + targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end(); + + if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) + AddFeature(accumulator, sourceTrigger, sourceString, + targetTrigger, targetString); + } + } + } + } + } else { + util::StringStream feature; + feature << "glm_"; + feature << targetString; + feature << "~"; + feature << sourceString; + accumulator->SparsePlusEquals(feature.str(), 1); + alreadyScored.insert(sourceHash); + + } + } + } + } + } +} + +void GlobalLexicalModelUnlimited::AddFeature(ScoreComponentCollection* accumulator, + StringPiece sourceTrigger, StringPiece sourceWord, + StringPiece targetTrigger, StringPiece targetWord) const +{ + util::StringStream feature; + feature << "glm_"; + feature << targetTrigger; + feature << ","; + feature << targetWord; + feature << "~"; + feature << sourceTrigger; + feature << ","; + feature << sourceWord; + accumulator->SparsePlusEquals(feature.str(), 1); + +} + +} diff --git a/mosesdecoder/moses/FF/HyperParameterAsWeight.cpp b/mosesdecoder/moses/FF/HyperParameterAsWeight.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0436313564eebf38a58e6f8bc4bbe7105c431857 --- /dev/null +++ b/mosesdecoder/moses/FF/HyperParameterAsWeight.cpp @@ -0,0 +1,29 @@ +#include "HyperParameterAsWeight.h" +#include "moses/StaticData.h" + +using namespace std; + +namespace Moses +{ + +HyperParameterAsWeight::HyperParameterAsWeight(const std::string &line) + :StatelessFeatureFunction(2, line) +{ + ReadParameters(); + + // hack into StaticData and change anything you want + // as an example, we have 2 weights and change + // 1. stack size + // 2. beam width + StaticData &staticData = StaticData::InstanceNonConst(); + + vector weights = staticData.GetWeights(this); + + staticData.m_options->search.stack_size = weights[0] * 1000; + staticData.m_options->search.beam_width = weights[1] * 10; + +} + + +} + diff --git a/mosesdecoder/moses/FF/InputFeature.h b/mosesdecoder/moses/FF/InputFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..b2b3b4ff48cf07c8b109191a632946c9800451c9 --- /dev/null +++ b/mosesdecoder/moses/FF/InputFeature.h @@ -0,0 +1,70 @@ +#pragma once + +#include "InputFeature.h" +#include "StatelessFeatureFunction.h" + +namespace Moses +{ + + +class InputFeature : public StatelessFeatureFunction +{ +protected: + static InputFeature *s_instance; + + size_t m_numInputScores; + size_t m_numRealWordCount; + bool m_legacy; + +public: + static const InputFeature *InstancePtr() { + return s_instance; + } + + InputFeature(const std::string &line); + + void Load(AllOptions::ptr const& opts); + + void SetParameter(const std::string& key, const std::string& value); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + size_t GetNumInputScores() const { + return m_numInputScores; + } + size_t GetNumRealWordsInInput() const { + return m_numRealWordCount; + } + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const { + } + + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const; + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + + +}; + + +} + diff --git a/mosesdecoder/moses/FF/Model1Feature.cpp b/mosesdecoder/moses/FF/Model1Feature.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4dce6f1bc3aee8b2cabbaabb4962516f40013b58 --- /dev/null +++ b/mosesdecoder/moses/FF/Model1Feature.cpp @@ -0,0 +1,276 @@ +#include "Model1Feature.h" +#include "moses/StaticData.h" +#include "moses/InputFileStream.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/FactorCollection.h" + + +using namespace std; + +namespace Moses +{ + +const std::string Model1Vocabulary::GIZANULL = "GIZANULL"; + +Model1Vocabulary::Model1Vocabulary() +{ + FactorCollection &factorCollection = FactorCollection::Instance(); + m_NULL = factorCollection.AddFactor(GIZANULL,false); + Store(m_NULL,0); +} + +bool Model1Vocabulary::Store(const Factor* word, const unsigned id) +{ + boost::unordered_map::iterator iter = m_lookup.find( word ); + if ( iter != m_lookup.end() ) { + return false; + } + m_lookup[ word ] = id; + if ( m_vocab.size() <= id ) { + m_vocab.resize(id+1); + } + m_vocab[id] = word; + return true; +} + +unsigned Model1Vocabulary::StoreIfNew(const Factor* word) +{ + boost::unordered_map::iterator iter = m_lookup.find( word ); + + if ( iter != m_lookup.end() ) { + return iter->second; + } + + unsigned id = m_vocab.size(); + m_vocab.push_back( word ); + m_lookup[ word ] = id; + return id; +} + +unsigned Model1Vocabulary::GetWordID(const Factor* word) const +{ + boost::unordered_map::const_iterator iter = m_lookup.find( word ); + if ( iter == m_lookup.end() ) { + return INVALID_ID; + } + return iter->second; +} + +const Factor* Model1Vocabulary::GetWord(unsigned id) const +{ + if (id >= m_vocab.size()) { + return NULL; + } + return m_vocab[ id ]; +} + +void Model1Vocabulary::Load(const std::string& fileName) +{ + InputFileStream inFile(fileName); + FactorCollection &factorCollection = FactorCollection::Instance(); + std::string line; + + unsigned i = 0; + if ( getline(inFile, line) ) { // first line of MGIZA vocabulary files seems to be special : "1 UNK 0" -- skip if it's this + ++i; + std::vector tokens = Tokenize(line); + UTIL_THROW_IF2(tokens.size()!=3, "Line " << i << " in " << fileName << " has wrong number of tokens."); + unsigned id = atoll( tokens[0].c_str() ); + if (! ( (id == 1) && (tokens[1] == "UNK") )) { + const Factor* factor = factorCollection.AddFactor(tokens[1],false); // TODO: can we assume that the vocabulary is know and filter the model on loading? + bool stored = Store(factor, id); + UTIL_THROW_IF2(!stored, "Line " << i << " in " << fileName << " overwrites existing vocabulary entry."); + } + } + while ( getline(inFile, line) ) { + ++i; + std::vector tokens = Tokenize(line); + UTIL_THROW_IF2(tokens.size()!=3, "Line " << i << " in " << fileName << " has wrong number of tokens."); + unsigned id = atoll( tokens[0].c_str() ); + const Factor* factor = factorCollection.AddFactor(tokens[1],false); // TODO: can we assume that the vocabulary is know and filter the model on loading? + bool stored = Store(factor, id); + UTIL_THROW_IF2(!stored, "Line " << i << " in " << fileName << " overwrites existing vocabulary entry."); + } + inFile.Close(); +} + + +void Model1LexicalTable::Load(const std::string &fileName, const Model1Vocabulary& vcbS, const Model1Vocabulary& vcbT) +{ + InputFileStream inFile(fileName); + std::string line; + + unsigned i = 0; + while ( getline(inFile, line) ) { + ++i; + std::vector tokens = Tokenize(line); + UTIL_THROW_IF2(tokens.size()!=3, "Line " << i << " in " << fileName << " has wrong number of tokens."); + unsigned idS = atoll( tokens[0].c_str() ); + unsigned idT = atoll( tokens[1].c_str() ); + const Factor* wordS = vcbS.GetWord(idS); + const Factor* wordT = vcbT.GetWord(idT); + float prob = std::atof( tokens[2].c_str() ); + if ( (wordS != NULL) && (wordT != NULL) ) { + m_ltable[ wordS ][ wordT ] = prob; + } + UTIL_THROW_IF2((wordS == NULL) || (wordT == NULL), "Line " << i << " in " << fileName << " has unknown vocabulary."); // TODO: can we assume that the vocabulary is know and filter the model on loading? Then remove this line. + } + inFile.Close(); +} + +// p( wordT | wordS ) +float Model1LexicalTable::GetProbability(const Factor* wordS, const Factor* wordT) const +{ + float prob = m_floor; + + boost::unordered_map< const Factor*, boost::unordered_map< const Factor*, float > >::const_iterator iter1 = m_ltable.find( wordS ); + + if ( iter1 != m_ltable.end() ) { + boost::unordered_map< const Factor*, float >::const_iterator iter2 = iter1->second.find( wordT ); + if ( iter2 != iter1->second.end() ) { + prob = iter2->second; + if ( prob < m_floor ) { + prob = m_floor; + } + } + } + return prob; +} + + +Model1Feature::Model1Feature(const std::string &line) + : StatelessFeatureFunction(1, line) + , m_skipTargetPunctuation(false) + , m_is_syntax(false) +{ + VERBOSE(1, "Initializing feature " << GetScoreProducerDescription() << " ..."); + ReadParameters(); + VERBOSE(1, " Done."); +} + +void Model1Feature::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "path") { + m_fileNameModel1 = value; + } else if (key == "source-vocabulary") { + m_fileNameVcbS = value; + } else if (key == "target-vocabulary") { + m_fileNameVcbT = value; + } else if (key == "skip-target-punctuation") { + m_skipTargetPunctuation = Scan(value); + } else { + StatelessFeatureFunction::SetParameter(key, value); + } +} + +void Model1Feature::Load(AllOptions::ptr const& opts) +{ + m_options = opts; + m_is_syntax = is_syntax(opts->search.algo); + + FEATUREVERBOSE(2, GetScoreProducerDescription() << ": Loading source vocabulary from file " << m_fileNameVcbS << " ..."); + Model1Vocabulary vcbS; + vcbS.Load(m_fileNameVcbS); + FEATUREVERBOSE2(2, " Done." << std::endl); + FEATUREVERBOSE(2, GetScoreProducerDescription() << ": Loading target vocabulary from file " << m_fileNameVcbT << " ..."); + Model1Vocabulary vcbT; + vcbT.Load(m_fileNameVcbT); + FEATUREVERBOSE2(2, " Done." << std::endl); + FEATUREVERBOSE(2, GetScoreProducerDescription() << ": Loading model 1 lexical translation table from file " << m_fileNameModel1 << " ..."); + m_model1.Load(m_fileNameModel1,vcbS,vcbT); + FEATUREVERBOSE2(2, " Done." << std::endl); + FactorCollection &factorCollection = FactorCollection::Instance(); + m_emptyWord = factorCollection.GetFactor(Model1Vocabulary::GIZANULL,false); + UTIL_THROW_IF2(m_emptyWord==NULL, GetScoreProducerDescription() + << ": Factor for GIZA empty word does not exist."); + + if (m_skipTargetPunctuation) { + const std::string punctuation = ",;.:!?"; + for (size_t i=0; i::iterator,bool> inserted = m_punctuation.insert(punctFactor); + } + } +} + +void Model1Feature::EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores) const +{ + const Sentence& sentence = static_cast(input); + float score = 0.0; + float norm = TransformScore(1+sentence.GetSize()); + + for (size_t posT=0; posT::const_iterator foundPunctuation = m_punctuation.find(wordT[0]); + if (foundPunctuation != m_punctuation.end()) { + continue; + } + } + if ( !wordT.IsNonTerminal() ) { + float thisWordProb = m_model1.GetProbability(m_emptyWord,wordT[0]); // probability conditioned on empty word + + // cache lookup + bool foundInCache = false; + { +#ifdef WITH_THREADS + boost::shared_lock read_lock(m_accessLock); +#endif + boost::unordered_map >::const_iterator sentenceCache = m_cache.find(&input); + if (sentenceCache != m_cache.end()) { + boost::unordered_map::const_iterator cacheHit = sentenceCache->second.find(wordT[0]); + if (cacheHit != sentenceCache->second.end()) { + foundInCache = true; + score += cacheHit->second; + FEATUREVERBOSE(3, "Cached score( " << wordT << " ) = " << cacheHit->second << std::endl); + } + } + } + + if (!foundInCache) { + for (size_t posS=(m_is_syntax?1:0); posS<(m_is_syntax?sentence.GetSize()-1:sentence.GetSize()); ++posS) { // ignore and + const Word &wordS = sentence.GetWord(posS); + float modelProb = m_model1.GetProbability(wordS[0],wordT[0]); + FEATUREVERBOSE(4, "p( " << wordT << " | " << wordS << " ) = " << modelProb << std::endl); + thisWordProb += modelProb; + } + float thisWordScore = TransformScore(thisWordProb) - norm; + FEATUREVERBOSE(3, "score( " << wordT << " ) = " << thisWordScore << std::endl); + { +#ifdef WITH_THREADS + // need to update cache; write lock + boost::unique_lock lock(m_accessLock); +#endif + m_cache[&input][wordT[0]] = thisWordScore; + } + score += thisWordScore; + } + } + } + + scoreBreakdown.PlusEquals(this, score); +} + +void Model1Feature::CleanUpAfterSentenceProcessing(const InputType& source) +{ +#ifdef WITH_THREADS + // need to update cache; write lock + boost::unique_lock lock(m_accessLock); +#endif + // clear cache + boost::unordered_map >::iterator sentenceCache = m_cache.find(&source); + if (sentenceCache != m_cache.end()) { + sentenceCache->second.clear(); + m_cache.erase(sentenceCache); + } +} + +} + diff --git a/mosesdecoder/moses/FF/Model1Feature.h b/mosesdecoder/moses/FF/Model1Feature.h new file mode 100644 index 0000000000000000000000000000000000000000..cac894841d8e2e89eb86c63f50d9114db171e6e1 --- /dev/null +++ b/mosesdecoder/moses/FF/Model1Feature.h @@ -0,0 +1,118 @@ +#pragma once + +#include +#include +#include +#include +#include "StatelessFeatureFunction.h" +#include "moses/Factor.h" + +#ifdef WITH_THREADS +#include +#endif + +namespace Moses +{ + +class Model1Vocabulary +{ +public: + +#define INVALID_ID std::numeric_limits::max() // UINT_MAX + static const std::string GIZANULL; + + Model1Vocabulary(); + bool Store(const Factor* word, const unsigned id); + unsigned StoreIfNew(const Factor* word); + unsigned GetWordID(const Factor* word) const; + const Factor* GetWord(unsigned id) const; + void Load(const std::string& fileName); + +protected: + boost::unordered_map m_lookup; + std::vector< const Factor* > m_vocab; + const Factor* m_NULL; +}; + + +class Model1LexicalTable +{ +public: + Model1LexicalTable(float floor=1e-7) : m_floor(floor) + {} + + void Load(const std::string& fileName, const Model1Vocabulary& vcbS, const Model1Vocabulary& vcbT); + + // p( wordT | wordS ) + float GetProbability(const Factor* wordS, const Factor* wordT) const; + +protected: + boost::unordered_map< const Factor*, boost::unordered_map< const Factor*, float > > m_ltable; + const float m_floor; +}; + + + +class Model1Feature : public StatelessFeatureFunction +{ +public: + Model1Feature(const std::string &line); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + void SetParameter(const std::string& key, const std::string& value); + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const + {}; + + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const; + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const + {} + + void EvaluateWhenApplied( + const Hypothesis& cur_hypo, + ScoreComponentCollection* accumulator) const + {} + + void EvaluateWhenApplied( + const ChartHypothesis& cur_hypo, + ScoreComponentCollection* accumulator) const + {} + + void CleanUpAfterSentenceProcessing(const InputType& source); + +private: + std::string m_fileNameVcbS; + std::string m_fileNameVcbT; + std::string m_fileNameModel1; + Model1LexicalTable m_model1; + const Factor* m_emptyWord; + bool m_skipTargetPunctuation; + std::set m_punctuation; + bool m_is_syntax; + + void Load(AllOptions::ptr const& opts); + + // cache + mutable boost::unordered_map > m_cache; +#ifdef WITH_THREADS + // reader-writer lock + mutable boost::shared_mutex m_accessLock; +#endif +}; + + +} + diff --git a/mosesdecoder/moses/FF/PhraseBoundaryFeature.h b/mosesdecoder/moses/FF/PhraseBoundaryFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..9e84aaeef4f39bd73227ba57dc8d0d1869afa9f2 --- /dev/null +++ b/mosesdecoder/moses/FF/PhraseBoundaryFeature.h @@ -0,0 +1,70 @@ +#ifndef moses_PhraseBoundaryFeature_h +#define moses_PhraseBoundaryFeature_h + +#include +#include +#include + +#include "StatefulFeatureFunction.h" +#include "moses/FF/FFState.h" +#include "moses/Word.h" + +namespace Moses +{ + +class PhraseBoundaryState : public FFState +{ +public: + PhraseBoundaryState(const Word* sourceWord, const Word* targetWord) : + m_sourceWord(sourceWord), m_targetWord(targetWord) {} + const Word* GetSourceWord() const { + return m_sourceWord; + } + const Word* GetTargetWord() const { + return m_targetWord; + } + virtual size_t hash() const; + virtual bool operator==(const FFState& other) const; + + +private: + const Word* m_sourceWord; + const Word* m_targetWord; +}; + + +/** + * Concatenations of factors on boundaries of phrases. + **/ +class PhraseBoundaryFeature : public StatefulFeatureFunction +{ +public: + PhraseBoundaryFeature(const std::string &line); + + bool IsUseable(const FactorMask &mask) const; + + virtual const FFState* EmptyHypothesisState(const InputType &) const; + + virtual FFState* EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state, + ScoreComponentCollection* accumulator) const; + + virtual FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, + int /* featureID */, + ScoreComponentCollection* ) const { + throw std::logic_error("PhraseBoundaryState not supported in chart decoder, yet"); + } + + void SetParameter(const std::string& key, const std::string& value); + +private: + void AddFeatures( + const Word* leftWord, const Word* rightWord, const FactorList& factors, + const std::string& side, ScoreComponentCollection* scores) const ; + FactorList m_sourceFactors; + FactorList m_targetFactors; +}; + +} + + +#endif diff --git a/mosesdecoder/moses/FF/PhraseDistanceFeature.cpp b/mosesdecoder/moses/FF/PhraseDistanceFeature.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0544ec9fb7ea5bbafc0acaa7ce61416ec497bba0 --- /dev/null +++ b/mosesdecoder/moses/FF/PhraseDistanceFeature.cpp @@ -0,0 +1,123 @@ +#include "PhraseDistanceFeature.h" + +#include +#include +#include "moses/InputType.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/StaticData.h" +#include "util/exception.hh" + +using namespace std; + +namespace Moses +{ +PhraseDistanceFeature::PhraseDistanceFeature(const string &line) + : StatelessFeatureFunction(2, line) + , m_space("") + , m_spaceID(0) + , m_measure(EuclideanDistance) +{ + ReadParameters(); +} + +void PhraseDistanceFeature::EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores) const +{ + vector scores(m_numScoreComponents, 0); + bool broken = false; + // Input coord + map >::const_iterator ii; + if (input.m_coordMap) { + ii = input.m_coordMap->find(m_spaceID); + } else { + TRACE_ERR("No coordinates for space " << m_space << " on input (specify with coord XML tag)" << endl); + TRACE_ERR("Scores for " << m_description << " will be incorrect and probably all zeros" << endl); + broken = true; + } + if (ii == input.m_coordMap->end()) { + TRACE_ERR("No coordinates for space " << m_space << " on input (specify with coord XML tag)" << endl); + TRACE_ERR("Scores for " << m_description << " will be incorrect and probably all zeros" << endl); + broken = true; + } + // Target phrase coord + vector > > const* tpp = targetPhrase.GetCoordList(m_spaceID); + if (tpp == NULL) { + TRACE_ERR("No coordinates for space " << m_space << " on target phrase (PhraseDictionary implementation needs to set)" << endl); + TRACE_ERR("Scores for " << m_description << " will be incorrect and probably all zeros" << endl); + broken = true; + } + // Compute scores + if (!broken) { + vector const& inputCoord = ii->second; + vector > > const& tpCoord = *tpp; + // Centroid of target phrase instances (from phrase extraction) + vector centroid = vector(inputCoord.size(), 0); + BOOST_FOREACH(SPTR > const coord, tpCoord) { + for (size_t i = 0; i < inputCoord.size(); ++i) { + centroid[i] += (*coord)[i]; + } + } + for (size_t i = 0; i < inputCoord.size(); ++i) { + centroid[i] /= tpCoord.size(); + } + // Average distance from the target phrase instances to (1) the input and + // (2) the target phrase centroid + float inputDistance = 0; + float centroidDistance = 0; + if (m_measure == EuclideanDistance) { + BOOST_FOREACH(SPTR > const coord, tpCoord) { + float pointInputDistance = 0; + float pointCentroidDistance = 0; + for (size_t i = 0; i < inputCoord.size(); ++i) { + pointInputDistance += pow(inputCoord[i] - (*coord)[i], 2); + pointCentroidDistance += pow(centroid[i] - (*coord)[i], 2); + } + inputDistance += sqrt(pointInputDistance); + centroidDistance += sqrt(pointCentroidDistance); + } + } else if (m_measure == TotalVariationDistance) { + BOOST_FOREACH(SPTR > const coord, tpCoord) { + float pointInputDistance = 0; + float pointCentroidDistance = 0; + for (size_t i = 0; i < inputCoord.size(); ++i) { + pointInputDistance += abs(inputCoord[i] - (*coord)[i]); + pointCentroidDistance += abs(centroid[i] - (*coord)[i]); + } + inputDistance += pointInputDistance / 2; + centroidDistance += pointCentroidDistance / 2; + } + } + inputDistance /= tpCoord.size(); + centroidDistance /= tpCoord.size(); + // Log transform scores, max with float epsilon to avoid domain error + scores[0] = log(max(inputDistance, Moses::FLOAT_EPSILON)); + scores[1] = log(max(centroidDistance, Moses::FLOAT_EPSILON)); + } + // Set scores + scoreBreakdown.Assign(this, scores); + return; +} + +void PhraseDistanceFeature::SetParameter(const string& key, const string& value) +{ + if (key == "space") { + m_space = value; + m_spaceID = StaticData::InstanceNonConst().MapCoordSpace(m_space); + } else if (key == "measure") { + if (value == "euc") { + m_measure = EuclideanDistance; + } else if (value == "var") { + m_measure = TotalVariationDistance; + } else { + UTIL_THROW2("Unknown measure " << value << ", choices: euc var"); + } + } else { + StatelessFeatureFunction::SetParameter(key, value); + } +} + +} // namespace diff --git a/mosesdecoder/moses/FF/PhraseDistanceFeature.h b/mosesdecoder/moses/FF/PhraseDistanceFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..8c9e1a361e23406a001cbd34f7c2a0e9969dd368 --- /dev/null +++ b/mosesdecoder/moses/FF/PhraseDistanceFeature.h @@ -0,0 +1,56 @@ +#pragma once + +#include "StatelessFeatureFunction.h" + +namespace Moses +{ + +class PhraseDistanceFeature : public StatelessFeatureFunction +{ + enum Measure { + EuclideanDistance, + TotalVariationDistance, + }; + +public: + PhraseDistanceFeature(const std::string &line); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + virtual void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const { + } + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + void EvaluateWhenApplied(const Syntax::SHyperedge &hyperedge, + ScoreComponentCollection* accumulator) const { + } + + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const; + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + void SetParameter(const std::string& key, const std::string& value); + +protected: + Measure m_measure; + std::string m_space; + size_t m_spaceID; +}; + +} //namespace diff --git a/mosesdecoder/moses/FF/PhraseLengthFeature.cpp b/mosesdecoder/moses/FF/PhraseLengthFeature.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c5598e2ecb861298427f2595f9841a4e1be31081 --- /dev/null +++ b/mosesdecoder/moses/FF/PhraseLengthFeature.cpp @@ -0,0 +1,46 @@ +#include +#include "PhraseLengthFeature.h" +#include "moses/Hypothesis.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/TranslationOption.h" +#include "util/string_stream.hh" + +namespace Moses +{ + +using namespace std; + +PhraseLengthFeature::PhraseLengthFeature(const std::string &line) + :StatelessFeatureFunction(0, line) +{ + ReadParameters(); +} + +void PhraseLengthFeature::EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const +{ + // get length of source and target phrase + size_t targetLength = targetPhrase.GetSize(); + size_t sourceLength = source.GetSize(); + + // create feature names + util::StringStream nameSource; + nameSource << "s" << sourceLength; + + util::StringStream nameTarget; + nameTarget << "t" << targetLength; + + util::StringStream nameBoth; + nameBoth << sourceLength << "," << targetLength; + + // increase feature counts + scoreBreakdown.PlusEquals(this,nameSource.str(),1); + scoreBreakdown.PlusEquals(this,nameTarget.str(),1); + scoreBreakdown.PlusEquals(this,nameBoth.str(),1); + + //cerr << nameSource.str() << " " << nameTarget.str() << " " << nameBoth.str() << endl; +} + +} diff --git a/mosesdecoder/moses/FF/PhrasePenalty.cpp b/mosesdecoder/moses/FF/PhrasePenalty.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1f709787f27acdb3a59ddf1efadf91b42aa9042c --- /dev/null +++ b/mosesdecoder/moses/FF/PhrasePenalty.cpp @@ -0,0 +1,51 @@ +#include +#include "PhrasePenalty.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/TranslationModel/PhraseDictionary.h" +#include "util/exception.hh" + +using namespace std; + +namespace Moses +{ +PhrasePenalty::PhrasePenalty(const std::string &line) + : StatelessFeatureFunction(1, line) + , m_perPhraseTable(false) +{ + ReadParameters(); +} + +void PhrasePenalty::EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const +{ + if (m_perPhraseTable) { + const PhraseDictionary *pt = targetPhrase.GetContainer(); + if (pt) { + size_t ptId = pt->GetId(); + UTIL_THROW_IF2(ptId >= m_numScoreComponents, "Wrong number of scores"); + + vector scores(m_numScoreComponents, 0); + scores[ptId] = 1.0f; + + scoreBreakdown.Assign(this, scores); + } + + } else { + scoreBreakdown.Assign(this, 1.0f); + } +} + +void PhrasePenalty::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "per-phrase-table") { + m_perPhraseTable =Scan(value); + } else { + StatelessFeatureFunction::SetParameter(key, value); + } +} + + +} // namespace + diff --git a/mosesdecoder/moses/FF/PhrasePenalty.h b/mosesdecoder/moses/FF/PhrasePenalty.h new file mode 100644 index 0000000000000000000000000000000000000000..e6bda3435867cb6dd66485c3677e299303be12f8 --- /dev/null +++ b/mosesdecoder/moses/FF/PhrasePenalty.h @@ -0,0 +1,50 @@ +#pragma once + +#include "StatelessFeatureFunction.h" + +namespace Moses +{ + +class PhrasePenalty : public StatelessFeatureFunction +{ +public: + PhrasePenalty(const std::string &line); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + virtual void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const; + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + void EvaluateWhenApplied(const Syntax::SHyperedge &hyperedge, + ScoreComponentCollection* accumulator) const { + } + + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const { + } + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + void SetParameter(const std::string& key, const std::string& value); + +protected: + bool m_perPhraseTable; +}; + +} //namespace + diff --git a/mosesdecoder/moses/FF/ReferenceComparison.cpp b/mosesdecoder/moses/FF/ReferenceComparison.cpp new file mode 100644 index 0000000000000000000000000000000000000000..80dcbd234f3c7e89838d61e22f3fca788855f90f --- /dev/null +++ b/mosesdecoder/moses/FF/ReferenceComparison.cpp @@ -0,0 +1,11 @@ +#include "ReferenceComparison.h" + +namespace Moses +{ +ReferenceComparison::ReferenceComparison(const std::string &line) + :StatelessFeatureFunction(0, line) +{ +} + +} + diff --git a/mosesdecoder/moses/FF/RuleScope.h b/mosesdecoder/moses/FF/RuleScope.h new file mode 100644 index 0000000000000000000000000000000000000000..473cf22f5be0a08933e198fcde927e1535eee2ac --- /dev/null +++ b/mosesdecoder/moses/FF/RuleScope.h @@ -0,0 +1,56 @@ +#pragma once +#include +#include "StatelessFeatureFunction.h" + +namespace Moses +{ + +// Rule Scope - not quite completely implemented yet +class RuleScope : public StatelessFeatureFunction +{ +public: + RuleScope(const std::string &line); + + virtual bool IsUseable(const FactorMask &mask) const { + return true; + } + + virtual void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const; + + virtual void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const { + } + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + + + virtual void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + + virtual void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + + void SetParameter(const std::string& key, const std::string& value); + +protected: + bool m_sourceSyntax; + bool m_perScope; + bool m_futureCostOnly; + + bool IsGlueRule(const Phrase &source) const; + +}; + +} + diff --git a/mosesdecoder/moses/FF/SoftMatchingFeature.h b/mosesdecoder/moses/FF/SoftMatchingFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..ae2380dcb4c8b8b1ed21784a5cfd05b414f1aa83 --- /dev/null +++ b/mosesdecoder/moses/FF/SoftMatchingFeature.h @@ -0,0 +1,68 @@ +#pragma once + +#include "moses/Word.h" +#include "StatelessFeatureFunction.h" + +#ifdef WITH_THREADS +#include +#endif + +namespace Moses +{ + +class SoftMatchingFeature : public StatelessFeatureFunction +{ +public: + SoftMatchingFeature(const std::string &line); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + virtual void EvaluateWhenApplied(const ChartHypothesis& hypo, + ScoreComponentCollection* accumulator) const; + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const {}; + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const {}; + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const {}; + + bool Load(const std::string &filePath); + + std::vector >& GetSoftMatches() { + return m_softMatches; + } + + void ResizeCache() const; + + const std::string& GetOrSetFeatureName(const Word& RHS, const Word& LHS) const; + void SetParameter(const std::string& key, const std::string& value); + + +private: + mutable std::vector > m_softMatches; // map RHS of new rule to list of possible LHS of old rule (subtree) + mutable std::vector > m_nameCache; + bool m_scoreIdentical; + +#ifdef WITH_THREADS + //reader-writer lock + mutable boost::shared_mutex m_accessLock; +#endif + +}; + +} + diff --git a/mosesdecoder/moses/FF/SoftSourceSyntacticConstraintsFeature.cpp b/mosesdecoder/moses/FF/SoftSourceSyntacticConstraintsFeature.cpp new file mode 100644 index 0000000000000000000000000000000000000000..afba59b47c31658c65f7814a1be60c579bbd2d95 --- /dev/null +++ b/mosesdecoder/moses/FF/SoftSourceSyntacticConstraintsFeature.cpp @@ -0,0 +1,651 @@ +#include +#include +#include +#include "SoftSourceSyntacticConstraintsFeature.h" +#include "moses/StaticData.h" +#include "moses/InputFileStream.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/Hypothesis.h" +#include "moses/ChartHypothesis.h" +#include "moses/ChartManager.h" +#include "moses/FactorCollection.h" +#include "moses/TreeInput.h" +#include "moses/PP/SourceLabelsPhraseProperty.h" + + +using namespace std; + +namespace Moses +{ + + +SoftSourceSyntacticConstraintsFeature::SoftSourceSyntacticConstraintsFeature(const std::string &line) + : StatelessFeatureFunction(6, line) + , m_useCoreSourceLabels(false) + , m_useLogprobs(true) + , m_useSparse(false) + , m_useSparseLabelPairs(false) + , m_noMismatches(false) + , m_floor(1e-7) +{ + VERBOSE(1, "Initializing feature " << GetScoreProducerDescription() << " ..."); + ReadParameters(); + VERBOSE(1, " Done."); + VERBOSE(1, " Config:"); + VERBOSE(1, " Log probabilities"); + if ( m_useLogprobs ) { + VERBOSE(1, " active."); + } else { + VERBOSE(1, " inactive."); + } + VERBOSE(1, " Sparse scores"); + if ( m_useSparse ) { + VERBOSE(1, " active."); + } else { + VERBOSE(1, " inactive."); + } + VERBOSE(1, " Sparse label pair scores"); + if ( m_useSparseLabelPairs ) { + VERBOSE(1, " active."); + } else { + VERBOSE(1, " inactive."); + } + VERBOSE(1, " Core labels"); + if ( m_useCoreSourceLabels ) { + VERBOSE(1, " active."); + } else { + VERBOSE(1, " inactive."); + } + VERBOSE(1, " No mismatches"); + if ( m_noMismatches ) { + VERBOSE(1, " active."); + } else { + VERBOSE(1, " inactive."); + } + VERBOSE(1, std::endl); +} + + +void SoftSourceSyntacticConstraintsFeature::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "sourceLabelSetFile") { + m_sourceLabelSetFile = value; + } else if (key == "coreSourceLabelSetFile") { + m_coreSourceLabelSetFile = value; + m_useCoreSourceLabels = true; + } else if (key == "targetSourceLeftHandSideJointCountFile") { + m_targetSourceLHSJointCountFile = value; + } else if (key == "noMismatches") { + m_noMismatches = Scan(value); // for a hard constraint, allow no mismatches (also set: weights 1 0 0 0 0 0, tuneable=false) + } else if (key == "logProbabilities") { + m_useLogprobs = Scan(value); + } else if (key == "sparse") { + m_useSparse = Scan(value); + } else if (key == "sparseLabelPairs") { + m_useSparseLabelPairs = Scan(value); + } else { + StatelessFeatureFunction::SetParameter(key, value); + } +} + +void SoftSourceSyntacticConstraintsFeature::Load(AllOptions::ptr const& opts) +{ + m_options = opts; + // don't change the loading order! + LoadSourceLabelSet(); + if (!m_coreSourceLabelSetFile.empty()) { + LoadCoreSourceLabelSet(); + } + if (!m_targetSourceLHSJointCountFile.empty()) { + LoadTargetSourceLeftHandSideJointCountFile(); + } +} + +void SoftSourceSyntacticConstraintsFeature::LoadSourceLabelSet() +{ + FEATUREVERBOSE(2, "Loading source label set from file " << m_sourceLabelSetFile << " ..."); + InputFileStream inFile(m_sourceLabelSetFile); + + FactorCollection &factorCollection = FactorCollection::Instance(); + + // read source label set + std::string line; + m_sourceLabels.clear(); + m_sourceLabelsByIndex.clear(); + m_sourceLabelsByIndex_RHS_1.clear(); + m_sourceLabelsByIndex_RHS_0.clear(); + m_sourceLabelsByIndex_LHS_1.clear(); + m_sourceLabelsByIndex_LHS_0.clear(); + m_sourceLabelIndexesByFactor.clear(); + while (getline(inFile, line)) { + std::istringstream tokenizer(line); + std::string label; + size_t index; + try { + tokenizer >> label >> index; + } catch (const std::exception &e) { + UTIL_THROW2(GetScoreProducerDescription() + << ": Error reading source label set file " << m_sourceLabelSetFile << " ."); + } + std::pair< boost::unordered_map::iterator, bool > inserted = m_sourceLabels.insert( std::pair(label,index) ); + UTIL_THROW_IF2(!inserted.second, GetScoreProducerDescription() + << ": Source label set file " << m_sourceLabelSetFile << " should contain each syntactic label only once."); + + if (index >= m_sourceLabelsByIndex.size()) { + m_sourceLabelsByIndex.resize(index+1); + m_sourceLabelsByIndex_RHS_1.resize(index+1); + m_sourceLabelsByIndex_RHS_0.resize(index+1); + m_sourceLabelsByIndex_LHS_1.resize(index+1); + m_sourceLabelsByIndex_LHS_0.resize(index+1); + } + m_sourceLabelsByIndex[index] = label; + m_sourceLabelsByIndex_RHS_1[index] = "RHS_1_" + label; + m_sourceLabelsByIndex_RHS_0[index] = "RHS_0_" + label; + m_sourceLabelsByIndex_LHS_1[index] = "LHS_1_" + label; + m_sourceLabelsByIndex_LHS_0[index] = "LHS_0_" + label; + const Factor* sourceLabelFactor = factorCollection.AddFactor(label,true); + m_sourceLabelIndexesByFactor[sourceLabelFactor] = index; + } + + inFile.Close(); + + std::list specialLabels; + specialLabels.push_back("GlueTop"); + specialLabels.push_back("GlueX"); +// specialLabels.push_back("XRHS"); +// specialLabels.push_back("XLHS"); + for (std::list::const_iterator iter=specialLabels.begin(); + iter!=specialLabels.end(); ++iter) { + boost::unordered_map::iterator found = m_sourceLabels.find(*iter); + UTIL_THROW_IF2(found == m_sourceLabels.end(), GetScoreProducerDescription() + << ": Source label set file " << m_sourceLabelSetFile << " should contain an entry for the special label \"" << *iter << "\"."); + if (!(found->first).compare("GlueTop")) { + m_GlueTopLabel = found->second; +// } else if (!(found->first).compare("XRHS")) { +// m_XRHSLabel = found->second; +// } else if (!(found->first).compare("XLHS")) { +// m_XLHSLabel = found->second; + } + } + FEATUREVERBOSE2(2, " Done." << std::endl); +} + + +void SoftSourceSyntacticConstraintsFeature::LoadCoreSourceLabelSet() +{ + FEATUREVERBOSE(2, "Loading core source label set from file " << m_coreSourceLabelSetFile << " ..."); + // read core source label set + LoadLabelSet(m_coreSourceLabelSetFile, m_coreSourceLabels); + FEATUREVERBOSE2(2, " Done." << std::endl); +} + +void SoftSourceSyntacticConstraintsFeature::LoadLabelSet(std::string &filename, + boost::unordered_set &labelSet) +{ + InputFileStream inFile(filename); + std::string line; + labelSet.clear(); + while (getline(inFile, line)) { + istringstream tokenizer(line); + std::string label; + tokenizer >> label; + boost::unordered_map::iterator foundSourceLabelIndex = m_sourceLabels.find( label ); + if ( foundSourceLabelIndex != m_sourceLabels.end() ) { + labelSet.insert(foundSourceLabelIndex->second); + } else { + FEATUREVERBOSE(2, "Ignoring undefined source label \"" << label << "\" " + << "from core source label set file " << filename << "." + << std::endl); + } + } + inFile.Close(); +} + + +void SoftSourceSyntacticConstraintsFeature::LoadTargetSourceLeftHandSideJointCountFile() +{ + + FEATUREVERBOSE(2, "Loading target/source label joint counts from file " << m_targetSourceLHSJointCountFile << " ..."); + InputFileStream inFile(m_targetSourceLHSJointCountFile); + + for (boost::unordered_map >* >::iterator iter=m_labelPairProbabilities.begin(); + iter!=m_labelPairProbabilities.end(); ++iter) { + delete iter->second; + } + m_labelPairProbabilities.clear(); + + // read joint counts + std::string line; + FactorCollection &factorCollection = FactorCollection::Instance(); + boost::unordered_map targetLHSCounts; + std::vector sourceLHSCounts(m_sourceLabels.size(),0.0); + + while (getline(inFile, line)) { + istringstream tokenizer(line); + std::string targetLabel; + std::string sourceLabel; + float count; + tokenizer >> targetLabel; + tokenizer >> sourceLabel; + tokenizer >> count; + + boost::unordered_map::iterator foundSourceLabelIndex = m_sourceLabels.find( sourceLabel ); + UTIL_THROW_IF2(foundSourceLabelIndex == m_sourceLabels.end(), GetScoreProducerDescription() + << ": Target/source label joint count file " << m_targetSourceLHSJointCountFile + << " contains undefined source label \"" << sourceLabel << "\"."); + + const Factor* targetLabelFactor = factorCollection.AddFactor(targetLabel,true); + + sourceLHSCounts[foundSourceLabelIndex->second] += count; + std::pair< boost::unordered_map::iterator, bool > insertedTargetLHSCount = + targetLHSCounts.insert( std::pair(targetLabelFactor,count) ); + if (!insertedTargetLHSCount.second) { + (insertedTargetLHSCount.first)->second += count; + boost::unordered_map >* >::iterator jointCountIt = + m_labelPairProbabilities.find( targetLabelFactor ); + assert(jointCountIt != m_labelPairProbabilities.end()); + (jointCountIt->second)->at(foundSourceLabelIndex->second).first += count; + (jointCountIt->second)->at(foundSourceLabelIndex->second).second += count; + } else { + std::pair init(0.0,0.0); + std::vector< std::pair >* sourceVector = new std::vector< std::pair >(m_sourceLabels.size(),init); + sourceVector->at(foundSourceLabelIndex->second) = std::pair(count,count); + std::pair< boost::unordered_map >* >::iterator, bool > insertedJointCount = + m_labelPairProbabilities.insert( std::pair >* >(targetLabelFactor,sourceVector) ); + UTIL_THROW_IF2(!insertedJointCount.second, GetScoreProducerDescription() + << ": Loading target/source label joint counts from file " << m_targetSourceLHSJointCountFile << " failed."); + } + } + + // normalization + for (boost::unordered_map >* >::iterator iter=m_labelPairProbabilities.begin(); + iter!=m_labelPairProbabilities.end(); ++iter) { + float targetLHSCount = 0; + boost::unordered_map::const_iterator targetLHSCountIt = targetLHSCounts.find( iter->first ); + if ( targetLHSCountIt != targetLHSCounts.end() ) { + targetLHSCount = targetLHSCountIt->second; + } + std::vector< std::pair > &probabilities = *(iter->second); + for (size_t index=0; indexsize(); ++i) { + const ChartCellLabel &cell = *stackVec->at(i); + const Range &ntRange = cell.GetCoverage(); + FEATUREVERBOSE(3, "stackVec[ " << i << " ] : " << ntRange.GetStartPos() << " - " << ntRange.GetEndPos() << std::endl); + } + + for (AlignmentInfo::const_iterator it=targetPhrase.GetAlignNonTerm().begin(); + it!=targetPhrase.GetAlignNonTerm().end(); ++it) { + FEATUREVERBOSE(3, "alignNonTerm " << it->first << " " << it->second << std::endl); + } + } + + // dense scores + std::vector newScores(m_numScoreComponents,0); + + const TreeInput& treeInput = static_cast(input); + // const StaticData& staticData = StaticData::Instance(); + // const Word& outputDefaultNonTerminal = staticData.GetOutputDefaultNonTerminal(); + + size_t nNTs = 1; + bool treeInputMismatchLHSBinary = true; + size_t treeInputMismatchRHSCount = 0; + bool hasCompleteTreeInputMatch = false; + float ruleLabelledProbability = 0.0; + float treeInputMatchProbRHS = 0.0; + float treeInputMatchProbLHS = 0.0; + + // read SourceLabels property + const Factor* targetLHS = targetPhrase.GetTargetLHS()[0]; + bool isGlueGrammarRule = false; + bool isUnkRule = false; + + if (const PhraseProperty *property = targetPhrase.GetProperty("SourceLabels")) { + + const SourceLabelsPhraseProperty *sourceLabelsPhraseProperty = static_cast(property); + + nNTs = sourceLabelsPhraseProperty->GetNumberOfNonTerminals(); + float totalCount = sourceLabelsPhraseProperty->GetTotalCount(); + + // prepare for input tree label matching + std::vector< boost::unordered_set > treeInputLabelsRHS(nNTs-1); + boost::unordered_set treeInputLabelsLHS; + + // get index map for underlying hypotheses + const Range& range = inputPath.GetWordsRange(); + size_t startPos = range.GetStartPos(); + size_t endPos = range.GetEndPos(); + const Phrase *sourcePhrase = targetPhrase.GetRuleSource(); + + if (nNTs > 1) { // rule has right-hand side non-terminals, i.e. it's a hierarchical rule + size_t nonTerminalNumber = 0; + size_t sourceSentPos = startPos; + + for (size_t sourcePhrasePos=0; sourcePhrasePosGetSize(); ++sourcePhrasePos) { + // consult rule for either word or non-terminal + const Word &word = sourcePhrase->GetWord(sourcePhrasePos); + size_t symbolStartPos = sourceSentPos; + size_t symbolEndPos = sourceSentPos; + if ( word.IsNonTerminal() ) { + // retrieve information that is required for input tree label matching (RHS) + const ChartCellLabel &cell = *stackVec->at(nonTerminalNumber); + const Range& prevWordsRange = cell.GetCoverage(); + symbolStartPos = prevWordsRange.GetStartPos(); + symbolEndPos = prevWordsRange.GetEndPos(); + } + + const NonTerminalSet& treeInputLabels = treeInput.GetLabelSet(symbolStartPos,symbolEndPos); + + for (NonTerminalSet::const_iterator treeInputLabelsIt = treeInputLabels.begin(); + treeInputLabelsIt != treeInputLabels.end(); ++treeInputLabelsIt) { + if (*treeInputLabelsIt != m_options->syntax.output_default_non_terminal) { + boost::unordered_map::const_iterator foundTreeInputLabel + = m_sourceLabelIndexesByFactor.find((*treeInputLabelsIt)[0]); + if (foundTreeInputLabel != m_sourceLabelIndexesByFactor.end()) { + size_t treeInputLabelIndex = foundTreeInputLabel->second; + treeInputLabelsRHS[sourcePhrasePos].insert(treeInputLabelIndex); + } + } + } + + if ( word.IsNonTerminal() ) { + ++nonTerminalNumber; + } + sourceSentPos = symbolEndPos + 1; + } + } + + // retrieve information that is required for input tree label matching (LHS) + const NonTerminalSet& treeInputLabels = treeInput.GetLabelSet(startPos,endPos); + + for (NonTerminalSet::const_iterator treeInputLabelsIt = treeInputLabels.begin(); + treeInputLabelsIt != treeInputLabels.end(); ++treeInputLabelsIt) { + if (*treeInputLabelsIt != m_options->syntax.output_default_non_terminal) { + boost::unordered_map::const_iterator foundTreeInputLabel + = m_sourceLabelIndexesByFactor.find((*treeInputLabelsIt)[0]); + if (foundTreeInputLabel != m_sourceLabelIndexesByFactor.end()) { + size_t treeInputLabelIndex = foundTreeInputLabel->second; + treeInputLabelsLHS.insert(treeInputLabelIndex); + } + } + } + + + // inspect source-labelled rule items + + std::vector< boost::unordered_set > sparseScoredTreeInputLabelsRHS(nNTs-1); + boost::unordered_set sparseScoredTreeInputLabelsLHS; + + std::vector sourceLabelSeenAsLHS(m_sourceLabels.size(),false); + std::vector treeInputMatchRHSCountByNonTerminal(nNTs-1,false); + std::vector treeInputMatchProbRHSByNonTerminal(nNTs-1,0.0); + + const std::list &sourceLabelItems = sourceLabelsPhraseProperty->GetSourceLabelItems(); + + for (std::list::const_iterator sourceLabelItem = sourceLabelItems.begin(); + sourceLabelItem != sourceLabelItems.end() && !hasCompleteTreeInputMatch; ++sourceLabelItem) { + + const std::list &sourceLabelsRHS = sourceLabelItem->GetSourceLabelsRHS(); + const std::list< std::pair > &sourceLabelsLHSList = sourceLabelItem->GetSourceLabelsLHSList(); + float sourceLabelsRHSCount = sourceLabelItem->GetSourceLabelsRHSCount(); + + assert(sourceLabelsRHS.size() == nNTs-1); + + bool currentSourceLabelItemIsCompleteTreeInputMatch = true; + + size_t nonTerminalNumber=0; + for (std::list::const_iterator sourceLabelsRHSIt = sourceLabelsRHS.begin(); + sourceLabelsRHSIt != sourceLabelsRHS.end(); ++sourceLabelsRHSIt, ++nonTerminalNumber) { + + if (treeInputLabelsRHS[nonTerminalNumber].find(*sourceLabelsRHSIt) != treeInputLabelsRHS[nonTerminalNumber].end()) { + + treeInputMatchRHSCountByNonTerminal[nonTerminalNumber] = true; + treeInputMatchProbRHSByNonTerminal[nonTerminalNumber] += sourceLabelsRHSCount; // to be normalized later on + + if ( m_useSparse && + (!m_useCoreSourceLabels || m_coreSourceLabels.find(*sourceLabelsRHSIt) != m_coreSourceLabels.end()) ) { + // score sparse features: RHS match + if (sparseScoredTreeInputLabelsRHS[nonTerminalNumber].find(*sourceLabelsRHSIt) == sparseScoredTreeInputLabelsRHS[nonTerminalNumber].end()) { + // (only if no match has been scored for this tree input label and rule non-terminal with a previous sourceLabelItem) + float score_RHS_1 = (float)1/treeInputLabelsRHS[nonTerminalNumber].size(); + scoreBreakdown.PlusEquals(this, + m_sourceLabelsByIndex_RHS_1[*sourceLabelsRHSIt], + score_RHS_1); + sparseScoredTreeInputLabelsRHS[nonTerminalNumber].insert(*sourceLabelsRHSIt); + } + } + + } else { + + currentSourceLabelItemIsCompleteTreeInputMatch = false; + + } + } + + for (std::list< std::pair >::const_iterator sourceLabelsLHSIt = sourceLabelsLHSList.begin(); + sourceLabelsLHSIt != sourceLabelsLHSList.end(); ++sourceLabelsLHSIt) { + + if ( sourceLabelsLHSIt->first == m_GlueTopLabel ) { + isGlueGrammarRule = true; + } + + if (treeInputLabelsLHS.find(sourceLabelsLHSIt->first) != treeInputLabelsLHS.end()) { + + treeInputMismatchLHSBinary = false; + treeInputMatchProbLHS += sourceLabelsLHSIt->second; // to be normalized later on + + if ( m_useSparse && + (!m_useCoreSourceLabels || m_coreSourceLabels.find(sourceLabelsLHSIt->first) != m_coreSourceLabels.end()) ) { + // score sparse features: LHS match + if (sparseScoredTreeInputLabelsLHS.find(sourceLabelsLHSIt->first) == sparseScoredTreeInputLabelsLHS.end()) { + // (only if no match has been scored for this tree input label and rule non-terminal with a previous sourceLabelItem) + float score_LHS_1 = (float)1/treeInputLabelsLHS.size(); + scoreBreakdown.PlusEquals(this, + m_sourceLabelsByIndex_LHS_1[sourceLabelsLHSIt->first], + score_LHS_1); + sparseScoredTreeInputLabelsLHS.insert(sourceLabelsLHSIt->first); + } + } + + if ( currentSourceLabelItemIsCompleteTreeInputMatch ) { + ruleLabelledProbability += sourceLabelsLHSIt->second; // to be normalized later on + hasCompleteTreeInputMatch = true; + } + + } + } + } + + // normalization + for (std::vector::iterator treeInputMatchProbRHSByNonTerminalIt = treeInputMatchProbRHSByNonTerminal.begin(); + treeInputMatchProbRHSByNonTerminalIt != treeInputMatchProbRHSByNonTerminal.end(); ++treeInputMatchProbRHSByNonTerminalIt) { + *treeInputMatchProbRHSByNonTerminalIt /= totalCount; + if ( *treeInputMatchProbRHSByNonTerminalIt != 0 ) { + treeInputMatchProbRHS += ( m_useLogprobs ? TransformScore(*treeInputMatchProbRHSByNonTerminalIt) : *treeInputMatchProbRHSByNonTerminalIt ); + } + } + treeInputMatchProbLHS /= totalCount; + ruleLabelledProbability /= totalCount; + + // input tree matching (RHS) + if ( !hasCompleteTreeInputMatch ) { + treeInputMismatchRHSCount = nNTs-1; + for (std::vector::const_iterator treeInputMatchRHSCountByNonTerminalIt = treeInputMatchRHSCountByNonTerminal.begin(); + treeInputMatchRHSCountByNonTerminalIt != treeInputMatchRHSCountByNonTerminal.end(); ++treeInputMatchRHSCountByNonTerminalIt) { + if (*treeInputMatchRHSCountByNonTerminalIt) { + --treeInputMismatchRHSCount; + } + } + } + + // score sparse features: mismatches + if ( m_useSparse ) { + + // RHS + + for (size_t nonTerminalNumber = 0; nonTerminalNumber < nNTs-1; ++nonTerminalNumber) { + // nNTs-1 because nNTs also counts the left-hand side non-terminal + + float score_RHS_0 = (float)1/treeInputLabelsRHS[nonTerminalNumber].size(); + for (boost::unordered_set::const_iterator treeInputLabelsRHSIt = treeInputLabelsRHS[nonTerminalNumber].begin(); + treeInputLabelsRHSIt != treeInputLabelsRHS[nonTerminalNumber].end(); ++treeInputLabelsRHSIt) { + + if ( !m_useCoreSourceLabels || m_coreSourceLabels.find(*treeInputLabelsRHSIt) != m_coreSourceLabels.end() ) { + + if (sparseScoredTreeInputLabelsRHS[nonTerminalNumber].find(*treeInputLabelsRHSIt) == sparseScoredTreeInputLabelsRHS[nonTerminalNumber].end()) { + // score sparse features: RHS mismatch + scoreBreakdown.PlusEquals(this, + m_sourceLabelsByIndex_RHS_0[*treeInputLabelsRHSIt], + score_RHS_0); + } + } + } + } + + // LHS + + float score_LHS_0 = (float)1/treeInputLabelsLHS.size(); + for (boost::unordered_set::const_iterator treeInputLabelsLHSIt = treeInputLabelsLHS.begin(); + treeInputLabelsLHSIt != treeInputLabelsLHS.end(); ++treeInputLabelsLHSIt) { + + if ( !m_useCoreSourceLabels || m_coreSourceLabels.find(*treeInputLabelsLHSIt) != m_coreSourceLabels.end() ) { + + if (sparseScoredTreeInputLabelsLHS.find(*treeInputLabelsLHSIt) == sparseScoredTreeInputLabelsLHS.end()) { + // score sparse features: RHS mismatch + scoreBreakdown.PlusEquals(this, + m_sourceLabelsByIndex_LHS_0[*treeInputLabelsLHSIt], + score_LHS_0); + } + } + } + + } + + if ( m_useSparseLabelPairs && !isGlueGrammarRule ) { + + // left-hand side label pairs (target NT, source NT) + float t2sLabelsScore = 0.0; + float s2tLabelsScore = 0.0; + for (boost::unordered_set::const_iterator treeInputLabelsLHSIt = treeInputLabelsLHS.begin(); + treeInputLabelsLHSIt != treeInputLabelsLHS.end(); ++treeInputLabelsLHSIt) { + + scoreBreakdown.PlusEquals(this, + "LHSPAIR_" + targetLHS->GetString().as_string() + "_" + m_sourceLabelsByIndex[*treeInputLabelsLHSIt], + (float)1/treeInputLabelsLHS.size()); + + if (!m_targetSourceLHSJointCountFile.empty()) { + std::pair probPair = GetLabelPairProbabilities( targetLHS, *treeInputLabelsLHSIt); + t2sLabelsScore += probPair.first; + s2tLabelsScore += probPair.second; + } + } + if ( treeInputLabelsLHS.size() == 0 ) { + scoreBreakdown.PlusEquals(this, + "LHSPAIR_" + targetLHS->GetString().as_string() + "_" + + m_options->syntax.output_default_non_terminal[0] + ->GetString().as_string(), + 1); + if (!m_targetSourceLHSJointCountFile.empty()) { + t2sLabelsScore = TransformScore(m_floor); + s2tLabelsScore = TransformScore(m_floor); + } + } else { + if (!m_targetSourceLHSJointCountFile.empty()) { + float norm = TransformScore(treeInputLabelsLHS.size()); + t2sLabelsScore = TransformScore(t2sLabelsScore) - norm; + s2tLabelsScore = TransformScore(s2tLabelsScore) - norm; + } + } + if (!m_targetSourceLHSJointCountFile.empty()) { + scoreBreakdown.PlusEquals(this, "LHST2S", t2sLabelsScore); + scoreBreakdown.PlusEquals(this, "LHSS2T", s2tLabelsScore); + } + } + + } else { + + // abort with error message if the phrase does not translate an unknown word + UTIL_THROW_IF2(!targetPhrase.GetWord(0).IsOOV(), GetScoreProducerDescription() + << ": Missing SourceLabels property. " + << "Please check phrase table and glue rules."); + + // unknown word + isUnkRule = true; +// ruleLabelledProbability = 1; + + } + + // add scores + + // input tree matching + newScores[0] = !hasCompleteTreeInputMatch; + if ( m_noMismatches ) { + newScores[0] = ( (hasCompleteTreeInputMatch || isGlueGrammarRule || isUnkRule) ? 0 : -std::numeric_limits::infinity() ); + } + newScores[1] = treeInputMismatchLHSBinary; + newScores[2] = treeInputMismatchRHSCount; + + if ( m_useLogprobs ) { + if ( ruleLabelledProbability != 0 ) { + ruleLabelledProbability = TransformScore(ruleLabelledProbability); + } + if ( treeInputMatchProbLHS != 0 ) { + treeInputMatchProbLHS = TransformScore(treeInputMatchProbLHS); + } + } + + newScores[3] = ruleLabelledProbability; + newScores[4] = treeInputMatchProbLHS; + newScores[5] = treeInputMatchProbRHS; + + scoreBreakdown.PlusEquals(this, newScores); +} + + +std::pair SoftSourceSyntacticConstraintsFeature::GetLabelPairProbabilities( + const Factor* target, + const size_t source) const +{ + boost::unordered_map >* >::const_iterator found = + m_labelPairProbabilities.find(target); + if ( found == m_labelPairProbabilities.end() ) { + return std::pair(m_floor,m_floor); // floor values + } + std::pair ret = found->second->at(source); + if ( ret == std::pair(0,0) ) { + return std::pair(m_floor,m_floor); // floor values + } + return ret; +} + + +} + diff --git a/mosesdecoder/moses/FF/SourceGHKMTreeInputMatchFeature.cpp b/mosesdecoder/moses/FF/SourceGHKMTreeInputMatchFeature.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa0aa87b933951f6e4c0f1b472cbe9d4042c6d7b --- /dev/null +++ b/mosesdecoder/moses/FF/SourceGHKMTreeInputMatchFeature.cpp @@ -0,0 +1,75 @@ +#include +#include +#include +#include "SourceGHKMTreeInputMatchFeature.h" +#include "moses/StaticData.h" +#include "moses/InputFileStream.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/Hypothesis.h" +#include "moses/ChartHypothesis.h" +#include "moses/Factor.h" +#include "moses/FactorCollection.h" +#include "moses/InputPath.h" +#include "moses/TreeInput.h" + + +using namespace std; + +namespace Moses +{ + +SourceGHKMTreeInputMatchFeature::SourceGHKMTreeInputMatchFeature(const std::string &line) + : StatelessFeatureFunction(2, line) +{ + std::cerr << GetScoreProducerDescription() << "Initializing feature..."; + ReadParameters(); + std::cerr << " Done." << std::endl; +} + +void SourceGHKMTreeInputMatchFeature::SetParameter(const std::string& key, const std::string& value) +{ + UTIL_THROW(util::Exception, GetScoreProducerDescription() << ": Unknown parameter " << key << "=" << value); +} + +// assumes that source-side syntax labels are stored in the target non-terminal field of the rules +void SourceGHKMTreeInputMatchFeature::EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores) const +{ + const Range& range = inputPath.GetWordsRange(); + size_t startPos = range.GetStartPos(); + size_t endPos = range.GetEndPos(); + const TreeInput& treeInput = static_cast(input); + const NonTerminalSet& treeInputLabels = treeInput.GetLabelSet(startPos,endPos); + const Word& lhsLabel = targetPhrase.GetTargetLHS(); + + const StaticData& staticData = StaticData::Instance(); + + std::vector newScores(m_numScoreComponents,0.0); + // m_numScoreComponents == 2 // first fires for matches, second for mismatches + + if ( (treeInputLabels.find(lhsLabel) != treeInputLabels.end()) + && (lhsLabel != m_options->syntax.output_default_non_terminal) ) { + // match + newScores[0] = 1.0; + } else { + // mismatch + newScores[1] = 1.0; + } + + scoreBreakdown.PlusEquals(this, newScores); +} + +void +SourceGHKMTreeInputMatchFeature:: +Load(AllOptions::ptr const& opts) +{ + m_options = opts; + // m_output_default_nonterminal = opts->syntax.output_default_non_terminal; +} + +} + diff --git a/mosesdecoder/moses/FF/SourceGHKMTreeInputMatchFeature.h b/mosesdecoder/moses/FF/SourceGHKMTreeInputMatchFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..403dca71668f8d9c5aa70da8eea8681dd275d655 --- /dev/null +++ b/mosesdecoder/moses/FF/SourceGHKMTreeInputMatchFeature.h @@ -0,0 +1,49 @@ +#pragma once + +#include "StatelessFeatureFunction.h" +#include "moses/parameters/AllOptions.h" + +namespace Moses +{ + +// assumes that source-side syntax labels are stored in the target non-terminal field of the rules +class SourceGHKMTreeInputMatchFeature : public StatelessFeatureFunction +{ +public: + SourceGHKMTreeInputMatchFeature(const std::string &line); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + void SetParameter(const std::string& key, const std::string& value); + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const {}; + + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const; + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const {}; + + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const {}; + + void Load(AllOptions::ptr const& opts); +}; + + +} + diff --git a/mosesdecoder/moses/FF/SourceWordDeletionFeature.h b/mosesdecoder/moses/FF/SourceWordDeletionFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..0257621ead6a6b1af858725652c705f6c63b145a --- /dev/null +++ b/mosesdecoder/moses/FF/SourceWordDeletionFeature.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include + +#include "StatelessFeatureFunction.h" +#include "moses/FactorCollection.h" +#include "moses/AlignmentInfo.h" + +namespace Moses +{ + +/** Sets the features for source word deletion + */ +class SourceWordDeletionFeature : public StatelessFeatureFunction +{ +private: + boost::unordered_set m_vocab; + FactorType m_factorType; + bool m_unrestricted; + std::string m_filename; + +public: + SourceWordDeletionFeature(const std::string &line); + + void Load(AllOptions::ptr const& opts); + + bool IsUseable(const FactorMask &mask) const; + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const; + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const { + } + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + + void ComputeFeatures(const Phrase &source, + const TargetPhrase& targetPhrase, + ScoreComponentCollection* accumulator, + const AlignmentInfo &alignmentInfo) const; + void SetParameter(const std::string& key, const std::string& value); + +}; + +} + diff --git a/mosesdecoder/moses/FF/SpanLength.cpp b/mosesdecoder/moses/FF/SpanLength.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1a413013c0237a3813c6a59fb6af1eb77f335a29 --- /dev/null +++ b/mosesdecoder/moses/FF/SpanLength.cpp @@ -0,0 +1,88 @@ +#include +#include "SpanLength.h" +#include "moses/StaticData.h" +#include "moses/Word.h" +#include "moses/ChartCellLabel.h" +#include "moses/Range.h" +#include "moses/StackVec.h" +#include "moses/TargetPhrase.h" +#include "moses/PP/PhraseProperty.h" +#include "moses/PP/SpanLengthPhraseProperty.h" + +using namespace std; + +namespace Moses +{ +SpanLength::SpanLength(const std::string &line) + :StatelessFeatureFunction(1, line) + ,m_smoothingMethod(None) + ,m_const(0) +{ + ReadParameters(); +} + +void SpanLength::EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const +{ + targetPhrase.SetRuleSource(source); +} + +void SpanLength::EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores) const +{ + assert(stackVec); + + const PhraseProperty *property = targetPhrase.GetProperty("SpanLength"); + if (property == NULL) { + return; + } + + const SpanLengthPhraseProperty *slProp = static_cast(property); + + assert(targetPhrase.GetRuleSource()); + + float score = 0; + for (size_t i = 0; i < stackVec->size(); ++i) { + const ChartCellLabel &cell = *stackVec->at(i); + const Range &ntRange = cell.GetCoverage(); + size_t sourceWidth = ntRange.GetNumWordsCovered(); + float prob = slProp->GetProb(i, sourceWidth, m_const); + score += TransformScore(prob); + } + + if (score < -100.0f) { + float weight = StaticData::Instance().GetWeight(this); + if (weight < 0) { + score = -100; + } + } + + scoreBreakdown.PlusEquals(this, score); + +} + +void SpanLength::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "smoothing") { + if (value == "plus-constant") { + m_smoothingMethod = PlusConst; + } else if (value == "none") { + m_smoothingMethod = None; + } else { + UTIL_THROW(util::Exception, "Unknown smoothing type " << value); + } + } else if (key == "constant") { + m_const = Scan(value); + } else { + StatelessFeatureFunction::SetParameter(key, value); + } +} + +} + diff --git a/mosesdecoder/moses/FF/SpanLength.h b/mosesdecoder/moses/FF/SpanLength.h new file mode 100644 index 0000000000000000000000000000000000000000..3a13d3a3d446c2c1ecad4f21c196eb852503351e --- /dev/null +++ b/mosesdecoder/moses/FF/SpanLength.h @@ -0,0 +1,56 @@ +#pragma once +#include +#include "StatelessFeatureFunction.h" + +namespace Moses +{ + +// Rule Scope - not quite completely implemented yet +class SpanLength : public StatelessFeatureFunction +{ +public: + SpanLength(const std::string &line); + + virtual bool IsUseable(const FactorMask &mask) const { + return true; + } + + virtual void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const; + + virtual void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const; + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + + + virtual void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + + virtual void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + + void SetParameter(const std::string& key, const std::string& value); + +protected: + enum SmoothingMethod { + None, + PlusConst, + }; + SmoothingMethod m_smoothingMethod; + + float m_const; +}; + +} + diff --git a/mosesdecoder/moses/FF/SparseHieroReorderingFeature.cpp b/mosesdecoder/moses/FF/SparseHieroReorderingFeature.cpp new file mode 100644 index 0000000000000000000000000000000000000000..df689b807767a7645b709c2dc0d70d0b97f4831a --- /dev/null +++ b/mosesdecoder/moses/FF/SparseHieroReorderingFeature.cpp @@ -0,0 +1,224 @@ +#include + +#include "moses/ChartHypothesis.h" +#include "moses/ChartManager.h" +#include "moses/FactorCollection.h" +#include "moses/Sentence.h" + +#include "util/exception.hh" +#include "util/string_stream.hh" + +#include "SparseHieroReorderingFeature.h" + +using namespace std; + +namespace Moses +{ + +SparseHieroReorderingFeature::SparseHieroReorderingFeature(const std::string &line) + :StatelessFeatureFunction(0, line), + m_type(SourceCombined), + m_sourceFactor(0), + m_targetFactor(0), + m_sourceVocabFile(""), + m_targetVocabFile("") +{ + + /* + Configuration of features. + factor - Which factor should it apply to + type - what type of sparse reordering feature. e.g. block (modelled on Matthias + Huck's EAMT 2012 features) + word - which words to include, e.g. src_bdry, src_all, tgt_bdry , ... + vocab - vocab file to limit it to + orientation - e.g. lr, etc. + */ + cerr << "Constructing a Sparse Reordering feature" << endl; + ReadParameters(); + m_otherFactor = FactorCollection::Instance().AddFactor("##OTHER##"); + LoadVocabulary(m_sourceVocabFile, m_sourceVocab); + LoadVocabulary(m_targetVocabFile, m_targetVocab); +} + +void SparseHieroReorderingFeature::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "input-factor") { + m_sourceFactor = Scan(value); + } else if (key == "output-factor") { + m_targetFactor = Scan(value); + } else if (key == "input-vocab-file") { + m_sourceVocabFile = value; + } else if (key == "output-vocab-file") { + m_targetVocabFile = value; + } else if (key == "type") { + if (value == "SourceCombined") { + m_type = SourceCombined; + } else if (value == "SourceLeft") { + m_type = SourceLeft; + } else if (value == "SourceRight") { + m_type = SourceRight; + } else { + UTIL_THROW(util::Exception, "Unknown sparse reordering type " << value); + } + } else { + FeatureFunction::SetParameter(key, value); + } +} + +void SparseHieroReorderingFeature::LoadVocabulary(const std::string& filename, Vocab& vocab) +{ + if (filename.empty()) return; + ifstream in(filename.c_str()); + UTIL_THROW_IF(!in, util::Exception, "Unable to open vocab file: " << filename); + string line; + while(getline(in,line)) { + vocab.insert(FactorCollection::Instance().AddFactor(line)); + } + in.close(); +} + +const Factor* SparseHieroReorderingFeature::GetFactor(const Word& word, const Vocab& vocab, FactorType factorType) const +{ + const Factor* factor = word.GetFactor(factorType); + if (vocab.size() && vocab.find(factor) == vocab.end()) return m_otherFactor; + return factor; +} + +void SparseHieroReorderingFeature::EvaluateWhenApplied( + const ChartHypothesis& cur_hypo , + ScoreComponentCollection* accumulator) const +{ + // get index map for underlying hypotheses + //const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = + // cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap(); + + //The Huck features. For a rule with source side: + // abXcdXef + //We first have to split into blocks: + // ab X cd X ef + //Then we extract features based in the boundary words of the neighbouring blocks + //For the block pair, we use the right word of the left block, and the left + //word of the right block. + + //Need to get blocks, and their alignment. Each block has a word range (on the + // on the source), a non-terminal flag, and a set of alignment points in the target phrase + + //We need to be able to map source word position to target word position, as + //much as possible (don't need interior of non-terminals). The alignment info + //objects just give us the mappings between *rule* positions. So if we can + //map source word position to source rule position, and target rule position + //to target word position, then we can map right through. + + size_t sourceStart = cur_hypo.GetCurrSourceRange().GetStartPos(); + size_t sourceSize = cur_hypo.GetCurrSourceRange().GetNumWordsCovered(); + + vector sourceNTSpans; + for (size_t prevHypoId = 0; prevHypoId < cur_hypo.GetPrevHypos().size(); ++prevHypoId) { + sourceNTSpans.push_back(cur_hypo.GetPrevHypo(prevHypoId)->GetCurrSourceRange()); + } + //put in source order. Is this necessary? + sort(sourceNTSpans.begin(), sourceNTSpans.end()); + //cerr << "Source NTs: "; + //for (size_t i = 0; i < sourceNTSpans.size(); ++i) cerr << sourceNTSpans[i] << " "; + //cerr << endl; + + typedef pair Block;//flag indicates NT + vector sourceBlocks; + sourceBlocks.push_back(Block(cur_hypo.GetCurrSourceRange(),false)); + for (vector::const_iterator i = sourceNTSpans.begin(); + i != sourceNTSpans.end(); ++i) { + const Range& prevHypoRange = *i; + Block lastBlock = sourceBlocks.back(); + sourceBlocks.pop_back(); + //split this range into before NT, NT and after NT + if (prevHypoRange.GetStartPos() > lastBlock.first.GetStartPos()) { + sourceBlocks.push_back(Block(Range(lastBlock.first.GetStartPos(),prevHypoRange.GetStartPos()-1),false)); + } + sourceBlocks.push_back(Block(prevHypoRange,true)); + if (prevHypoRange.GetEndPos() < lastBlock.first.GetEndPos()) { + sourceBlocks.push_back(Block(Range(prevHypoRange.GetEndPos()+1,lastBlock.first.GetEndPos()), false)); + } + } + /* + cerr << "Source Blocks: "; + for (size_t i = 0; i < sourceBlocks.size(); ++i) cerr << sourceBlocks[i].first << " " + << (sourceBlocks[i].second ? "NT" : "T") << " "; + cerr << endl; + */ + + //Mapping from source word to target rule position + vector sourceWordToTargetRulePos(sourceSize); + map alignMap; + alignMap.insert( + cur_hypo.GetCurrTargetPhrase().GetAlignTerm().begin(), + cur_hypo.GetCurrTargetPhrase().GetAlignTerm().end()); + alignMap.insert( + cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().begin(), + cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().end()); + //vector alignMapTerm = cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm() + size_t sourceRulePos = 0; + //cerr << "SW->RP "; + for (vector::const_iterator sourceBlockIt = sourceBlocks.begin(); + sourceBlockIt != sourceBlocks.end(); ++sourceBlockIt) { + for (size_t sourceWordPos = sourceBlockIt->first.GetStartPos(); + sourceWordPos <= sourceBlockIt->first.GetEndPos(); ++sourceWordPos) { + sourceWordToTargetRulePos[sourceWordPos - sourceStart] = alignMap[sourceRulePos]; + // cerr << sourceWordPos - sourceStart << "-" << alignMap[sourceRulePos] << " "; + if (! sourceBlockIt->second) { + //T + ++sourceRulePos; + } + } + if ( sourceBlockIt->second) { + //NT + ++sourceRulePos; + } + } + //cerr << endl; + + //Iterate through block pairs + const Sentence& sentence = + static_cast(cur_hypo.GetManager().GetSource()); + //const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase(); + for (size_t i = 0; i < sourceBlocks.size()-1; ++i) { + Block& leftSourceBlock = sourceBlocks[i]; + Block& rightSourceBlock = sourceBlocks[i+1]; + size_t sourceLeftBoundaryPos = leftSourceBlock.first.GetEndPos(); + size_t sourceRightBoundaryPos = rightSourceBlock.first.GetStartPos(); + const Word& sourceLeftBoundaryWord = sentence.GetWord(sourceLeftBoundaryPos); + const Word& sourceRightBoundaryWord = sentence.GetWord(sourceRightBoundaryPos); + sourceLeftBoundaryPos -= sourceStart; + sourceRightBoundaryPos -= sourceStart; + + // Need to figure out where these map to on the target. + size_t targetLeftRulePos = + sourceWordToTargetRulePos[sourceLeftBoundaryPos]; + size_t targetRightRulePos = + sourceWordToTargetRulePos[sourceRightBoundaryPos]; + + bool isMonotone = true; + if ((sourceLeftBoundaryPos < sourceRightBoundaryPos && + targetLeftRulePos > targetRightRulePos) || + ((sourceLeftBoundaryPos > sourceRightBoundaryPos && + targetLeftRulePos < targetRightRulePos))) { + isMonotone = false; + } + util::StringStream buf; + buf << "h_"; //sparse reordering, Huck + if (m_type == SourceLeft || m_type == SourceCombined) { + buf << GetFactor(sourceLeftBoundaryWord,m_sourceVocab,m_sourceFactor)->GetString(); + buf << "_"; + } + if (m_type == SourceRight || m_type == SourceCombined) { + buf << GetFactor(sourceRightBoundaryWord,m_sourceVocab,m_sourceFactor)->GetString(); + buf << "_"; + } + buf << (isMonotone ? "M" : "S"); + accumulator->PlusEquals(this,buf.str(), 1); + } +// cerr << endl; +} + + +} + diff --git a/mosesdecoder/moses/FF/SparseHieroReorderingFeatureTest.cpp b/mosesdecoder/moses/FF/SparseHieroReorderingFeatureTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f05355df91da9019261ffb798e1ac78ff7542aad --- /dev/null +++ b/mosesdecoder/moses/FF/SparseHieroReorderingFeatureTest.cpp @@ -0,0 +1,36 @@ +/*********************************************************************** +Moses - factored phrase-based language decoder +Copyright (C) 2013- University of Edinburgh + +This library is free software; you can redistribute it and/or +modify it under the terms of the GNU Lesser General Public +License as published by the Free Software Foundation; either +version 2.1 of the License, or (at your option) any later version. + +This library is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public +License along with this library; if not, write to the Free Software +Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +***********************************************************************/ +#include + +#include + +#include "SparseHieroReorderingFeature.h" + +using namespace Moses; +using namespace std; + +BOOST_AUTO_TEST_SUITE(shrf) + +BOOST_AUTO_TEST_CASE(lexical_rule) +{ + SparseHieroReorderingFeature feature("name=shrf"); + +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/mosesdecoder/moses/FF/StatefulFeatureFunction.cpp b/mosesdecoder/moses/FF/StatefulFeatureFunction.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bfb56f88c7c19330c20406d3b1ba2f9b28c2aca7 --- /dev/null +++ b/mosesdecoder/moses/FF/StatefulFeatureFunction.cpp @@ -0,0 +1,23 @@ +#include "StatefulFeatureFunction.h" + +namespace Moses +{ + +std::vector StatefulFeatureFunction::m_statefulFFs; + +StatefulFeatureFunction +::StatefulFeatureFunction(const std::string &line, bool registerNow) + : FeatureFunction(line, registerNow) +{ + m_statefulFFs.push_back(this); +} + +StatefulFeatureFunction +::StatefulFeatureFunction(size_t numScoreComponents, const std::string &line) + : FeatureFunction(numScoreComponents, line) +{ + m_statefulFFs.push_back(this); +} + +} + diff --git a/mosesdecoder/moses/FF/StatelessFeatureFunction.cpp b/mosesdecoder/moses/FF/StatelessFeatureFunction.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15d97e4bc2ee03d06e2007ca7565a611ab96cd2a --- /dev/null +++ b/mosesdecoder/moses/FF/StatelessFeatureFunction.cpp @@ -0,0 +1,23 @@ +#include "StatelessFeatureFunction.h" + +namespace Moses +{ + +std::vector StatelessFeatureFunction::m_statelessFFs; + +StatelessFeatureFunction +::StatelessFeatureFunction(const std::string &line, bool registerNow) + : FeatureFunction(line, registerNow) +{ + m_statelessFFs.push_back(this); +} + +StatelessFeatureFunction +::StatelessFeatureFunction(size_t numScoreComponents, const std::string &line) + : FeatureFunction(numScoreComponents, line) +{ + m_statelessFFs.push_back(this); +} + +} + diff --git a/mosesdecoder/moses/FF/StatelessFeatureFunction.h b/mosesdecoder/moses/FF/StatelessFeatureFunction.h new file mode 100644 index 0000000000000000000000000000000000000000..a364a811fb02a4cc6f68eb28e9a30dbbb373a69b --- /dev/null +++ b/mosesdecoder/moses/FF/StatelessFeatureFunction.h @@ -0,0 +1,55 @@ +#pragma once + +#include "FeatureFunction.h" + + +namespace Moses +{ + +namespace Syntax +{ +struct SHyperedge; +} + +/** base class for all stateless feature functions. + * eg. phrase table, word penalty, phrase penalty + */ +class StatelessFeatureFunction: public FeatureFunction +{ + //All stateless FFs, except those that cache scores in T-Option + static std::vector m_statelessFFs; + +public: + static const std::vector& GetStatelessFeatureFunctions() { + return m_statelessFFs; + } + + StatelessFeatureFunction(const std::string &line, bool registerNow); + StatelessFeatureFunction(size_t numScoreComponents, const std::string &line); + + /** + * This should be implemented for features that apply to phrase-based models. + **/ + virtual void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const = 0; + + /** + * Same for chart-based features. + **/ + virtual void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const = 0; + + virtual void EvaluateWhenApplied(const Syntax::SHyperedge &, + ScoreComponentCollection*) const { + assert(false); + } + + virtual bool IsStateless() const { + return true; + } + +}; + + +} // namespace + diff --git a/mosesdecoder/moses/FF/SyntaxRHS.cpp b/mosesdecoder/moses/FF/SyntaxRHS.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2aaf1d5f9059527a60025e961667bcbbb45d078f --- /dev/null +++ b/mosesdecoder/moses/FF/SyntaxRHS.cpp @@ -0,0 +1,42 @@ +#include +#include "SyntaxRHS.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/TargetPhrase.h" +#include "moses/StackVec.h" + +using namespace std; + +namespace Moses +{ +SyntaxRHS::SyntaxRHS(const std::string &line) + :StatelessFeatureFunction(1, line) +{ + ReadParameters(); +} + +void SyntaxRHS::EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const +{ +} + +void SyntaxRHS::EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores) const +{ + assert(stackVec); + + if (targetPhrase.GetNumNonTerminals()) { + vector newScores(m_numScoreComponents); + newScores[0] = - std::numeric_limits::infinity(); + scoreBreakdown.PlusEquals(this, newScores); + } + +} + +} + diff --git a/mosesdecoder/moses/FF/SyntaxRHS.h b/mosesdecoder/moses/FF/SyntaxRHS.h new file mode 100644 index 0000000000000000000000000000000000000000..0096e286a9419808ba2c25f8406411ca10eb419a --- /dev/null +++ b/mosesdecoder/moses/FF/SyntaxRHS.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include "StatelessFeatureFunction.h" + +namespace Moses +{ + +class SyntaxRHS : public StatelessFeatureFunction +{ +public: + SyntaxRHS(const std::string &line); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const; + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const; + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + +}; + +} + diff --git a/mosesdecoder/moses/FF/TargetBigramFeature.cpp b/mosesdecoder/moses/FF/TargetBigramFeature.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a6f3249e69d0967e2fc50b108b0ac9bad2c04746 --- /dev/null +++ b/mosesdecoder/moses/FF/TargetBigramFeature.cpp @@ -0,0 +1,133 @@ +#include "TargetBigramFeature.h" +#include "moses/Phrase.h" +#include "moses/TargetPhrase.h" +#include "moses/Hypothesis.h" +#include "moses/ScoreComponentCollection.h" +#include "util/string_piece_hash.hh" +#include "util/exception.hh" + +using namespace std; + +namespace Moses +{ + +size_t TargetBigramState::hash() const +{ + std::size_t ret = hash_value(m_word); + return ret; +} + +bool TargetBigramState::operator==(const FFState& other) const +{ + const TargetBigramState& rhs = static_cast(other); + return m_word == rhs.m_word; +} + +//////////////////////////////////////////////////////////////////////////////// +TargetBigramFeature::TargetBigramFeature(const std::string &line) + :StatefulFeatureFunction(0, line) +{ + std::cerr << "Initializing target bigram feature.." << std::endl; + ReadParameters(); + + FactorCollection& factorCollection = FactorCollection::Instance(); + const Factor* bosFactor = + factorCollection.AddFactor(Output,m_factorType,BOS_); + m_bos.SetFactor(m_factorType,bosFactor); + +} + +void TargetBigramFeature::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "factor") { + m_factorType = Scan(value); + } else if (key == "path") { + m_filePath = value; + } else { + StatefulFeatureFunction::SetParameter(key, value); + } +} + +void TargetBigramFeature::Load(AllOptions::ptr const& opts) +{ + m_options = opts; + if (m_filePath == "*") + return ; //allow all + ifstream inFile(m_filePath.c_str()); + UTIL_THROW_IF2(!inFile, "Can't open file " << m_filePath); + + std::string line; + m_vocab.insert(BOS_); + m_vocab.insert(BOS_); + while (getline(inFile, line)) { + m_vocab.insert(line); + } + + inFile.close(); +} + + +const FFState* TargetBigramFeature::EmptyHypothesisState(const InputType &/*input*/) const +{ + return new TargetBigramState(m_bos); +} + +FFState* TargetBigramFeature::EvaluateWhenApplied(const Hypothesis& cur_hypo, + const FFState* prev_state, + ScoreComponentCollection* accumulator) const +{ + const TargetBigramState* tbState = static_cast(prev_state); + assert(tbState); + + // current hypothesis target phrase + const Phrase& targetPhrase = cur_hypo.GetCurrTargetPhrase(); + if (targetPhrase.GetSize() == 0) { + return new TargetBigramState(*tbState); + } + + // extract all bigrams w1 w2 from current hypothesis + for (size_t i = 0; i < targetPhrase.GetSize(); ++i) { + const Factor* f1 = NULL; + if (i == 0) { + f1 = tbState->GetWord().GetFactor(m_factorType); + } else { + f1 = targetPhrase.GetWord(i-1).GetFactor(m_factorType); + } + const Factor* f2 = targetPhrase.GetWord(i).GetFactor(m_factorType); + const StringPiece w1 = f1->GetString(); + const StringPiece w2 = f2->GetString(); + + // skip bigrams if they don't belong to a given restricted vocabulary + if (m_vocab.size() && + (FindStringPiece(m_vocab, w1) == m_vocab.end() || FindStringPiece(m_vocab, w2) == m_vocab.end())) { + continue; + } + + string name(w1.data(), w1.size()); + name += ":"; + name.append(w2.data(), w2.size()); + accumulator->PlusEquals(this,name,1); + } + + if (cur_hypo.GetWordsBitmap().IsComplete()) { + const StringPiece w1 = targetPhrase.GetWord(targetPhrase.GetSize()-1).GetFactor(m_factorType)->GetString(); + const string& w2 = EOS_; + if (m_vocab.empty() || (FindStringPiece(m_vocab, w1) != m_vocab.end())) { + string name(w1.data(), w1.size()); + name += ":"; + name += w2; + accumulator->PlusEquals(this,name,1); + } + return NULL; + } + return new TargetBigramState(targetPhrase.GetWord(targetPhrase.GetSize()-1)); +} + +bool TargetBigramFeature::IsUseable(const FactorMask &mask) const +{ + bool ret = mask[m_factorType]; + return ret; +} + +} + diff --git a/mosesdecoder/moses/FF/TargetBigramFeature.h b/mosesdecoder/moses/FF/TargetBigramFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..a12f3d25d261776e14b8b29489502ee397611745 --- /dev/null +++ b/mosesdecoder/moses/FF/TargetBigramFeature.h @@ -0,0 +1,63 @@ +#ifndef moses_TargetBigramFeature_h +#define moses_TargetBigramFeature_h + +#include +#include +#include + +#include "moses/FF/FFState.h" +#include "StatefulFeatureFunction.h" +#include "moses/FactorCollection.h" +#include "moses/Word.h" + +namespace Moses +{ + +class TargetBigramState : public FFState +{ +public: + TargetBigramState(const Word& word): m_word(word) {} + const Word& GetWord() const { + return m_word; + } + size_t hash() const; + virtual bool operator==(const FFState& other) const; + +private: + Word m_word; +}; + +/** Sets the features of observed bigrams. + */ +class TargetBigramFeature : public StatefulFeatureFunction +{ +public: + TargetBigramFeature(const std::string &line); + + void Load(AllOptions::ptr const& opts); + + bool IsUseable(const FactorMask &mask) const; + + virtual const FFState* EmptyHypothesisState(const InputType &input) const; + + virtual FFState* EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state, + ScoreComponentCollection* accumulator) const; + + virtual FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, + int /* featureID */, + ScoreComponentCollection* ) const { + throw std::logic_error("TargetBigramFeature not valid in chart decoder"); + } + + void SetParameter(const std::string& key, const std::string& value); + +private: + FactorType m_factorType; + Word m_bos; + std::string m_filePath; + boost::unordered_set m_vocab; +}; + +} + +#endif // moses_TargetBigramFeature_h diff --git a/mosesdecoder/moses/FF/TargetPreferencesFeature.cpp b/mosesdecoder/moses/FF/TargetPreferencesFeature.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c79177af6622fe81e7952538d9823e65a7b57c4 --- /dev/null +++ b/mosesdecoder/moses/FF/TargetPreferencesFeature.cpp @@ -0,0 +1,408 @@ +#include +#include +#include +#include +#include "TargetPreferencesFeature.h" +#include "moses/StaticData.h" +#include "moses/InputFileStream.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/Hypothesis.h" +#include "moses/ChartHypothesis.h" +#include "moses/ChartManager.h" +#include "moses/FactorCollection.h" +#include "moses/TreeInput.h" +#include "moses/PP/TargetPreferencesPhraseProperty.h" + + +using namespace std; + +namespace Moses +{ + +void TargetPreferencesFeatureState::AddProbabilityForLHSLabel(size_t label, double cost) +{ + std::pair< std::map::iterator, bool > inserted = + m_probabilitiesForLHSLabels.insert(std::pair(label,cost)); + if ( !inserted.second ) { + (inserted.first)->second += cost; + } +} + +void TargetPreferencesFeatureState::NormalizeProbabilitiesForLHSLabels(double denominator) +{ + for ( std::map::iterator iter=m_probabilitiesForLHSLabels.begin(); + iter!=m_probabilitiesForLHSLabels.end(); ++iter ) { + (iter->second) /= denominator; + } +} + +double TargetPreferencesFeatureState::GetProbabilityForLHSLabel(size_t label, bool &isMatch) const +{ + std::map::const_iterator iter = m_probabilitiesForLHSLabels.find(label); + if ( iter != m_probabilitiesForLHSLabels.end() ) { + isMatch = true; + return iter->second; + } + isMatch = false; + return 0; +} + +size_t TargetPreferencesFeatureState::hash() const +{ + if (!m_distinguishStates) { + return 0; + } + size_t ret = 0; + boost::hash_combine(ret, m_probabilitiesForLHSLabels.size()); + for (std::map::const_iterator it=m_probabilitiesForLHSLabels.begin(); + it!=m_probabilitiesForLHSLabels.end(); ++it) { + boost::hash_combine(ret, it->first); + } + return ret; +}; + +bool TargetPreferencesFeatureState::operator==(const FFState& other) const +{ + if (!m_distinguishStates) { + return true; + } + + if (this == &other) { + return true; + } + + const TargetPreferencesFeatureState* otherState = + dynamic_cast(&other); + UTIL_THROW_IF2(otherState == NULL, "Wrong state type"); + + if (m_probabilitiesForLHSLabels.size() != (otherState->m_probabilitiesForLHSLabels).size()) { + return false; + } + std::map::const_iterator thisIt, otherIt; + for (thisIt=m_probabilitiesForLHSLabels.begin(), otherIt=(otherState->m_probabilitiesForLHSLabels).begin(); + thisIt!=m_probabilitiesForLHSLabels.end(); ++thisIt, ++otherIt) { + if (thisIt->first != otherIt->first) { + return false; + } + } + return true; +}; + + +TargetPreferencesFeature::TargetPreferencesFeature(const std::string &line) + : StatefulFeatureFunction(2, line) + , m_featureVariant(0) + , m_distinguishStates(false) + , m_noMismatches(false) +{ + VERBOSE(1, "Initializing feature " << GetScoreProducerDescription() << " ..."); + ReadParameters(); + VERBOSE(1, " Done." << std::endl); + VERBOSE(1, " Feature variant: " << m_featureVariant << "." << std::endl); +} + +TargetPreferencesFeature::~TargetPreferencesFeature() +{} + +void TargetPreferencesFeature::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "label-set-file") { + m_labelSetFile = value; + } else if (key == "unknown-word-labels-file") { + m_unknownLeftHandSideFile = value; + } else if (key == "variant") { + m_featureVariant = Scan(value); + } else if (key == "distinguish-states") { + m_distinguishStates = Scan(value); + } else if (key == "no-mismatches") { + m_noMismatches = Scan(value); + } else { + StatefulFeatureFunction::SetParameter(key, value); + } +} + + +void TargetPreferencesFeature::Load(AllOptions::ptr const& opts) +{ + // don't change the loading order! + LoadLabelSet(); + LoadUnknownLeftHandSideFile(); +} + +void TargetPreferencesFeature::LoadLabelSet() +{ + FEATUREVERBOSE(2, "Loading label set from file " << m_labelSetFile << " ..."); + InputFileStream inFile(m_labelSetFile); + + // read label set + std::string line; + m_labels.clear(); + m_labelsByIndex.clear(); + while (getline(inFile, line)) { + std::istringstream tokenizer(line); + std::string label; + size_t index; + try { + tokenizer >> label >> index; + } catch (const std::exception &e) { + UTIL_THROW2(GetScoreProducerDescription() + << ": Error reading label set file " << m_labelSetFile << " ."); + } + std::pair< boost::unordered_map::iterator, bool > inserted = m_labels.insert( std::pair(label,index) ); + UTIL_THROW_IF2(!inserted.second, GetScoreProducerDescription() + << ": Label set file " << m_labelSetFile << " should contain each label only once."); + + if (index >= m_labelsByIndex.size()) { + m_labelsByIndex.resize(index+1); + } + m_labelsByIndex[index] = label; + } + + inFile.Close(); + + std::list specialLabels; + specialLabels.push_back("GlueTop"); + for (std::list::const_iterator iter=specialLabels.begin(); + iter!=specialLabels.end(); ++iter) { + boost::unordered_map::iterator found = m_labels.find(*iter); + UTIL_THROW_IF2(found == m_labels.end(), GetScoreProducerDescription() + << ": Label set file " << m_labelSetFile << " should contain an entry for the special label \"" << *iter << "\"."); + if (!(found->first).compare("GlueTop")) { + m_GlueTopLabel = found->second; + } + } + FEATUREVERBOSE2(2, " Done." << std::endl); +} + +// Make sure to call this method _after_ LoadLabelSet() +void TargetPreferencesFeature::LoadUnknownLeftHandSideFile() +{ + FEATUREVERBOSE(2, "Loading left-hand side labels for unknowns from file " << m_unknownLeftHandSideFile << std::endl); + InputFileStream inFile(m_unknownLeftHandSideFile); + + // read left-hand side labels for unknowns + std::string line; + m_unknownLHSProbabilities.clear(); + double countsSum = 0.0; + while (getline(inFile, line)) { + istringstream tokenizer(line); + std::string label; + double count; + tokenizer >> label; + tokenizer >> count; + boost::unordered_map::iterator found = m_labels.find( label ); + if ( found != m_labels.end() ) { + std::pair< std::map::iterator, bool > inserted = + m_unknownLHSProbabilities.insert( std::pair(found->second,count) ); + if ( !inserted.second ) { + (inserted.first)->second += count; + } + countsSum += count; + } else { + FEATUREVERBOSE(1, "WARNING: undefined label \"" << label << "\" in file " << m_unknownLeftHandSideFile << std::endl); + } + } + // compute probabilities from counts + countsSum += (double)m_labels.size(); + for (std::map::iterator iter=m_unknownLHSProbabilities.begin(); + iter!=m_unknownLHSProbabilities.end(); ++iter) { + iter->second /= countsSum; + } + + IFFEATUREVERBOSE(3) { + for (std::map::iterator iter=m_unknownLHSProbabilities.begin(); + iter!=m_unknownLHSProbabilities.end(); ++iter) { + FEATUREVERBOSE(3, GetScoreProducerDescription() << "::LoadUnknownLeftHandSideFile(): " << iter->first << " " << iter->second << std::endl); + } + } + + inFile.Close(); +} + +FFState* TargetPreferencesFeature::EvaluateWhenApplied( + const ChartHypothesis& hypo, + int featureID, // used to index the state in the previous hypotheses + ScoreComponentCollection* accumulator) const +{ + streamsize cerr_precision = std::cerr.precision(); + std::cerr.precision(20); // TODO: remove. just for debug purposes. + + // dense scores + std::vector newScores(m_numScoreComponents,0); // m_numScoreComponents == 2 + + // state: used to store tree probabilities of partial hypotheses + // and access the respective tree probabilities of subderivations + TargetPreferencesFeatureState *state = new TargetPreferencesFeatureState(m_distinguishStates); + + size_t nNTs = 1; + double overallTreeProbability = 0.0; + bool isGlueGrammarRule = false; + + // read TargetPreferences property + const TargetPhrase &currTarPhr = hypo.GetCurrTargetPhrase(); + + FEATUREVERBOSE(2, "Phrase: " << currTarPhr << std::endl); + + if (const PhraseProperty *property = currTarPhr.GetProperty("TargetPreferences")) { + + const TargetPreferencesPhraseProperty *targetPreferencesPhraseProperty = static_cast(property); + +// IFFEATUREVERBOSE(2) { +// const std::string *targetPreferencesPhrasePropertyValueString = targetPreferencesPhraseProperty->GetValueString(); +// if (targetPreferencesPhrasePropertyValueString) { +// FEATUREVERBOSE(2, "PreferencesPhraseProperty " << *targetPreferencesPhrasePropertyValueString << std::endl); +// } else { +// FEATUREVERBOSE(2, "PreferencesPhraseProperty NULL" << std::endl); +// } +// } + + nNTs = targetPreferencesPhraseProperty->GetNumberOfNonTerminals(); + double totalCount = targetPreferencesPhraseProperty->GetTotalCount(); + + // get index map for underlying hypotheses + const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = + currTarPhr.GetAlignNonTerm().GetNonTermIndexMap(); + + // retrieve states from previous hypotheses, if any + std::vector< const TargetPreferencesFeatureState* > prevStatesByNonTerminal(nNTs-1); + + if (nNTs > 1) { // rule has right-hand side non-terminals, i.e. it's a hierarchical rule + size_t nonTerminalNumber = 0; + + for (size_t phrasePos=0; phrasePos(prevHypo->GetFFState(featureID)); + prevStatesByNonTerminal[nonTerminalNumber] = prevState; + + IFFEATUREVERBOSE(2) { + // some log output that is not required in any way for the functionality + const std::map &prevHypoTreeProbabilities = + prevStatesByNonTerminal[nonTerminalNumber]->GetProbabilitiesForLHSLabels(); + FEATUREVERBOSE(2, "Previous tree probs:"); + for (std::map::const_iterator iter=prevHypoTreeProbabilities.begin(); + iter!=prevHypoTreeProbabilities.end(); ++iter) { + FEATUREVERBOSE2(2, " " << m_labelsByIndex[iter->first] << " " << iter->second); + } + FEATUREVERBOSE2(2, std::endl); + } + + ++nonTerminalNumber; + } + } + } + + // inspect labelled rule items + + overallTreeProbability = 0.0; + + const std::list &targetPreferencesItems = targetPreferencesPhraseProperty->GetTargetPreferencesItems(); + + for (std::list::const_iterator targetPreferencesItem = targetPreferencesItems.begin(); + targetPreferencesItem != targetPreferencesItems.end(); ++targetPreferencesItem) { + + const std::list &targetPreferencesRHS = targetPreferencesItem->GetTargetPreferencesRHS(); + const std::list< std::pair > &targetPreferencesLHSList = targetPreferencesItem->GetTargetPreferencesLHSList(); + + assert(targetPreferencesRHS.size() == nNTs-1); + + size_t currentTargetLabelsMismatches = nNTs - 1; + double matchingLabelsProbabilityProduct = 1.0; + + size_t nonTerminalNumber=0; + for (std::list::const_iterator targetPreferencesRHSIt = targetPreferencesRHS.begin(); + targetPreferencesRHSIt != targetPreferencesRHS.end(); ++targetPreferencesRHSIt, ++nonTerminalNumber) { + + bool isLabelMatch = false; + double matchingLabelsProbability = + prevStatesByNonTerminal[nonTerminalNumber]->GetProbabilityForLHSLabel(*targetPreferencesRHSIt, + isLabelMatch); + matchingLabelsProbabilityProduct *= matchingLabelsProbability; + + if ( isLabelMatch ) { + currentTargetLabelsMismatches -= 1; + } + } + + FEATUREVERBOSE(2, "matchingLabelsProbabilityProduct = " << matchingLabelsProbabilityProduct << std::endl); + + // LHS labels seen with this RHS + for (std::list< std::pair >::const_iterator targetPreferencesLHSIt = targetPreferencesLHSList.begin(); + targetPreferencesLHSIt != targetPreferencesLHSList.end(); ++targetPreferencesLHSIt) { + + size_t targetPreferenceLHS = targetPreferencesLHSIt->first; + + if ( targetPreferenceLHS == m_GlueTopLabel ) { + isGlueGrammarRule = true; + } + + // proceed with the actual probability computations + double ruleTargetPreferenceCount = targetPreferencesLHSIt->second; + double ruleTargetPreferenceProbability = ruleTargetPreferenceCount / totalCount; + + FEATUREVERBOSE(2, " ruleTargetPreferenceProbability = " << ruleTargetPreferenceProbability << std::endl); + + double weightedTargetPreferenceRuleProbability = ruleTargetPreferenceProbability * matchingLabelsProbabilityProduct; + if ( weightedTargetPreferenceRuleProbability != 0 ) { + state->AddProbabilityForLHSLabel(targetPreferenceLHS, weightedTargetPreferenceRuleProbability); + } + overallTreeProbability += weightedTargetPreferenceRuleProbability; + } + } + + IFFEATUREVERBOSE(2) { + FEATUREVERBOSE(2, "overallTreeProbability = " << overallTreeProbability); + if ( overallTreeProbability > 1.0001 ) { // account for some rounding error + FEATUREVERBOSE2(2, " -- WARNING: overallTreeProbability > 1"); + } + FEATUREVERBOSE2(2, std::endl); + } + + if ( overallTreeProbability != 0 ) { + UTIL_THROW_IF2(!boost::math::isnormal(overallTreeProbability), GetScoreProducerDescription() + << ": Oops. Numerical precision issues."); + state->NormalizeProbabilitiesForLHSLabels(overallTreeProbability); + } + + } else { + + // abort with error message if the phrase does not translate an unknown word + UTIL_THROW_IF2(!currTarPhr.GetWord(0).IsOOV(), GetScoreProducerDescription() + << ": Missing TargetPreferences property. Please check phrase table and glue rules."); + + // unknown word + overallTreeProbability = 1.0; + + for (std::map::const_iterator iter=m_unknownLHSProbabilities.begin(); + iter!=m_unknownLHSProbabilities.end(); ++iter) { + // update state + state->AddProbabilityForLHSLabel(iter->first, iter->second); + } + } + + FEATUREVERBOSE(2, "-> OVERALLTREEPROB = " << overallTreeProbability << std::endl); + + // add scores + + // tree probability (preference grammar style) + newScores[0] = (overallTreeProbability == 0 ? 0 : std::log(overallTreeProbability) ); + if ( m_noMismatches && (overallTreeProbability == 0) && !isGlueGrammarRule ) { + newScores[0] = -std::numeric_limits::infinity(); + } + // tree mismatch penalty + // TODO: deactivate the tree mismatch penalty score component automatically if feature configuration parameter no-mismatches=true + newScores[1] = (overallTreeProbability == 0 ? 1 : 0 ); + + accumulator->PlusEquals(this, newScores); + + std::cerr.precision(cerr_precision); + return state; +} + +} + diff --git a/mosesdecoder/moses/FF/TargetWordInsertionFeature.h b/mosesdecoder/moses/FF/TargetWordInsertionFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..d06f3248145240db5949132a83ca0ab676491746 --- /dev/null +++ b/mosesdecoder/moses/FF/TargetWordInsertionFeature.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +#include "StatelessFeatureFunction.h" +#include "moses/FactorCollection.h" +#include "moses/AlignmentInfo.h" + +namespace Moses +{ + +/** Sets the features for length of source phrase, target phrase, both. + */ +class TargetWordInsertionFeature : public StatelessFeatureFunction +{ +private: + boost::unordered_set m_vocab; + FactorType m_factorType; + bool m_unrestricted; + std::string m_filename; + +public: + TargetWordInsertionFeature(const std::string &line); + + bool IsUseable(const FactorMask &mask) const; + + void Load(AllOptions::ptr const& opts); + + virtual void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const; + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const { + } + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + + void ComputeFeatures(const Phrase &source, + const TargetPhrase& targetPhrase, + ScoreComponentCollection* accumulator, + const AlignmentInfo &alignmentInfo) const; + void SetParameter(const std::string& key, const std::string& value); + +}; + +} + diff --git a/mosesdecoder/moses/FF/TreeStructureFeature.cpp b/mosesdecoder/moses/FF/TreeStructureFeature.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b83ef81ea1a0be1cdb52bf94dc344ecf8417e5f7 --- /dev/null +++ b/mosesdecoder/moses/FF/TreeStructureFeature.cpp @@ -0,0 +1,74 @@ +#include "TreeStructureFeature.h" +#include "moses/StaticData.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/ChartHypothesis.h" +#include +#include "moses/PP/TreeStructurePhraseProperty.h" + +namespace Moses +{ + +void TreeStructureFeature::Load(AllOptions::ptr const& opts) +{ + m_options = opts; + + // syntactic constraints can be hooked in here. + m_constraints = NULL; + + StaticData &staticData = StaticData::InstanceNonConst(); + staticData.SetTreeStructure(this); +} + + +FFState* TreeStructureFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hypo + , int featureID /* used to index the state in the previous hypotheses */ + , ScoreComponentCollection* accumulator) const +{ + if (const PhraseProperty *property = cur_hypo.GetCurrTargetPhrase().GetProperty("Tree")) { + const std::string *tree = property->GetValueString(); + TreePointer mytree (boost::make_shared(*tree)); + + //get subtrees (in target order) + std::vector previous_trees; + for (size_t pos = 0; pos < cur_hypo.GetCurrTargetPhrase().GetSize(); ++pos) { + const Word &word = cur_hypo.GetCurrTargetPhrase().GetWord(pos); + if (word.IsNonTerminal()) { + size_t nonTermInd = cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap()[pos]; + const ChartHypothesis *prevHypo = cur_hypo.GetPrevHypo(nonTermInd); + const TreeState* prev = static_cast(prevHypo->GetFFState(featureID)); + const TreePointer prev_tree = prev->GetTree(); + previous_trees.push_back(prev_tree); + } + } + + if (m_constraints) { + m_constraints->SyntacticRules(mytree, previous_trees, this, accumulator); + } + mytree->Combine(previous_trees); + + bool full_sentence = (mytree->GetChildren().back()->GetLabel() == m_send || (mytree->GetChildren().back()->GetLabel() == m_send_nt && mytree->GetChildren().back()->GetChildren().back()->GetLabel() == m_send)); + if (m_binarized && full_sentence) { + mytree->Unbinarize(); + } + + return new TreeState(mytree); + } else { + UTIL_THROW2("Error: TreeStructureFeature active, but no internal tree structure found"); + } + +} + +void TreeStructureFeature::SetParameter(const std::string& key, const std::string& value) +{ + std::cerr << "setting: " << this->GetScoreProducerDescription() << " - " << key << "\n"; + if (key == "tuneable") { + m_tuneable = Scan(value); + } else if (key == "filterable") { //ignore + } else if (key == "binarized") { // if trees have been binarized before learning translation model; output unbinarized trees + m_binarized = true; + } else { + UTIL_THROW(util::Exception, "Unknown argument " << key << "=" << value); + } +} + +} diff --git a/mosesdecoder/moses/FF/TreeStructureFeature.h b/mosesdecoder/moses/FF/TreeStructureFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..366b84fd21a7be27b6077b2dd4381f3562f3856f --- /dev/null +++ b/mosesdecoder/moses/FF/TreeStructureFeature.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include +#include "StatefulFeatureFunction.h" +#include "FFState.h" +#include "moses/Word.h" +#include "InternalTree.h" + +namespace Moses +{ + +typedef int NTLabel; + + +// mapping from string nonterminal label to int representation. +// allows abstraction if multiple nonterminal strings should map to same label. +struct LabelSet { +public: + std::map string_to_label; +}; + + +// class to implement language-specific syntactic constraints. +// the method SyntacticRules is given pointer to ScoreComponentCollection, so it can add sparse features itself. +class SyntaxConstraints +{ +public: + virtual void SyntacticRules(TreePointer root, const std::vector &previous, const FeatureFunction* sp, ScoreComponentCollection* accumulator) = 0; + virtual ~SyntaxConstraints() {}; +}; + + +class TreeStructureFeature : public StatefulFeatureFunction +{ + SyntaxConstraints* m_constraints; + LabelSet* m_labelset; + bool m_binarized; + Word m_send; + Word m_send_nt; + +public: + TreeStructureFeature(const std::string &line) + :StatefulFeatureFunction(0, line) + , m_binarized(false) { + ReadParameters(); + std::vector factors; + factors.push_back(0); + m_send.CreateFromString(Output, factors, "", false); + m_send_nt.CreateFromString(Output, factors, "SEND", true); + } + ~TreeStructureFeature() { + delete m_constraints; + }; + + virtual const FFState* EmptyHypothesisState(const InputType &input) const { + return new TreeState(TreePointer()); + } + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + void SetParameter(const std::string& key, const std::string& value); + + FFState* EvaluateWhenApplied( + const Hypothesis& cur_hypo, + const FFState* prev_state, + ScoreComponentCollection* accumulator) const { + UTIL_THROW(util::Exception, "Not implemented"); + }; + FFState* EvaluateWhenApplied( + const ChartHypothesis& /* cur_hypo */, + int /* featureID - used to index the state in the previous hypotheses */, + ScoreComponentCollection* accumulator) const; + + void Load(AllOptions::ptr const& opts); +}; + + +} diff --git a/mosesdecoder/moses/FF/UnalignedWordCountFeature.h b/mosesdecoder/moses/FF/UnalignedWordCountFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..807b2683d14d61ff30475146ddfdbafcb4393ff6 --- /dev/null +++ b/mosesdecoder/moses/FF/UnalignedWordCountFeature.h @@ -0,0 +1,47 @@ +#pragma once + +#include "StatelessFeatureFunction.h" +#include "moses/FactorCollection.h" +#include "moses/AlignmentInfo.h" + +namespace Moses +{ + +class UnalignedWordCountFeature : public StatelessFeatureFunction +{ +public: + UnalignedWordCountFeature(const std::string &line); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const; + + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const { + } + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + +}; + +} + diff --git a/mosesdecoder/moses/FF/UnknownWordPenaltyProducer.cpp b/mosesdecoder/moses/FF/UnknownWordPenaltyProducer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..593c95a774c5b22aee09ba555bd01905b167a77e --- /dev/null +++ b/mosesdecoder/moses/FF/UnknownWordPenaltyProducer.cpp @@ -0,0 +1,29 @@ +#include +#include +#include "UnknownWordPenaltyProducer.h" +#include "util/exception.hh" + +using namespace std; + +namespace Moses +{ +UnknownWordPenaltyProducer *UnknownWordPenaltyProducer::s_instance = NULL; + +UnknownWordPenaltyProducer::UnknownWordPenaltyProducer(const std::string &line) + : StatelessFeatureFunction(1, line) +{ + m_tuneable = false; + ReadParameters(); + + UTIL_THROW_IF2(s_instance, "Can only have 1 unknown word penalty feature"); + s_instance = this; +} + +std::vector UnknownWordPenaltyProducer::DefaultWeights() const +{ + std::vector ret(1, 1.0f); + return ret; +} + +} + diff --git a/mosesdecoder/moses/FF/WordPenaltyProducer.h b/mosesdecoder/moses/FF/WordPenaltyProducer.h new file mode 100644 index 0000000000000000000000000000000000000000..34159eb9c1a8659709bd00c2ddf1678599a81018 --- /dev/null +++ b/mosesdecoder/moses/FF/WordPenaltyProducer.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include "StatelessFeatureFunction.h" + +namespace Moses +{ +class TargetPhrase; +class ScoreComponentCollection; + +class WordPenaltyProducer : public StatelessFeatureFunction +{ +protected: + static WordPenaltyProducer *s_instance; + +public: + static const WordPenaltyProducer& Instance() { + return *s_instance; + } + static WordPenaltyProducer& InstanceNonConst() { + return *s_instance; + } + + WordPenaltyProducer(const std::string &line); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + virtual void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const; + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + void EvaluateWhenApplied(const Syntax::SHyperedge &hyperedge, + ScoreComponentCollection* accumulator) const { + } + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const { + } + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + + + + /* + virtual void Evaluate(const InputType &source + , ScoreComponentCollection &scoreBreakdown) const; + */ +}; + +} + diff --git a/mosesdecoder/moses/FF/WordTranslationFeature.h b/mosesdecoder/moses/FF/WordTranslationFeature.h new file mode 100644 index 0000000000000000000000000000000000000000..b3c3c18e221df081bc38bb6b5c54fb7c724dede2 --- /dev/null +++ b/mosesdecoder/moses/FF/WordTranslationFeature.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#include + +#include "moses/FactorCollection.h" +#include "moses/Sentence.h" +#include "StatelessFeatureFunction.h" + +namespace Moses +{ + +/** Sets the features for word translation + */ +class WordTranslationFeature : public StatelessFeatureFunction +{ + + typedef std::map< char, short > CharHash; + typedef std::vector< boost::unordered_set > DocumentVector; + +private: + boost::unordered_set m_vocabSource; + boost::unordered_set m_vocabTarget; + DocumentVector m_vocabDomain; + FactorType m_factorTypeSource; + FactorType m_factorTypeTarget; + bool m_unrestricted; + bool m_simple; + bool m_sourceContext; + bool m_targetContext; + bool m_domainTrigger; + bool m_ignorePunctuation; + CharHash m_punctuationHash; + std::string m_filePathSource; + std::string m_filePathTarget; + +public: + WordTranslationFeature(const std::string &line); + + void SetParameter(const std::string& key, const std::string& value); + bool IsUseable(const FactorMask &mask) const; + + void Load(AllOptions::ptr const& opts); + + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedScores = NULL) const; + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedScores) const { + } + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const { + } + + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const { + } + + void EvaluateTranslationOptionListWithSourceContext(const InputType &input + , const TranslationOptionList &translationOptionList) const { + } + + +}; + +} + diff --git a/mosesdecoder/moses/Syntax/S2T/DerivationWriter.cpp b/mosesdecoder/moses/Syntax/S2T/DerivationWriter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..159763a463ab117042844d072c15e9b3f57e3b51 --- /dev/null +++ b/mosesdecoder/moses/Syntax/S2T/DerivationWriter.cpp @@ -0,0 +1,100 @@ +#include "DerivationWriter.h" + +#include "moses/Factor.h" +#include "moses/Syntax/PVertex.h" +#include "moses/Syntax/SHyperedge.h" + +namespace Moses +{ +namespace Syntax +{ +namespace S2T +{ + +// 1-best version. +void DerivationWriter::Write(const SHyperedge ­peredge, + std::size_t sentNum, std::ostream &out) +{ + WriteLine(shyperedge, sentNum, out); + for (std::size_t i = 0; i < shyperedge.tail.size(); ++i) { + const SVertex &pred = *(shyperedge.tail[i]); + if (pred.best) { + Write(*pred.best, sentNum, out); + } + } +} + +// k-best derivation. +void DerivationWriter::Write(const KBestExtractor::Derivation &derivation, + std::size_t sentNum, std::ostream &out) +{ + WriteLine(derivation.edge->shyperedge, sentNum, out); + for (std::size_t i = 0; i < derivation.subderivations.size(); ++i) { + Write(*(derivation.subderivations[i]), sentNum, out); + } +} + +void DerivationWriter::WriteLine(const SHyperedge ­peredge, + std::size_t sentNum, std::ostream &out) +{ + // Sentence number. + out << sentNum << " |||"; + + // Source LHS. + out << " [X] ->"; + + // Source RHS symbols. + for (std::size_t i = 0; i < shyperedge.tail.size(); ++i) { + const Word &symbol = shyperedge.tail[i]->pvertex->symbol; + out << " "; + if (symbol.IsNonTerminal()) { + out << "[X]"; + } else { + WriteSymbol(symbol, out); + } + } + out << " |||"; + + // Target RHS. + out << " "; + WriteSymbol(shyperedge.head->pvertex->symbol, out); + out << " ->"; + + // Target RHS symbols. + const TargetPhrase &phrase = *(shyperedge.label.translation); + for (std::size_t i = 0; i < phrase.GetSize(); ++i) { + out << " "; + WriteSymbol(phrase.GetWord(i), out); + } + out << " |||"; + + // Non-terminal alignments + const AlignmentInfo &a = phrase.GetAlignNonTerm(); + for (AlignmentInfo::const_iterator p = a.begin(); p != a.end(); ++p) { + out << " " << p->first << "-" << p->second; + } + out << " |||"; + + // Spans covered by source RHS symbols. + for (std::size_t i = 0; i < shyperedge.tail.size(); ++i) { + const SVertex *child = shyperedge.tail[i]; + const Range &span = child->pvertex->span; + out << " " << span.GetStartPos() << ".." << span.GetEndPos(); + } + + out << "\n"; +} + +void DerivationWriter::WriteSymbol(const Word &symbol, std::ostream &out) +{ + const Factor *f = symbol[0]; + if (symbol.IsNonTerminal()) { + out << "[" << f->GetString() << "]"; + } else { + out << f->GetString(); + } +} + +} // namespace S2T +} // namespace Syntax +} // namespace Moses diff --git a/mosesdecoder/moses/Syntax/S2T/Manager-inl.h b/mosesdecoder/moses/Syntax/S2T/Manager-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..7c9c3318329ffb5abafed7c9d201c58a9488bef2 --- /dev/null +++ b/mosesdecoder/moses/Syntax/S2T/Manager-inl.h @@ -0,0 +1,408 @@ +// -*- c++ -*- +#pragma once + +#include +#include + +#include "moses/DecodeGraph.h" +#include "moses/StaticData.h" +#include "moses/Syntax/BoundedPriorityContainer.h" +#include "moses/Syntax/CubeQueue.h" +#include "moses/Syntax/PHyperedge.h" +#include "moses/Syntax/RuleTable.h" +#include "moses/Syntax/RuleTableFF.h" +#include "moses/Syntax/SHyperedgeBundle.h" +#include "moses/Syntax/SVertex.h" +#include "moses/Syntax/SVertexRecombinationEqualityPred.h" +#include "moses/Syntax/SVertexRecombinationHasher.h" +#include "moses/Syntax/SymbolEqualityPred.h" +#include "moses/Syntax/SymbolHasher.h" + +#include "DerivationWriter.h" +#include "OovHandler.h" +#include "PChart.h" +#include "RuleTrie.h" +#include "SChart.h" + +namespace Moses +{ +namespace Syntax +{ +namespace S2T +{ + +template +Manager::Manager(ttasksptr const& ttask) + : Syntax::Manager(ttask) + , m_pchart(m_source.GetSize(), Parser::RequiresCompressedChart()) + , m_schart(m_source.GetSize()) +{ } + +template +void Manager::InitializeCharts() +{ + // Create a PVertex object and a SVertex object for each source word. + for (std::size_t i = 0; i < m_source.GetSize(); ++i) { + const Word &terminal = m_source.GetWord(i); + + // PVertex + PVertex tmp(Range(i,i), terminal); + PVertex &pvertex = m_pchart.AddVertex(tmp); + + // SVertex + boost::shared_ptr v(new SVertex()); + v->best = 0; + v->pvertex = &pvertex; + SChart::Cell &scell = m_schart.GetCell(i,i); + SVertexStack stack(1, v); + SChart::Cell::TMap::value_type x(terminal, stack); + scell.terminalStacks.insert(x); + } +} + +template +void Manager::InitializeParsers(PChart &pchart, + std::size_t ruleLimit) +{ + const std::vector &ffs = RuleTableFF::Instances(); + + const std::vector &graphs = + StaticData::Instance().GetDecodeGraphs(); + + UTIL_THROW_IF2(ffs.size() != graphs.size(), + "number of RuleTables does not match number of decode graphs"); + + for (std::size_t i = 0; i < ffs.size(); ++i) { + RuleTableFF *ff = ffs[i]; + std::size_t maxChartSpan = graphs[i]->GetMaxChartSpan(); + // This may change in the future, but currently we assume that every + // RuleTableFF is associated with a static, file-based rule table of + // some sort and that the table should have been loaded into a RuleTable + // by this point. + const RuleTable *table = ff->GetTable(); + assert(table); + RuleTable *nonConstTable = const_cast(table); + boost::shared_ptr parser; + typename Parser::RuleTrie *trie = + dynamic_cast(nonConstTable); + assert(trie); + parser.reset(new Parser(pchart, *trie, maxChartSpan)); + m_parsers.push_back(parser); + } + + // Check for OOVs and synthesize an additional rule trie + parser if + // necessary. + m_oovs.clear(); + std::size_t maxOovWidth = 0; + FindOovs(pchart, m_oovs, maxOovWidth); + if (!m_oovs.empty()) { + // FIXME Add a hidden RuleTableFF for unknown words(?) + OovHandler oovHandler(*ffs[0]); + m_oovRuleTrie = oovHandler.SynthesizeRuleTrie(m_oovs.begin(), m_oovs.end()); + // Create a parser for the OOV rule trie. + boost::shared_ptr parser( + new Parser(pchart, *m_oovRuleTrie, maxOovWidth)); + m_parsers.push_back(parser); + } +} + +// Find the set of OOVs for this input. This function assumes that the +// PChart argument has already been initialized from the input. +template +void Manager::FindOovs(const PChart &pchart, boost::unordered_set &oovs, + std::size_t maxOovWidth) +{ + // Get the set of RuleTries. + std::vector tries; + const std::vector &ffs = RuleTableFF::Instances(); + for (std::size_t i = 0; i < ffs.size(); ++i) { + const RuleTableFF *ff = ffs[i]; + if (ff->GetTable()) { + const RuleTrie *trie = dynamic_cast(ff->GetTable()); + assert(trie); // FIXME + tries.push_back(trie); + } + } + + // For every sink vertex in pchart (except for and ), check whether + // the word has a preterminal rule in any of the rule tables. If not then + // add it to the OOV set. + oovs.clear(); + maxOovWidth = 0; + // Assume and have been added at sentence boundaries, so skip + // cells starting at position 0 and ending at the last position. + for (std::size_t i = 1; i < pchart.GetWidth()-1; ++i) { + for (std::size_t j = i; j < pchart.GetWidth()-1; ++j) { + std::size_t width = j-i+1; + const PChart::Cell::TMap &map = pchart.GetCell(i,j).terminalVertices; + for (PChart::Cell::TMap::const_iterator p = map.begin(); + p != map.end(); ++p) { + const Word &word = p->first; + assert(!word.IsNonTerminal()); + bool found = false; + for (std::vector::const_iterator q = tries.begin(); + q != tries.end(); ++q) { + const RuleTrie *trie = *q; + if (trie->HasPreterminalRule(word)) { + found = true; + break; + } + } + if (!found) { + oovs.insert(word); + maxOovWidth = std::max(maxOovWidth, width); + } + } + } + } +} + +template +void Manager::Decode() +{ + // Get various pruning-related constants. + const std::size_t popLimit = options()->cube.pop_limit; + const std::size_t ruleLimit = options()->syntax.rule_limit; + const std::size_t stackLimit = options()->search.stack_size; + + // Initialise the PChart and SChart. + InitializeCharts(); + + // Initialize the parsers. + InitializeParsers(m_pchart, ruleLimit); + + // Create a callback to process the PHyperedges produced by the parsers. + typename Parser::CallbackType callback(m_schart, ruleLimit); + + // Visit each cell of PChart in right-to-left depth-first order. + std::size_t size = m_source.GetSize(); + for (int start = size-1; start >= 0; --start) { + for (std::size_t width = 1; width <= size-start; ++width) { + std::size_t end = start + width - 1; + + //PChart::Cell &pcell = m_pchart.GetCell(start, end); + SChart::Cell &scell = m_schart.GetCell(start, end); + + Range range(start, end); + + // Call the parsers to generate PHyperedges for this span and convert + // each one to a SHyperedgeBundle (via the callback). The callback + // prunes the SHyperedgeBundles and keeps the best ones (up to ruleLimit). + callback.InitForRange(range); + for (typename std::vector >::iterator + p = m_parsers.begin(); p != m_parsers.end(); ++p) { + (*p)->EnumerateHyperedges(range, callback); + } + + // Retrieve the (pruned) set of SHyperedgeBundles from the callback. + const BoundedPriorityContainer &bundles = + callback.GetContainer(); + + // Use cube pruning to extract SHyperedges from SHyperedgeBundles. + // Collect the SHyperedges into buffers, one for each category. + CubeQueue cubeQueue(bundles.Begin(), bundles.End()); + std::size_t count = 0; + typedef boost::unordered_map, + SymbolHasher, SymbolEqualityPred > BufferMap; + BufferMap buffers; + while (count < popLimit && !cubeQueue.IsEmpty()) { + SHyperedge *hyperedge = cubeQueue.Pop(); + // BEGIN{HACK} + // The way things currently work, the LHS of each hyperedge is not + // determined until just before the point of its creation, when a + // target phrase is selected from the list of possible phrases (which + // happens during cube pruning). The cube pruning code doesn't (and + // shouldn't) know about the contents of PChart and so creation of + // the PVertex is deferred until this point. + const Word &lhs = hyperedge->label.translation->GetTargetLHS(); + hyperedge->head->pvertex = &m_pchart.AddVertex(PVertex(range, lhs)); + // END{HACK} + buffers[lhs].push_back(hyperedge); + ++count; + } + + // Recombine SVertices and sort into stacks. + for (BufferMap::const_iterator p = buffers.begin(); p != buffers.end(); + ++p) { + const Word &category = p->first; + const std::vector &buffer = p->second; + std::pair ret = + scell.nonTerminalStacks.Insert(category, SVertexStack()); + assert(ret.second); + SVertexStack &stack = ret.first->second; + RecombineAndSort(buffer, stack); + } + + // Prune stacks. + if (stackLimit > 0) { + for (SChart::Cell::NMap::Iterator p = scell.nonTerminalStacks.Begin(); + p != scell.nonTerminalStacks.End(); ++p) { + SVertexStack &stack = p->second; + if (stack.size() > stackLimit) { + stack.resize(stackLimit); + } + } + } + + // Prune the PChart cell for this span by removing vertices for + // categories that don't occur in the SChart. +// Note: see HACK above. Pruning the chart isn't currently necessary. +// PrunePChart(scell, pcell); + } + } +} + +template +const SHyperedge *Manager::GetBestSHyperedge() const +{ + const SChart::Cell &cell = m_schart.GetCell(0, m_source.GetSize()-1); + const SChart::Cell::NMap &stacks = cell.nonTerminalStacks; + if (stacks.Size() == 0) { + return 0; + } + assert(stacks.Size() == 1); + const std::vector > &stack = stacks.Begin()->second; + // TODO Throw exception if stack is empty? Or return 0? + return stack[0]->best; +} + +template +void Manager::ExtractKBest( + std::size_t k, + std::vector > &kBestList, + bool onlyDistinct) const +{ + kBestList.clear(); + if (k == 0 || m_source.GetSize() == 0) { + return; + } + + // Get the top-level SVertex stack. + const SChart::Cell &cell = m_schart.GetCell(0, m_source.GetSize()-1); + const SChart::Cell::NMap &stacks = cell.nonTerminalStacks; + if (stacks.Size() == 0) { + return; + } + assert(stacks.Size() == 1); + const std::vector > &stack = stacks.Begin()->second; + // TODO Throw exception if stack is empty? Or return 0? + + KBestExtractor extractor; + + if (!onlyDistinct) { + // Return the k-best list as is, including duplicate translations. + extractor.Extract(stack, k, kBestList); + return; + } + + // Determine how many derivations to extract. If the k-best list is + // restricted to distinct translations then this limit should be bigger + // than k. The k-best factor determines how much bigger the limit should be, + // with 0 being 'unlimited.' This actually sets a large-ish limit in case + // too many translations are identical. + const StaticData &staticData = StaticData::Instance(); + const std::size_t nBestFactor = staticData.options()->nbest.factor; + std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor; + + // Extract the derivations. + KBestExtractor::KBestVec bigList; + bigList.reserve(numDerivations); + extractor.Extract(stack, numDerivations, bigList); + + // Copy derivations into kBestList, skipping ones with repeated translations. + std::set distinct; + for (KBestExtractor::KBestVec::const_iterator p = bigList.begin(); + kBestList.size() < k && p != bigList.end(); ++p) { + boost::shared_ptr derivation = *p; + Phrase translation = KBestExtractor::GetOutputPhrase(*derivation); + if (distinct.insert(translation).second) { + kBestList.push_back(derivation); + } + } +} + +template +void Manager::PrunePChart(const SChart::Cell &scell, + PChart::Cell &pcell) +{ + /* FIXME + PChart::Cell::VertexMap::iterator p = pcell.vertices.begin(); + while (p != pcell.vertices.end()) { + const Word &category = p->first; + if (scell.stacks.find(category) == scell.stacks.end()) { + PChart::Cell::VertexMap::iterator q = p++; + pcell.vertices.erase(q); + } else { + ++p; + } + } + */ +} + +template +void Manager::RecombineAndSort(const std::vector &buffer, + SVertexStack &stack) +{ + // Step 1: Create a map containing a single instance of each distinct vertex + // (where distinctness is defined by the state value). The hyperedges' + // head pointers are updated to point to the vertex instances in the map and + // any 'duplicate' vertices are deleted. +// TODO Set? + typedef boost::unordered_map Map; + Map map; + for (std::vector::const_iterator p = buffer.begin(); + p != buffer.end(); ++p) { + SHyperedge *h = *p; + SVertex *v = h->head; + assert(v->best == h); + assert(v->recombined.empty()); + std::pair result = map.insert(Map::value_type(v, v)); + if (result.second) { + continue; // v's recombination value hasn't been seen before. + } + // v is a duplicate (according to the recombination rules). + // Compare the score of h against the score of the best incoming hyperedge + // for the stored vertex. + SVertex *storedVertex = result.first->second; + if (h->label.futureScore > storedVertex->best->label.futureScore) { + // h's score is better. + storedVertex->recombined.push_back(storedVertex->best); + storedVertex->best = h; + } else { + storedVertex->recombined.push_back(h); + } + h->head->best = 0; + delete h->head; + h->head = storedVertex; + } + + // Step 2: Copy the vertices from the map to the stack. + stack.clear(); + stack.reserve(map.size()); + for (Map::const_iterator p = map.begin(); p != map.end(); ++p) { + stack.push_back(boost::shared_ptr(p->first)); + } + + // Step 3: Sort the vertices in the stack. + std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer()); +} + +template +void Manager::OutputDetailedTranslationReport( + OutputCollector *collector) const +{ + const SHyperedge *best = GetBestSHyperedge(); + if (best == NULL || collector == NULL) { + return; + } + long translationId = m_source.GetTranslationId(); + std::ostringstream out; + DerivationWriter::Write(*best, translationId, out); + collector->Write(translationId, out.str()); +} + +} // S2T +} // Syntax +} // Moses diff --git a/mosesdecoder/moses/Syntax/S2T/Manager.h b/mosesdecoder/moses/Syntax/S2T/Manager.h new file mode 100644 index 0000000000000000000000000000000000000000..b0e6555cf6bfeceff7a79c174f4727d13b992aab --- /dev/null +++ b/mosesdecoder/moses/Syntax/S2T/Manager.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include + +#include + +#include "moses/InputType.h" +#include "moses/Syntax/KBestExtractor.h" +#include "moses/Syntax/Manager.h" +#include "moses/Syntax/SVertexStack.h" +#include "moses/Word.h" + +#include "OovHandler.h" +#include "ParserCallback.h" +#include "PChart.h" +#include "SChart.h" + +namespace Moses +{ +namespace Syntax +{ + +struct SHyperedge; + +namespace S2T +{ + +template +class Manager : public Syntax::Manager +{ +public: + Manager(ttasksptr const& ttask); + + void Decode(); + + // Get the SHyperedge for the 1-best derivation. + const SHyperedge *GetBestSHyperedge() const; + + void ExtractKBest( + std::size_t k, + std::vector > &kBestList, + bool onlyDistinct=false) const; + + void OutputDetailedTranslationReport(OutputCollector *collector) const; + +private: + void FindOovs(const PChart &, boost::unordered_set &, std::size_t); + + void InitializeCharts(); + + void InitializeParsers(PChart &, std::size_t); + + void RecombineAndSort(const std::vector &, SVertexStack &); + + void PrunePChart(const SChart::Cell &, PChart::Cell &); + + PChart m_pchart; + SChart m_schart; + boost::shared_ptr m_oovRuleTrie; + std::vector > m_parsers; +}; + +} // S2T +} // Syntax +} // Moses + +// Implementation +#include "Manager-inl.h" diff --git a/mosesdecoder/moses/Syntax/S2T/OovHandler.h b/mosesdecoder/moses/Syntax/S2T/OovHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..5d484d2fd9aa3ad7405b2ebd51169101fab56360 --- /dev/null +++ b/mosesdecoder/moses/Syntax/S2T/OovHandler.h @@ -0,0 +1,50 @@ +#pragma once + +#include + +#include + +#include "moses/Phrase.h" +#include "moses/Syntax/RuleTableFF.h" +#include "moses/TargetPhrase.h" +#include "moses/Word.h" + +#include "RuleTrieCreator.h" + +namespace Moses +{ +namespace Syntax +{ +namespace S2T +{ + +template +class OovHandler : public RuleTrieCreator +{ +public: + OovHandler(const RuleTableFF &ff) : m_ruleTableFF(ff) {} + + // Synthesize a RuleTrie given a sequence of OOV words. The sequence is + // specified by a pair of iterators (indicating the beginning and end). It + // is assumed not to contain duplicates. + template + boost::shared_ptr SynthesizeRuleTrie(InputIterator, InputIterator); + +private: + const RuleTableFF &m_ruleTableFF; + + bool ShouldDrop(const Word &); + + Phrase *SynthesizeSourcePhrase(const Word &); + + Word *SynthesizeTargetLhs(const std::string &); + + TargetPhrase *SynthesizeTargetPhrase(const Word &, const Phrase &, + const Word &, float); +}; + +} // S2T +} // Syntax +} // Moses + +#include "OovHandler-inl.h" diff --git a/mosesdecoder/moses/Syntax/S2T/ParserCallback.h b/mosesdecoder/moses/Syntax/S2T/ParserCallback.h new file mode 100644 index 0000000000000000000000000000000000000000..2314b27f37172c3aa1ec7ec095d83b872dfeac98 --- /dev/null +++ b/mosesdecoder/moses/Syntax/S2T/ParserCallback.h @@ -0,0 +1,91 @@ +#pragma once + +#include "moses/Syntax/BoundedPriorityContainer.h" +#include "moses/Syntax/PHyperedge.h" +#include "moses/Syntax/PVertex.h" +#include "moses/Syntax/SHyperedgeBundle.h" +#include "moses/Syntax/SHyperedgeBundleScorer.h" + +#include "PHyperedgeToSHyperedgeBundle.h" + +namespace Moses +{ +namespace Syntax +{ +namespace S2T +{ + +class StandardParserCallback +{ +private: + typedef BoundedPriorityContainer Container; + +public: + StandardParserCallback(const SChart &schart, std::size_t ruleLimit) + : m_schart(schart) + , m_container(ruleLimit) {} + + void operator()(const PHyperedge &hyperedge) { + PHyperedgeToSHyperedgeBundle(hyperedge, m_schart, m_tmpBundle); + float score = SHyperedgeBundleScorer::Score(m_tmpBundle); + m_container.SwapIn(m_tmpBundle, score); + } + + void InitForRange(const Range &range) { + m_container.LazyClear(); + } + + const Container &GetContainer() { + return m_container; + } + +private: + const SChart &m_schart; + SHyperedgeBundle m_tmpBundle; + BoundedPriorityContainer m_container; +}; + +class EagerParserCallback +{ +private: + typedef BoundedPriorityContainer Container; + +public: + EagerParserCallback(const SChart &schart, std::size_t ruleLimit) + : m_schart(schart) + , m_containers(schart.GetWidth(), Container(ruleLimit)) + , m_prevStart(std::numeric_limits::max()) {} + + void operator()(const PHyperedge &hyperedge, std::size_t end) { + PHyperedgeToSHyperedgeBundle(hyperedge, m_schart, m_tmpBundle); + float score = SHyperedgeBundleScorer::Score(m_tmpBundle); + m_containers[end].SwapIn(m_tmpBundle, score); + } + + void InitForRange(const Range &range) { + const std::size_t start = range.GetStartPos(); + m_end = range.GetEndPos(); + if (start != m_prevStart) { + for (std::vector::iterator p = m_containers.begin(); + p != m_containers.end(); ++p) { + p->LazyClear(); + } + m_prevStart = start; + } + } + + const Container &GetContainer() { + return m_containers[m_end]; + } + +private: + const SChart &m_schart; + SHyperedgeBundle m_tmpBundle; + std::vector m_containers; + std::size_t m_end; + std::size_t m_prevStart; +}; + +} // S2T +} // Syntax +} // Moses diff --git a/mosesdecoder/moses/Syntax/S2T/RuleTrieCYKPlus.cpp b/mosesdecoder/moses/Syntax/S2T/RuleTrieCYKPlus.cpp new file mode 100644 index 0000000000000000000000000000000000000000..68da5f5b753c82480ac88b0b56fd8fca8ff1933c --- /dev/null +++ b/mosesdecoder/moses/Syntax/S2T/RuleTrieCYKPlus.cpp @@ -0,0 +1,146 @@ +#include "RuleTrieCYKPlus.h" + +#include +#include + +#include +#include +#include + +#include "moses/NonTerminal.h" +#include "moses/TargetPhrase.h" +#include "moses/TargetPhraseCollection.h" +#include "moses/Util.h" +#include "moses/Word.h" + +namespace Moses +{ +namespace Syntax +{ +namespace S2T +{ + +void RuleTrieCYKPlus::Node::Prune(std::size_t tableLimit) +{ + // recusively prune + for (SymbolMap::iterator p = m_sourceTermMap.begin(); + p != m_sourceTermMap.end(); ++p) { + p->second.Prune(tableLimit); + } + for (SymbolMap::iterator p = m_nonTermMap.begin(); + p != m_nonTermMap.end(); ++p) { + p->second.Prune(tableLimit); + } + + // prune TargetPhraseCollection in this node + m_targetPhraseCollection->Prune(true, tableLimit); +} + +void RuleTrieCYKPlus::Node::Sort(std::size_t tableLimit) +{ + // recusively sort + for (SymbolMap::iterator p = m_sourceTermMap.begin(); + p != m_sourceTermMap.end(); ++p) { + p->second.Sort(tableLimit); + } + for (SymbolMap::iterator p = m_nonTermMap.begin(); + p != m_nonTermMap.end(); ++p) { + p->second.Sort(tableLimit); + } + + // prune TargetPhraseCollection in this node + m_targetPhraseCollection->Sort(true, tableLimit); +} + +RuleTrieCYKPlus::Node *RuleTrieCYKPlus::Node::GetOrCreateChild( + const Word &sourceTerm) +{ + return &m_sourceTermMap[sourceTerm]; +} + +RuleTrieCYKPlus::Node *RuleTrieCYKPlus::Node::GetOrCreateNonTerminalChild(const Word &targetNonTerm) +{ + UTIL_THROW_IF2(!targetNonTerm.IsNonTerminal(), + "Not a non-terminal: " << targetNonTerm); + + return &m_nonTermMap[targetNonTerm]; +} + +const RuleTrieCYKPlus::Node *RuleTrieCYKPlus::Node::GetChild( + const Word &sourceTerm) const +{ + UTIL_THROW_IF2(sourceTerm.IsNonTerminal(), + "Not a terminal: " << sourceTerm); + + SymbolMap::const_iterator p = m_sourceTermMap.find(sourceTerm); + return (p == m_sourceTermMap.end()) ? NULL : &p->second; +} + +const RuleTrieCYKPlus::Node *RuleTrieCYKPlus::Node::GetNonTerminalChild( + const Word &targetNonTerm) const +{ + UTIL_THROW_IF2(!targetNonTerm.IsNonTerminal(), + "Not a non-terminal: " << targetNonTerm); + + SymbolMap::const_iterator p = m_nonTermMap.find(targetNonTerm); + return (p == m_nonTermMap.end()) ? NULL : &p->second; +} + +TargetPhraseCollection::shared_ptr +RuleTrieCYKPlus:: +GetOrCreateTargetPhraseCollection(const Phrase &source, + const TargetPhrase &target, + const Word *sourceLHS) +{ + Node &currNode = GetOrCreateNode(source, target, sourceLHS); + return currNode.GetTargetPhraseCollection(); +} + +RuleTrieCYKPlus::Node &RuleTrieCYKPlus::GetOrCreateNode( + const Phrase &source, const TargetPhrase &target, const Word *sourceLHS) +{ + const std::size_t size = source.GetSize(); + + const AlignmentInfo &alignmentInfo = target.GetAlignNonTerm(); + AlignmentInfo::const_iterator iterAlign = alignmentInfo.begin(); + + Node *currNode = &m_root; + for (std::size_t pos = 0 ; pos < size ; ++pos) { + const Word& word = source.GetWord(pos); + + if (word.IsNonTerminal()) { + UTIL_THROW_IF2(iterAlign == alignmentInfo.end(), + "No alignment for non-term at position " << pos); + UTIL_THROW_IF2(iterAlign->first != pos, + "Alignment info incorrect at position " << pos); + std::size_t targetNonTermInd = iterAlign->second; + ++iterAlign; + const Word &targetNonTerm = target.GetWord(targetNonTermInd); + currNode = currNode->GetOrCreateNonTerminalChild(targetNonTerm); + } else { + currNode = currNode->GetOrCreateChild(word); + } + + UTIL_THROW_IF2(currNode == NULL, "Node not found at position " << pos); + } + + return *currNode; +} + +void RuleTrieCYKPlus::SortAndPrune(std::size_t tableLimit) +{ + if (tableLimit) { + m_root.Sort(tableLimit); + } +} + +bool RuleTrieCYKPlus::HasPreterminalRule(const Word &w) const +{ + const Node::SymbolMap &map = m_root.GetTerminalMap(); + Node::SymbolMap::const_iterator p = map.find(w); + return p != map.end() && p->second.HasRules(); +} + +} // namespace S2T +} // namespace Syntax +} // namespace Moses diff --git a/mosesdecoder/moses/Syntax/S2T/RuleTrieCYKPlus.h b/mosesdecoder/moses/Syntax/S2T/RuleTrieCYKPlus.h new file mode 100644 index 0000000000000000000000000000000000000000..31880e0ed113ec3ffaa37e65e2f38d111547b426 --- /dev/null +++ b/mosesdecoder/moses/Syntax/S2T/RuleTrieCYKPlus.h @@ -0,0 +1,102 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "moses/Syntax/SymbolEqualityPred.h" +#include "moses/Syntax/SymbolHasher.h" +#include "moses/TargetPhrase.h" +#include "moses/TargetPhraseCollection.h" +#include "moses/Terminal.h" +#include "moses/Util.h" +#include "moses/Word.h" + +#include "RuleTrie.h" + +namespace Moses +{ +namespace Syntax +{ +namespace S2T +{ + +class RuleTrieCYKPlus : public RuleTrie +{ +public: + class Node + { + public: + typedef boost::unordered_map SymbolMap; + + bool IsLeaf() const { + return m_sourceTermMap.empty() && m_nonTermMap.empty(); + } + + bool HasRules() const { + return !m_targetPhraseCollection->IsEmpty(); + } + + void Prune(std::size_t tableLimit); + void Sort(std::size_t tableLimit); + + Node *GetOrCreateChild(const Word &sourceTerm); + Node *GetOrCreateNonTerminalChild(const Word &targetNonTerm); + + const Node *GetChild(const Word &sourceTerm) const; + const Node *GetNonTerminalChild(const Word &targetNonTerm) const; + + TargetPhraseCollection::shared_ptr + GetTargetPhraseCollection() const { + return m_targetPhraseCollection; + } + + TargetPhraseCollection::shared_ptr + GetTargetPhraseCollection() { + return m_targetPhraseCollection; + } + + const SymbolMap &GetTerminalMap() const { + return m_sourceTermMap; + } + + const SymbolMap &GetNonTerminalMap() const { + return m_nonTermMap; + } + + Node() : m_targetPhraseCollection(new TargetPhraseCollection) {} + + private: + SymbolMap m_sourceTermMap; + SymbolMap m_nonTermMap; + TargetPhraseCollection::shared_ptr m_targetPhraseCollection; + }; + + RuleTrieCYKPlus(const RuleTableFF *ff) : RuleTrie(ff) {} + + const Node &GetRootNode() const { + return m_root; + } + + bool HasPreterminalRule(const Word &) const; + +private: + TargetPhraseCollection::shared_ptr + GetOrCreateTargetPhraseCollection + (const Phrase &source, const TargetPhrase &target, const Word *sourceLHS); + + Node &GetOrCreateNode(const Phrase &source, const TargetPhrase &target, + const Word *sourceLHS); + + void SortAndPrune(std::size_t); + + Node m_root; +}; + +} // namespace S2T +} // namespace Syntax +} // namespace Moses diff --git a/mosesdecoder/moses/Syntax/S2T/RuleTrieCreator.h b/mosesdecoder/moses/Syntax/S2T/RuleTrieCreator.h new file mode 100644 index 0000000000000000000000000000000000000000..84a4a1636eebbc305ec43c3b31108aa91baa2d21 --- /dev/null +++ b/mosesdecoder/moses/Syntax/S2T/RuleTrieCreator.h @@ -0,0 +1,34 @@ +#pragma once + +#include "RuleTrie.h" + +namespace Moses +{ +namespace Syntax +{ +namespace S2T +{ + +// Base for classes that create a RuleTrie (currently RuleTrieLoader and +// OovHandler). RuleTrieCreator is a friend of RuleTrie. +class RuleTrieCreator +{ +protected: + // Provide access to RuleTrie's private SortAndPrune function. + void SortAndPrune(RuleTrie &trie, std::size_t limit) { + trie.SortAndPrune(limit); + } + + // Provide access to RuleTrie's private GetOrCreateTargetPhraseCollection + // function. + TargetPhraseCollection::shared_ptr + GetOrCreateTargetPhraseCollection + ( RuleTrie &trie, const Phrase &source, const TargetPhrase &target, + const Word *sourceLHS) { + return trie.GetOrCreateTargetPhraseCollection(source, target, sourceLHS); + } +}; + +} // namespace S2T +} // namespace Syntax +} // namespace Moses diff --git a/mosesdecoder/moses/Syntax/S2T/RuleTrieScope3.cpp b/mosesdecoder/moses/Syntax/S2T/RuleTrieScope3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ceaee9501078c43bd0b6a7990e930b9edd00c538 --- /dev/null +++ b/mosesdecoder/moses/Syntax/S2T/RuleTrieScope3.cpp @@ -0,0 +1,158 @@ +#include "RuleTrieScope3.h" + +#include +#include + +#include +#include +#include + +#include "moses/NonTerminal.h" +#include "moses/TargetPhrase.h" +#include "moses/TargetPhraseCollection.h" +#include "moses/Util.h" +#include "moses/Word.h" + +namespace Moses +{ +namespace Syntax +{ +namespace S2T +{ + +void RuleTrieScope3::Node::Prune(std::size_t tableLimit) +{ + // Recusively prune child node values. + for (TerminalMap::iterator p = m_terminalMap.begin(); + p != m_terminalMap.end(); ++p) { + p->second.Prune(tableLimit); + } + if (m_gapNode) { + m_gapNode->Prune(tableLimit); + } + + // Prune TargetPhraseCollections at this node. + for (LabelMap::iterator p = m_labelMap.begin(); p != m_labelMap.end(); ++p) { + p->second->Prune(true, tableLimit); + } +} + +void RuleTrieScope3::Node::Sort(std::size_t tableLimit) +{ + // Recusively sort child node values. + for (TerminalMap::iterator p = m_terminalMap.begin(); + p != m_terminalMap.end(); ++p) { + p->second.Sort(tableLimit); + } + if (m_gapNode) { + m_gapNode->Sort(tableLimit); + } + + // Sort TargetPhraseCollections at this node. + for (LabelMap::iterator p = m_labelMap.begin(); p != m_labelMap.end(); ++p) { + p->second->Sort(true, tableLimit); + } +} + +RuleTrieScope3::Node *RuleTrieScope3::Node::GetOrCreateTerminalChild( + const Word &sourceTerm) +{ + assert(!sourceTerm.IsNonTerminal()); + std::pair result; + result = m_terminalMap.insert(std::make_pair(sourceTerm, Node())); + const TerminalMap::iterator &iter = result.first; + Node &child = iter->second; + return &child; +} + +RuleTrieScope3::Node *RuleTrieScope3::Node::GetOrCreateNonTerminalChild( + const Word &targetNonTerm) +{ + assert(targetNonTerm.IsNonTerminal()); + if (m_gapNode == NULL) { + m_gapNode = new Node(); + } + return m_gapNode; +} + +TargetPhraseCollection::shared_ptr +RuleTrieScope3:: +Node:: +GetOrCreateTargetPhraseCollection(const TargetPhrase &target) +{ + const AlignmentInfo &alignmentInfo = target.GetAlignNonTerm(); + const std::size_t rank = alignmentInfo.GetSize(); + + std::vector vec; + vec.reserve(rank); + + m_labelTable.resize(rank); + + int i = 0; + for (AlignmentInfo::const_iterator p = alignmentInfo.begin(); + p != alignmentInfo.end(); ++p) { + std::size_t targetNonTermIndex = p->second; + const Word &targetNonTerm = target.GetWord(targetNonTermIndex); + vec.push_back(InsertLabel(i++, targetNonTerm)); + } + TargetPhraseCollection::shared_ptr& ret = m_labelMap[vec]; + if (!ret) ret.reset(new TargetPhraseCollection); + return ret; +} + +TargetPhraseCollection::shared_ptr +RuleTrieScope3:: +GetOrCreateTargetPhraseCollection(const Phrase &source, + const TargetPhrase &target, + const Word *sourceLHS) +{ + Node &currNode = GetOrCreateNode(source, target, sourceLHS); + return currNode.GetOrCreateTargetPhraseCollection(target); +} + +RuleTrieScope3::Node &RuleTrieScope3::GetOrCreateNode( + const Phrase &source, const TargetPhrase &target, const Word */*sourceLHS*/) +{ + const std::size_t size = source.GetSize(); + + const AlignmentInfo &alignmentInfo = target.GetAlignNonTerm(); + AlignmentInfo::const_iterator iterAlign = alignmentInfo.begin(); + + Node *currNode = &m_root; + for (std::size_t pos = 0 ; pos < size ; ++pos) { + const Word &word = source.GetWord(pos); + + if (word.IsNonTerminal()) { + assert(iterAlign != alignmentInfo.end()); + assert(iterAlign->first == pos); + std::size_t targetNonTermInd = iterAlign->second; + ++iterAlign; + const Word &targetNonTerm = target.GetWord(targetNonTermInd); + currNode = currNode->GetOrCreateNonTerminalChild(targetNonTerm); + } else { + currNode = currNode->GetOrCreateTerminalChild(word); + } + + assert(currNode != NULL); + } + + return *currNode; +} + +void RuleTrieScope3::SortAndPrune(std::size_t tableLimit) +{ + if (tableLimit) { + m_root.Sort(tableLimit); + } +} + +bool RuleTrieScope3::HasPreterminalRule(const Word &w) const +{ + const Node::TerminalMap &map = m_root.GetTerminalMap(); + Node::TerminalMap::const_iterator p = map.find(w); + return p != map.end() && p->second.HasRules(); +} + +} // namespace S2T +} // namespace Syntax +} // namespace Moses diff --git a/mosesdecoder/moses/Syntax/S2T/SChart.cpp b/mosesdecoder/moses/Syntax/S2T/SChart.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f47d6efdbad5fec0b4390a9ea2fb45beccc23965 --- /dev/null +++ b/mosesdecoder/moses/Syntax/S2T/SChart.cpp @@ -0,0 +1,20 @@ +#include "SChart.h" + +namespace Moses +{ +namespace Syntax +{ +namespace S2T +{ + +SChart::SChart(std::size_t width) +{ + m_cells.resize(width); + for (std::size_t i = 0; i < width; ++i) { + m_cells[i].resize(width); + } +} + +} // namespace S2T +} // namespace Syntax +} // namespace Moses