sleepyhead111 commited on
Commit
edace67
·
verified ·
1 Parent(s): 5610c2f

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. fairseq-0.10.2/docs/conf.py +133 -0
  3. fairseq-0.10.2/docs/lr_scheduler.rst +34 -0
  4. fairseq-0.10.2/docs/requirements.txt +2 -0
  5. fairseq-0.10.2/docs/tutorial_classifying_names.rst +415 -0
  6. fairseq-0.10.2/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so +3 -0
  7. fairseq-0.10.2/fairseq_cli/__init__.py +0 -0
  8. fairseq-0.10.2/fairseq_cli/__pycache__/generate.cpython-310.pyc +0 -0
  9. fairseq-0.10.2/fairseq_cli/eval_lm.py +279 -0
  10. fairseq-0.10.2/fairseq_cli/train.py +356 -0
  11. mosesdecoder/phrase-extract/Alignment.cpp +70 -0
  12. mosesdecoder/phrase-extract/AlignmentPhrase.h +74 -0
  13. mosesdecoder/phrase-extract/DomainFeature.cpp +170 -0
  14. mosesdecoder/phrase-extract/DomainFeature.h +143 -0
  15. mosesdecoder/phrase-extract/HoleCollection.cpp +77 -0
  16. mosesdecoder/phrase-extract/HoleCollection.h +95 -0
  17. mosesdecoder/phrase-extract/InputFileStream.cpp +61 -0
  18. mosesdecoder/phrase-extract/InputFileStream.h +48 -0
  19. mosesdecoder/phrase-extract/InternalStructFeature.h +64 -0
  20. mosesdecoder/phrase-extract/OutputFileStream.h +81 -0
  21. mosesdecoder/phrase-extract/PhraseExtractionOptions.h +193 -0
  22. mosesdecoder/phrase-extract/RuleExtractionOptions.h +95 -0
  23. mosesdecoder/phrase-extract/ScoreFeature.cpp +114 -0
  24. mosesdecoder/phrase-extract/SyntaxTree.h +12 -0
  25. mosesdecoder/phrase-extract/consolidate-direct-main.cpp +131 -0
  26. mosesdecoder/phrase-extract/extract-lex.h +70 -0
  27. mosesdecoder/phrase-extract/filter-rule-table/CfgFilter.h +30 -0
  28. mosesdecoder/phrase-extract/filter-rule-table/FilterRuleTable.h +54 -0
  29. mosesdecoder/phrase-extract/filter-rule-table/Forest.h +59 -0
  30. mosesdecoder/phrase-extract/filter-rule-table/ForestTsgFilter.cpp +196 -0
  31. mosesdecoder/phrase-extract/filter-rule-table/ForestTsgFilter.h +70 -0
  32. mosesdecoder/phrase-extract/filter-rule-table/Jamfile +1 -0
  33. mosesdecoder/phrase-extract/filter-rule-table/StringCfgFilter.cpp +323 -0
  34. mosesdecoder/phrase-extract/filter-rule-table/StringCfgFilter.h +143 -0
  35. mosesdecoder/phrase-extract/filter-rule-table/StringForest.h +24 -0
  36. mosesdecoder/phrase-extract/filter-rule-table/TreeTsgFilter.h +55 -0
  37. mosesdecoder/phrase-extract/filter-rule-table/TsgFilter.h +55 -0
  38. mosesdecoder/phrase-extract/lexical-reordering/InputFileStream.cpp +68 -0
  39. mosesdecoder/phrase-extract/lexical-reordering/InputFileStream.h +49 -0
  40. mosesdecoder/phrase-extract/lexical-reordering/Jamfile +2 -0
  41. mosesdecoder/phrase-extract/lexical-reordering/gzfilebuf.h +88 -0
  42. mosesdecoder/phrase-extract/lexical-reordering/reordering_classes.cpp +416 -0
  43. mosesdecoder/phrase-extract/lexical-reordering/reordering_classes.h +148 -0
  44. mosesdecoder/phrase-extract/lexical-reordering/score.cpp +269 -0
  45. mosesdecoder/phrase-extract/pcfg-extract/Jamfile +1 -0
  46. mosesdecoder/phrase-extract/pcfg-extract/options.h +41 -0
  47. mosesdecoder/phrase-extract/pcfg-extract/pcfg_extract.cc +138 -0
  48. mosesdecoder/phrase-extract/pcfg-extract/pcfg_extract.h +48 -0
  49. mosesdecoder/phrase-extract/pcfg-extract/rule_collection.h +73 -0
  50. mosesdecoder/phrase-extract/pcfg-extract/rule_extractor.h +51 -0
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
37
  fairseq-0.10.2/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
37
  fairseq-0.10.2/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
38
+ fairseq-0.10.2/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
fairseq-0.10.2/docs/conf.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ #
4
+ # fairseq documentation build configuration file, created by
5
+ # sphinx-quickstart on Fri Aug 17 21:45:30 2018.
6
+ #
7
+ # This file is execfile()d with the current directory set to its
8
+ # containing dir.
9
+ #
10
+ # Note that not all possible configuration values are present in this
11
+ # autogenerated file.
12
+ #
13
+ # All configuration values have a default; values that are commented out
14
+ # serve to show the default.
15
+
16
+ # If extensions (or modules to document with autodoc) are in another directory,
17
+ # add these directories to sys.path here. If the directory is relative to the
18
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
19
+
20
+ import os
21
+ import sys
22
+
23
+
24
+ # source code directory, relative to this file, for sphinx-autobuild
25
+ sys.path.insert(0, os.path.abspath(".."))
26
+
27
+ source_suffix = [".rst"]
28
+
29
+ # -- General configuration ------------------------------------------------
30
+
31
+ # If your documentation needs a minimal Sphinx version, state it here.
32
+ #
33
+ # needs_sphinx = '1.0'
34
+
35
+ # Add any Sphinx extension module names here, as strings. They can be
36
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
37
+ # ones.
38
+ extensions = [
39
+ "sphinx.ext.autodoc",
40
+ "sphinx.ext.intersphinx",
41
+ "sphinx.ext.viewcode",
42
+ "sphinx.ext.napoleon",
43
+ "sphinxarg.ext",
44
+ ]
45
+
46
+ # Add any paths that contain templates here, relative to this directory.
47
+ templates_path = ["_templates"]
48
+
49
+ # The master toctree document.
50
+ master_doc = "index"
51
+
52
+ # General information about the project.
53
+ project = "fairseq"
54
+ copyright = "2019, Facebook AI Research (FAIR)"
55
+ author = "Facebook AI Research (FAIR)"
56
+
57
+ github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/"
58
+
59
+ # The version info for the project you're documenting, acts as replacement for
60
+ # |version| and |release|, also used in various other places throughout the
61
+ # built documents.
62
+ #
63
+ # The short X.Y version.
64
+ version = "0.10.2"
65
+ # The full version, including alpha/beta/rc tags.
66
+ release = "0.10.2"
67
+
68
+ # The language for content autogenerated by Sphinx. Refer to documentation
69
+ # for a list of supported languages.
70
+ #
71
+ # This is also used if you do content translation via gettext catalogs.
72
+ # Usually you set "language" from the command line for these cases.
73
+ language = None
74
+
75
+ # List of patterns, relative to source directory, that match files and
76
+ # directories to ignore when looking for source files.
77
+ # This patterns also effect to html_static_path and html_extra_path
78
+ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
79
+
80
+ # The name of the Pygments (syntax highlighting) style to use.
81
+ pygments_style = "sphinx"
82
+ highlight_language = "python"
83
+
84
+ # If true, `todo` and `todoList` produce output, else they produce nothing.
85
+ todo_include_todos = False
86
+
87
+
88
+ # -- Options for HTML output ----------------------------------------------
89
+
90
+ # The theme to use for HTML and HTML Help pages. See the documentation for
91
+ # a list of builtin themes.
92
+ #
93
+ html_theme = "sphinx_rtd_theme"
94
+
95
+ # Theme options are theme-specific and customize the look and feel of a theme
96
+ # further. For a list of options available for each theme, see the
97
+ # documentation.
98
+ #
99
+ # html_theme_options = {}
100
+
101
+ # Add any paths that contain custom static files (such as style sheets) here,
102
+ # relative to this directory. They are copied after the builtin static files,
103
+ # so a file named "default.css" will overwrite the builtin "default.css".
104
+ html_static_path = ["_static"]
105
+
106
+ html_context = {
107
+ "css_files": [
108
+ "_static/theme_overrides.css", # override wide tables in RTD theme
109
+ ],
110
+ }
111
+
112
+ # Custom sidebar templates, must be a dictionary that maps document names
113
+ # to template names.
114
+ #
115
+ # This is required for the alabaster theme
116
+ # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
117
+ # html_sidebars = {
118
+ # '**': [
119
+ # 'about.html',
120
+ # 'navigation.html',
121
+ # 'relations.html', # needs 'show_related': True theme option to display
122
+ # 'searchbox.html',
123
+ # 'donate.html',
124
+ # ]
125
+ # }
126
+
127
+
128
+ # Example configuration for intersphinx: refer to the Python standard library.
129
+ intersphinx_mapping = {
130
+ "numpy": ("http://docs.scipy.org/doc/numpy/", None),
131
+ "python": ("https://docs.python.org/", None),
132
+ "torch": ("https://pytorch.org/docs/master/", None),
133
+ }
fairseq-0.10.2/docs/lr_scheduler.rst ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. _Learning Rate Schedulers:
5
+
6
+ Learning Rate Schedulers
7
+ ========================
8
+
9
+ Learning Rate Schedulers update the learning rate over the course of training.
10
+ Learning rates can be updated after each update via :func:`step_update` or at
11
+ epoch boundaries via :func:`step`.
12
+
13
+ .. automodule:: fairseq.optim.lr_scheduler
14
+ :members:
15
+
16
+ .. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
17
+ :members:
18
+ :undoc-members:
19
+
20
+ .. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
21
+ :members:
22
+ :undoc-members:
23
+ .. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
24
+ :members:
25
+ :undoc-members:
26
+ .. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
27
+ :members:
28
+ :undoc-members:
29
+ .. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
30
+ :members:
31
+ :undoc-members:
32
+ .. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
33
+ :members:
34
+ :undoc-members:
fairseq-0.10.2/docs/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sphinx<2.0
2
+ sphinx-argparse
fairseq-0.10.2/docs/tutorial_classifying_names.rst ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tutorial: Classifying Names with a Character-Level RNN
2
+ ======================================================
3
+
4
+ In this tutorial we will extend fairseq to support *classification* tasks. In
5
+ particular we will re-implement the PyTorch tutorial for `Classifying Names with
6
+ a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`_
7
+ in fairseq. It is recommended to quickly skim that tutorial before beginning
8
+ this one.
9
+
10
+ This tutorial covers:
11
+
12
+ 1. **Preprocessing the data** to create dictionaries.
13
+ 2. **Registering a new Model** that encodes an input sentence with a simple RNN
14
+ and predicts the output label.
15
+ 3. **Registering a new Task** that loads our dictionaries and dataset.
16
+ 4. **Training the Model** using the existing command-line tools.
17
+ 5. **Writing an evaluation script** that imports fairseq and allows us to
18
+ interactively evaluate our model on new inputs.
19
+
20
+
21
+ 1. Preprocessing the data
22
+ -------------------------
23
+
24
+ The original tutorial provides raw data, but we'll work with a modified version
25
+ of the data that is already tokenized into characters and split into separate
26
+ train, valid and test sets.
27
+
28
+ Download and extract the data from here:
29
+ `tutorial_names.tar.gz <https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz>`_
30
+
31
+ Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess`
32
+ command-line tool to create the dictionaries. While this tool is primarily
33
+ intended for sequence-to-sequence problems, we're able to reuse it here by
34
+ treating the label as a "target" sequence of length 1. We'll also output the
35
+ preprocessed files in "raw" format using the ``--dataset-impl`` option to
36
+ enhance readability:
37
+
38
+ .. code-block:: console
39
+
40
+ > fairseq-preprocess \
41
+ --trainpref names/train --validpref names/valid --testpref names/test \
42
+ --source-lang input --target-lang label \
43
+ --destdir names-bin --dataset-impl raw
44
+
45
+ After running the above command you should see a new directory,
46
+ :file:`names-bin/`, containing the dictionaries for *inputs* and *labels*.
47
+
48
+
49
+ 2. Registering a new Model
50
+ --------------------------
51
+
52
+ Next we'll register a new model in fairseq that will encode an input sentence
53
+ with a simple RNN and predict the output label. Compared to the original PyTorch
54
+ tutorial, our version will also work with batches of data and GPU Tensors.
55
+
56
+ First let's copy the simple RNN module implemented in the `PyTorch tutorial
57
+ <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network>`_.
58
+ Create a new file named :file:`fairseq/models/rnn_classifier.py` with the
59
+ following contents::
60
+
61
+ import torch
62
+ import torch.nn as nn
63
+
64
+ class RNN(nn.Module):
65
+
66
+ def __init__(self, input_size, hidden_size, output_size):
67
+ super(RNN, self).__init__()
68
+
69
+ self.hidden_size = hidden_size
70
+
71
+ self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
72
+ self.i2o = nn.Linear(input_size + hidden_size, output_size)
73
+ self.softmax = nn.LogSoftmax(dim=1)
74
+
75
+ def forward(self, input, hidden):
76
+ combined = torch.cat((input, hidden), 1)
77
+ hidden = self.i2h(combined)
78
+ output = self.i2o(combined)
79
+ output = self.softmax(output)
80
+ return output, hidden
81
+
82
+ def initHidden(self):
83
+ return torch.zeros(1, self.hidden_size)
84
+
85
+ We must also *register* this model with fairseq using the
86
+ :func:`~fairseq.models.register_model` function decorator. Once the model is
87
+ registered we'll be able to use it with the existing :ref:`Command-line Tools`.
88
+
89
+ All registered models must implement the :class:`~fairseq.models.BaseFairseqModel`
90
+ interface, so we'll create a small wrapper class in the same file and register
91
+ it in fairseq with the name ``'rnn_classifier'``::
92
+
93
+ from fairseq.models import BaseFairseqModel, register_model
94
+
95
+ # Note: the register_model "decorator" should immediately precede the
96
+ # definition of the Model class.
97
+
98
+ @register_model('rnn_classifier')
99
+ class FairseqRNNClassifier(BaseFairseqModel):
100
+
101
+ @staticmethod
102
+ def add_args(parser):
103
+ # Models can override this method to add new command-line arguments.
104
+ # Here we'll add a new command-line argument to configure the
105
+ # dimensionality of the hidden state.
106
+ parser.add_argument(
107
+ '--hidden-dim', type=int, metavar='N',
108
+ help='dimensionality of the hidden state',
109
+ )
110
+
111
+ @classmethod
112
+ def build_model(cls, args, task):
113
+ # Fairseq initializes models by calling the ``build_model()``
114
+ # function. This provides more flexibility, since the returned model
115
+ # instance can be of a different type than the one that was called.
116
+ # In this case we'll just return a FairseqRNNClassifier instance.
117
+
118
+ # Initialize our RNN module
119
+ rnn = RNN(
120
+ # We'll define the Task in the next section, but for now just
121
+ # notice that the task holds the dictionaries for the "source"
122
+ # (i.e., the input sentence) and "target" (i.e., the label).
123
+ input_size=len(task.source_dictionary),
124
+ hidden_size=args.hidden_dim,
125
+ output_size=len(task.target_dictionary),
126
+ )
127
+
128
+ # Return the wrapped version of the module
129
+ return FairseqRNNClassifier(
130
+ rnn=rnn,
131
+ input_vocab=task.source_dictionary,
132
+ )
133
+
134
+ def __init__(self, rnn, input_vocab):
135
+ super(FairseqRNNClassifier, self).__init__()
136
+
137
+ self.rnn = rnn
138
+ self.input_vocab = input_vocab
139
+
140
+ # The RNN module in the tutorial expects one-hot inputs, so we can
141
+ # precompute the identity matrix to help convert from indices to
142
+ # one-hot vectors. We register it as a buffer so that it is moved to
143
+ # the GPU when ``cuda()`` is called.
144
+ self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))
145
+
146
+ def forward(self, src_tokens, src_lengths):
147
+ # The inputs to the ``forward()`` function are determined by the
148
+ # Task, and in particular the ``'net_input'`` key in each
149
+ # mini-batch. We'll define the Task in the next section, but for
150
+ # now just know that *src_tokens* has shape `(batch, src_len)` and
151
+ # *src_lengths* has shape `(batch)`.
152
+ bsz, max_src_len = src_tokens.size()
153
+
154
+ # Initialize the RNN hidden state. Compared to the original PyTorch
155
+ # tutorial we'll also handle batched inputs and work on the GPU.
156
+ hidden = self.rnn.initHidden()
157
+ hidden = hidden.repeat(bsz, 1) # expand for batched inputs
158
+ hidden = hidden.to(src_tokens.device) # move to GPU
159
+
160
+ for i in range(max_src_len):
161
+ # WARNING: The inputs have padding, so we should mask those
162
+ # elements here so that padding doesn't affect the results.
163
+ # This is left as an exercise for the reader. The padding symbol
164
+ # is given by ``self.input_vocab.pad()`` and the unpadded length
165
+ # of each input is given by *src_lengths*.
166
+
167
+ # One-hot encode a batch of input characters.
168
+ input = self.one_hot_inputs[src_tokens[:, i].long()]
169
+
170
+ # Feed the input to our RNN.
171
+ output, hidden = self.rnn(input, hidden)
172
+
173
+ # Return the final output state for making a prediction
174
+ return output
175
+
176
+ Finally let's define a *named architecture* with the configuration for our
177
+ model. This is done with the :func:`~fairseq.models.register_model_architecture`
178
+ function decorator. Thereafter this named architecture can be used with the
179
+ ``--arch`` command-line argument, e.g., ``--arch pytorch_tutorial_rnn``::
180
+
181
+ from fairseq.models import register_model_architecture
182
+
183
+ # The first argument to ``register_model_architecture()`` should be the name
184
+ # of the model we registered above (i.e., 'rnn_classifier'). The function we
185
+ # register here should take a single argument *args* and modify it in-place
186
+ # to match the desired architecture.
187
+
188
+ @register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
189
+ def pytorch_tutorial_rnn(args):
190
+ # We use ``getattr()`` to prioritize arguments that are explicitly given
191
+ # on the command-line, so that the defaults defined below are only used
192
+ # when no other value has been specified.
193
+ args.hidden_dim = getattr(args, 'hidden_dim', 128)
194
+
195
+
196
+ 3. Registering a new Task
197
+ -------------------------
198
+
199
+ Now we'll register a new :class:`~fairseq.tasks.FairseqTask` that will load our
200
+ dictionaries and dataset. Tasks can also control how the data is batched into
201
+ mini-batches, but in this tutorial we'll reuse the batching provided by
202
+ :class:`fairseq.data.LanguagePairDataset`.
203
+
204
+ Create a new file named :file:`fairseq/tasks/simple_classification.py` with the
205
+ following contents::
206
+
207
+ import os
208
+ import torch
209
+
210
+ from fairseq.data import Dictionary, LanguagePairDataset
211
+ from fairseq.tasks import FairseqTask, register_task
212
+
213
+
214
+ @register_task('simple_classification')
215
+ class SimpleClassificationTask(FairseqTask):
216
+
217
+ @staticmethod
218
+ def add_args(parser):
219
+ # Add some command-line arguments for specifying where the data is
220
+ # located and the maximum supported input length.
221
+ parser.add_argument('data', metavar='FILE',
222
+ help='file prefix for data')
223
+ parser.add_argument('--max-positions', default=1024, type=int,
224
+ help='max input length')
225
+
226
+ @classmethod
227
+ def setup_task(cls, args, **kwargs):
228
+ # Here we can perform any setup required for the task. This may include
229
+ # loading Dictionaries, initializing shared Embedding layers, etc.
230
+ # In this case we'll just load the Dictionaries.
231
+ input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
232
+ label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
233
+ print('| [input] dictionary: {} types'.format(len(input_vocab)))
234
+ print('| [label] dictionary: {} types'.format(len(label_vocab)))
235
+
236
+ return SimpleClassificationTask(args, input_vocab, label_vocab)
237
+
238
+ def __init__(self, args, input_vocab, label_vocab):
239
+ super().__init__(args)
240
+ self.input_vocab = input_vocab
241
+ self.label_vocab = label_vocab
242
+
243
+ def load_dataset(self, split, **kwargs):
244
+ """Load a given dataset split (e.g., train, valid, test)."""
245
+
246
+ prefix = os.path.join(self.args.data, '{}.input-label'.format(split))
247
+
248
+ # Read input sentences.
249
+ sentences, lengths = [], []
250
+ with open(prefix + '.input', encoding='utf-8') as file:
251
+ for line in file:
252
+ sentence = line.strip()
253
+
254
+ # Tokenize the sentence, splitting on spaces
255
+ tokens = self.input_vocab.encode_line(
256
+ sentence, add_if_not_exist=False,
257
+ )
258
+
259
+ sentences.append(tokens)
260
+ lengths.append(tokens.numel())
261
+
262
+ # Read labels.
263
+ labels = []
264
+ with open(prefix + '.label', encoding='utf-8') as file:
265
+ for line in file:
266
+ label = line.strip()
267
+ labels.append(
268
+ # Convert label to a numeric ID.
269
+ torch.LongTensor([self.label_vocab.add_symbol(label)])
270
+ )
271
+
272
+ assert len(sentences) == len(labels)
273
+ print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))
274
+
275
+ # We reuse LanguagePairDataset since classification can be modeled as a
276
+ # sequence-to-sequence task where the target sequence has length 1.
277
+ self.datasets[split] = LanguagePairDataset(
278
+ src=sentences,
279
+ src_sizes=lengths,
280
+ src_dict=self.input_vocab,
281
+ tgt=labels,
282
+ tgt_sizes=torch.ones(len(labels)), # targets have length 1
283
+ tgt_dict=self.label_vocab,
284
+ left_pad_source=False,
285
+ # Since our target is a single class label, there's no need for
286
+ # teacher forcing. If we set this to ``True`` then our Model's
287
+ # ``forward()`` method would receive an additional argument called
288
+ # *prev_output_tokens* that would contain a shifted version of the
289
+ # target sequence.
290
+ input_feeding=False,
291
+ )
292
+
293
+ def max_positions(self):
294
+ """Return the max input length allowed by the task."""
295
+ # The source should be less than *args.max_positions* and the "target"
296
+ # has max length 1.
297
+ return (self.args.max_positions, 1)
298
+
299
+ @property
300
+ def source_dictionary(self):
301
+ """Return the source :class:`~fairseq.data.Dictionary`."""
302
+ return self.input_vocab
303
+
304
+ @property
305
+ def target_dictionary(self):
306
+ """Return the target :class:`~fairseq.data.Dictionary`."""
307
+ return self.label_vocab
308
+
309
+ # We could override this method if we wanted more control over how batches
310
+ # are constructed, but it's not necessary for this tutorial since we can
311
+ # reuse the batching provided by LanguagePairDataset.
312
+ #
313
+ # def get_batch_iterator(
314
+ # self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
315
+ # ignore_invalid_inputs=False, required_batch_size_multiple=1,
316
+ # seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
317
+ # data_buffer_size=0, disable_iterator_cache=False,
318
+ # ):
319
+ # (...)
320
+
321
+
322
+ 4. Training the Model
323
+ ---------------------
324
+
325
+ Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
326
+ command-line tool for this, making sure to specify our new Task (``--task
327
+ simple_classification``) and Model architecture (``--arch
328
+ pytorch_tutorial_rnn``):
329
+
330
+ .. note::
331
+
332
+ You can also configure the dimensionality of the hidden state by passing the
333
+ ``--hidden-dim`` argument to :ref:`fairseq-train`.
334
+
335
+ .. code-block:: console
336
+
337
+ > fairseq-train names-bin \
338
+ --task simple_classification \
339
+ --arch pytorch_tutorial_rnn \
340
+ --optimizer adam --lr 0.001 --lr-shrink 0.5 \
341
+ --max-tokens 1000
342
+ (...)
343
+ | epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
344
+ | epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
345
+ | done training in 31.6 seconds
346
+
347
+ The model files should appear in the :file:`checkpoints/` directory.
348
+
349
+
350
+ 5. Writing an evaluation script
351
+ -------------------------------
352
+
353
+ Finally we can write a short script to evaluate our model on new inputs. Create
354
+ a new file named :file:`eval_classifier.py` with the following contents::
355
+
356
+ from fairseq import checkpoint_utils, data, options, tasks
357
+
358
+ # Parse command-line arguments for generation
359
+ parser = options.get_generation_parser(default_task='simple_classification')
360
+ args = options.parse_args_and_arch(parser)
361
+
362
+ # Setup task
363
+ task = tasks.setup_task(args)
364
+
365
+ # Load model
366
+ print('| loading model from {}'.format(args.path))
367
+ models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
368
+ model = models[0]
369
+
370
+ while True:
371
+ sentence = input('\nInput: ')
372
+
373
+ # Tokenize into characters
374
+ chars = ' '.join(list(sentence.strip()))
375
+ tokens = task.source_dictionary.encode_line(
376
+ chars, add_if_not_exist=False,
377
+ )
378
+
379
+ # Build mini-batch to feed to the model
380
+ batch = data.language_pair_dataset.collate(
381
+ samples=[{'id': -1, 'source': tokens}], # bsz = 1
382
+ pad_idx=task.source_dictionary.pad(),
383
+ eos_idx=task.source_dictionary.eos(),
384
+ left_pad_source=False,
385
+ input_feeding=False,
386
+ )
387
+
388
+ # Feed batch to the model and get predictions
389
+ preds = model(**batch['net_input'])
390
+
391
+ # Print top 3 predictions and their log-probabilities
392
+ top_scores, top_labels = preds[0].topk(k=3)
393
+ for score, label_idx in zip(top_scores, top_labels):
394
+ label_name = task.target_dictionary.string([label_idx])
395
+ print('({:.2f})\t{}'.format(score, label_name))
396
+
397
+ Now we can evaluate our model interactively. Note that we have included the
398
+ original data path (:file:`names-bin/`) so that the dictionaries can be loaded:
399
+
400
+ .. code-block:: console
401
+
402
+ > python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt
403
+ | [input] dictionary: 64 types
404
+ | [label] dictionary: 24 types
405
+ | loading model from checkpoints/checkpoint_best.pt
406
+
407
+ Input: Satoshi
408
+ (-0.61) Japanese
409
+ (-1.20) Arabic
410
+ (-2.86) Italian
411
+
412
+ Input: Sinbad
413
+ (-0.30) Arabic
414
+ (-1.76) English
415
+ (-4.08) Russian
fairseq-0.10.2/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5893e460c344970e372a5cba54a7892a793e4519e085acde37ad9ad57ea5c48f
3
+ size 1855456
fairseq-0.10.2/fairseq_cli/__init__.py ADDED
File without changes
fairseq-0.10.2/fairseq_cli/__pycache__/generate.cpython-310.pyc ADDED
Binary file (8.12 kB). View file
 
fairseq-0.10.2/fairseq_cli/eval_lm.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Evaluate the perplexity of a trained language model.
9
+ """
10
+
11
+ import logging
12
+ import math
13
+ import os
14
+
15
+ import torch
16
+ from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
17
+ from fairseq.data import LMContextWindowDataset
18
+ from fairseq.logging import progress_bar
19
+ from fairseq.logging.meters import StopwatchMeter, TimeMeter
20
+ from fairseq.sequence_scorer import SequenceScorer
21
+
22
+
23
+ logging.basicConfig(
24
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
25
+ datefmt="%Y-%m-%d %H:%M:%S",
26
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
27
+ )
28
+ logger = logging.getLogger("fairseq_cli.eval_lm")
29
+
30
+
31
+ class WordStat(object):
32
+ def __init__(self, word, is_bpe):
33
+ self.word = word
34
+ self.is_bpe = is_bpe
35
+ self.log_prob = 0
36
+ self.next_word_prob = 0
37
+ self.count = 0
38
+ self.missing_next_words = 0
39
+
40
+ def add(self, log_prob, next_word_prob):
41
+ """increments counters for the sum of log probs of current word and next
42
+ word (given context ending at current word). Since the next word might be at the end of the example,
43
+ or it might be not counted because it is not an ending subword unit,
44
+ also keeps track of how many of those we have seen"""
45
+ if next_word_prob is not None:
46
+ self.next_word_prob += next_word_prob
47
+ else:
48
+ self.missing_next_words += 1
49
+ self.log_prob += log_prob
50
+ self.count += 1
51
+
52
+ def __str__(self):
53
+ return "{}\t{}\t{}\t{}\t{}\t{}".format(
54
+ self.word,
55
+ self.count,
56
+ self.log_prob,
57
+ self.is_bpe,
58
+ self.next_word_prob,
59
+ self.count - self.missing_next_words,
60
+ )
61
+
62
+
63
+ def main(parsed_args, **unused_kwargs):
64
+ assert parsed_args.path is not None, "--path required for evaluation!"
65
+
66
+ if torch.cuda.is_available() and not parsed_args.cpu:
67
+ torch.cuda.set_device(parsed_args.device_id)
68
+
69
+ utils.import_user_module(parsed_args)
70
+
71
+ logger.info(parsed_args)
72
+
73
+ use_cuda = torch.cuda.is_available() and not parsed_args.cpu
74
+
75
+ task = tasks.setup_task(parsed_args)
76
+
77
+ # Load ensemble
78
+ logger.info("loading model(s) from {}".format(parsed_args.path))
79
+ models, args = checkpoint_utils.load_model_ensemble(
80
+ parsed_args.path.split(os.pathsep),
81
+ arg_overrides=eval(parsed_args.model_overrides),
82
+ task=task,
83
+ suffix=getattr(parsed_args, "checkpoint_suffix", ""),
84
+ strict=(parsed_args.checkpoint_shard_count == 1),
85
+ num_shards=parsed_args.checkpoint_shard_count,
86
+ )
87
+
88
+ for arg in vars(parsed_args).keys():
89
+ if arg not in {
90
+ "self_target",
91
+ "future_target",
92
+ "past_target",
93
+ "tokens_per_sample",
94
+ "output_size_dictionary",
95
+ "add_bos_token",
96
+ }:
97
+ setattr(args, arg, getattr(parsed_args, arg))
98
+
99
+ # reduce tokens per sample by the required context window size
100
+ args.tokens_per_sample -= args.context_window
101
+ task = tasks.setup_task(args)
102
+
103
+ # Load dataset splits
104
+ task.load_dataset(args.gen_subset)
105
+ dataset = task.dataset(args.gen_subset)
106
+ if args.context_window > 0:
107
+ dataset = LMContextWindowDataset(
108
+ dataset=dataset,
109
+ tokens_per_sample=args.tokens_per_sample,
110
+ context_window=args.context_window,
111
+ pad_idx=task.source_dictionary.pad(),
112
+ )
113
+ logger.info("{} {} {} examples".format(args.data, args.gen_subset, len(dataset)))
114
+
115
+ # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
116
+ for model in models:
117
+ if args.fp16:
118
+ model.half()
119
+ if use_cuda and not args.pipeline_model_parallel:
120
+ model.cuda()
121
+ model.prepare_for_inference_(args)
122
+
123
+ assert len(models) > 0
124
+
125
+ logger.info(
126
+ "num. model params: {}".format(sum(p.numel() for p in models[0].parameters()))
127
+ )
128
+
129
+ itr = task.get_batch_iterator(
130
+ dataset=dataset,
131
+ max_tokens=args.max_tokens or 36000,
132
+ max_sentences=args.batch_size,
133
+ max_positions=utils.resolve_max_positions(
134
+ *[model.max_positions() for model in models]
135
+ ),
136
+ ignore_invalid_inputs=True,
137
+ num_shards=args.num_shards,
138
+ shard_id=args.shard_id,
139
+ num_workers=args.num_workers,
140
+ data_buffer_size=args.data_buffer_size,
141
+ ).next_epoch_itr(shuffle=False)
142
+ progress = progress_bar.progress_bar(
143
+ itr,
144
+ log_format=args.log_format,
145
+ log_interval=args.log_interval,
146
+ default_log_format=("tqdm" if not args.no_progress_bar else "none"),
147
+ )
148
+
149
+ gen_timer = StopwatchMeter()
150
+ scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)
151
+
152
+ score_sum = 0.0
153
+ count = 0
154
+
155
+ if args.remove_bpe is not None:
156
+ if args.remove_bpe == "sentencepiece":
157
+ raise NotImplementedError
158
+ else:
159
+ bpe_cont = args.remove_bpe.rstrip()
160
+ bpe_toks = {
161
+ i
162
+ for i in range(len(task.source_dictionary))
163
+ if task.source_dictionary[i].endswith(bpe_cont)
164
+ }
165
+ bpe_len = len(bpe_cont)
166
+ else:
167
+ bpe_toks = None
168
+ bpe_len = 0
169
+
170
+ word_stats = dict()
171
+
172
+ wps_meter = TimeMeter()
173
+
174
+ for sample in progress:
175
+ if "net_input" not in sample:
176
+ continue
177
+
178
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
179
+
180
+ gen_timer.start()
181
+ hypos = scorer.generate(models, sample)
182
+ gen_timer.stop(sample["ntokens"])
183
+
184
+ for i, hypos_i in enumerate(hypos):
185
+ hypo = hypos_i[0]
186
+ sample_id = sample["id"][i]
187
+
188
+ tokens = hypo["tokens"]
189
+ tgt_len = tokens.numel()
190
+ pos_scores = hypo["positional_scores"].float()
191
+
192
+ if getattr(args, "add_bos_token", False):
193
+ assert hypo["tokens"][0].item() == task.target_dictionary.bos()
194
+ tokens = tokens[1:]
195
+ pos_scores = pos_scores[1:]
196
+
197
+ skipped_toks = 0
198
+ if bpe_toks is not None:
199
+ for i in range(tgt_len - 1):
200
+ if tokens[i].item() in bpe_toks:
201
+ skipped_toks += 1
202
+ pos_scores[i + 1] += pos_scores[i]
203
+ pos_scores[i] = 0
204
+
205
+ inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf"))
206
+ if inf_scores.any():
207
+ logger.info(
208
+ "skipping tokens with inf scores:",
209
+ task.target_dictionary.string(tokens[inf_scores.nonzero()]),
210
+ )
211
+ pos_scores = pos_scores[(~inf_scores).nonzero()]
212
+ score_sum += pos_scores.sum().cpu()
213
+ count += pos_scores.numel() - skipped_toks
214
+
215
+ if args.output_word_probs or args.output_word_stats:
216
+ w = ""
217
+ word_prob = []
218
+ is_bpe = False
219
+ for i in range(len(tokens)):
220
+ w_ind = tokens[i].item()
221
+ w += task.source_dictionary[w_ind]
222
+ if bpe_toks is not None and w_ind in bpe_toks:
223
+ w = w[:-bpe_len]
224
+ is_bpe = True
225
+ else:
226
+ word_prob.append((w, pos_scores[i].item()))
227
+
228
+ next_prob = None
229
+ ind = i + 1
230
+ while ind < len(tokens):
231
+ if pos_scores[ind].item() != 0:
232
+ next_prob = pos_scores[ind]
233
+ break
234
+ ind += 1
235
+
236
+ word_stats.setdefault(w, WordStat(w, is_bpe)).add(
237
+ pos_scores[i].item(), next_prob
238
+ )
239
+ is_bpe = False
240
+ w = ""
241
+ if args.output_word_probs:
242
+ logger.info(
243
+ str(int(sample_id))
244
+ + " "
245
+ + (
246
+ "\t".join(
247
+ "{} [{:2f}]".format(x[0], x[1]) for x in word_prob
248
+ )
249
+ )
250
+ )
251
+
252
+ wps_meter.update(sample["ntokens"])
253
+ progress.log({"wps": round(wps_meter.avg)})
254
+
255
+ avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2
256
+ logger.info(
257
+ "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format(
258
+ gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg
259
+ )
260
+ )
261
+ logger.info(
262
+ "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format(
263
+ avg_nll_loss, 2 ** avg_nll_loss
264
+ )
265
+ )
266
+
267
+ if args.output_word_stats:
268
+ for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
269
+ logger.info(ws)
270
+
271
+
272
+ def cli_main():
273
+ parser = options.get_eval_lm_parser()
274
+ args = options.parse_args_and_arch(parser)
275
+ distributed_utils.call_main(args, main)
276
+
277
+
278
+ if __name__ == "__main__":
279
+ cli_main()
fairseq-0.10.2/fairseq_cli/train.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Train a new model on one or across multiple GPUs.
8
+ """
9
+
10
+ import argparse
11
+ import logging
12
+ import math
13
+ import os
14
+ import random
15
+ import sys
16
+
17
+ import numpy as np
18
+ import torch
19
+ from fairseq import (
20
+ checkpoint_utils,
21
+ distributed_utils,
22
+ options,
23
+ quantization_utils,
24
+ tasks,
25
+ utils,
26
+ )
27
+ from fairseq.data import iterators
28
+ from fairseq.logging import meters, metrics, progress_bar
29
+ from fairseq.model_parallel.megatron_trainer import MegatronTrainer
30
+ from fairseq.trainer import Trainer
31
+
32
+
33
+ logging.basicConfig(
34
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
35
+ datefmt="%Y-%m-%d %H:%M:%S",
36
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
37
+ stream=sys.stdout,
38
+ )
39
+ logger = logging.getLogger("fairseq_cli.train")
40
+
41
+
42
+ def main(args):
43
+ utils.import_user_module(args)
44
+
45
+ assert (
46
+ args.max_tokens is not None or args.batch_size is not None
47
+ ), "Must specify batch size either with --max-tokens or --batch-size"
48
+
49
+ metrics.reset()
50
+
51
+ np.random.seed(args.seed)
52
+ utils.set_torch_seed(args.seed)
53
+
54
+ if distributed_utils.is_master(args):
55
+ checkpoint_utils.verify_checkpoint_directory(args.save_dir)
56
+
57
+ # Print args
58
+ logger.info(args)
59
+
60
+ # Setup task, e.g., translation, language modeling, etc.
61
+ task = tasks.setup_task(args)
62
+
63
+ # Load valid dataset (we load training data below, based on the latest checkpoint)
64
+ for valid_sub_split in args.valid_subset.split(","):
65
+ task.load_dataset(valid_sub_split, combine=False, epoch=1)
66
+
67
+ # Build model and criterion
68
+ model = task.build_model(args)
69
+ criterion = task.build_criterion(args)
70
+ logger.info(model)
71
+ logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
72
+ logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
73
+ logger.info(
74
+ "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__)
75
+ )
76
+ logger.info(
77
+ "num. model params: {} (num. trained: {})".format(
78
+ sum(p.numel() for p in model.parameters()),
79
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
80
+ )
81
+ )
82
+
83
+ # (optionally) Configure quantization
84
+ if args.quantization_config_path is not None:
85
+ quantizer = quantization_utils.Quantizer(
86
+ config_path=args.quantization_config_path,
87
+ max_epoch=args.max_epoch,
88
+ max_update=args.max_update,
89
+ )
90
+ else:
91
+ quantizer = None
92
+
93
+ # Build trainer
94
+ if args.model_parallel_size == 1:
95
+ trainer = Trainer(args, task, model, criterion, quantizer)
96
+ else:
97
+ trainer = MegatronTrainer(args, task, model, criterion)
98
+
99
+ logger.info(
100
+ "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size)
101
+ )
102
+ logger.info(
103
+ "max tokens per GPU = {} and max sentences per GPU = {}".format(
104
+ args.max_tokens, args.batch_size
105
+ )
106
+ )
107
+
108
+ # Load the latest checkpoint if one is available and restore the
109
+ # corresponding train iterator
110
+ extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
111
+ args,
112
+ trainer,
113
+ # don't cache epoch iterators for sharded datasets
114
+ disable_iterator_cache=task.has_sharded_data("train"),
115
+ )
116
+
117
+ # Train until the learning rate gets too small
118
+ max_epoch = args.max_epoch or math.inf
119
+ lr = trainer.get_lr()
120
+ train_meter = meters.StopwatchMeter()
121
+ train_meter.start()
122
+
123
+ while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
124
+ # train for one epoch
125
+ valid_losses, should_stop = train(args, trainer, task, epoch_itr)
126
+ if should_stop:
127
+ break
128
+
129
+ # only use first validation loss to update the learning rate
130
+ lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
131
+
132
+ epoch_itr = trainer.get_train_iterator(
133
+ epoch_itr.next_epoch_idx,
134
+ # sharded data: get train iterator for next epoch
135
+ load_dataset=task.has_sharded_data("train"),
136
+ # don't cache epoch iterators for sharded datasets
137
+ disable_iterator_cache=task.has_sharded_data("train"),
138
+ )
139
+ train_meter.stop()
140
+ logger.info("done training in {:.1f} seconds".format(train_meter.sum))
141
+
142
+
143
+ def should_stop_early(args, valid_loss):
144
+ # skip check if no validation was done in the current epoch
145
+ if valid_loss is None:
146
+ return False
147
+ if args.patience <= 0:
148
+ return False
149
+
150
+ def is_better(a, b):
151
+ return a > b if args.maximize_best_checkpoint_metric else a < b
152
+
153
+ prev_best = getattr(should_stop_early, "best", None)
154
+ if prev_best is None or is_better(valid_loss, prev_best):
155
+ should_stop_early.best = valid_loss
156
+ should_stop_early.num_runs = 0
157
+ return False
158
+ else:
159
+ should_stop_early.num_runs += 1
160
+ if should_stop_early.num_runs >= args.patience:
161
+ logger.info(
162
+ "early stop since valid performance hasn't improved for last {} runs".format(
163
+ args.patience
164
+ )
165
+ )
166
+ return True
167
+ else:
168
+ return False
169
+
170
+
171
+ @metrics.aggregate("train")
172
+ def train(args, trainer, task, epoch_itr):
173
+ """Train the model for one epoch and return validation losses."""
174
+ # Initialize data iterator
175
+ itr = epoch_itr.next_epoch_itr(
176
+ fix_batches_to_gpus=args.fix_batches_to_gpus,
177
+ shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
178
+ )
179
+ update_freq = (
180
+ args.update_freq[epoch_itr.epoch - 1]
181
+ if epoch_itr.epoch <= len(args.update_freq)
182
+ else args.update_freq[-1]
183
+ )
184
+ itr = iterators.GroupedIterator(itr, update_freq)
185
+ if getattr(args, "tpu", False):
186
+ itr = utils.tpu_data_loader(itr)
187
+ progress = progress_bar.progress_bar(
188
+ itr,
189
+ log_format=args.log_format,
190
+ log_interval=args.log_interval,
191
+ epoch=epoch_itr.epoch,
192
+ tensorboard_logdir=(
193
+ args.tensorboard_logdir if distributed_utils.is_master(args) else None
194
+ ),
195
+ default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
196
+ )
197
+
198
+ trainer.begin_epoch(epoch_itr.epoch)
199
+
200
+ valid_losses = [None]
201
+ valid_subsets = args.valid_subset.split(",")
202
+ should_stop = False
203
+ num_updates = trainer.get_num_updates()
204
+ for i, samples in enumerate(progress):
205
+ with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
206
+ "train_step-%d" % i
207
+ ):
208
+ log_output = trainer.train_step(samples)
209
+
210
+ if log_output is not None: # not OOM, overflow, ...
211
+ # log mid-epoch stats
212
+ num_updates = trainer.get_num_updates()
213
+ if num_updates % args.log_interval == 0:
214
+ stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
215
+ progress.log(stats, tag="train_inner", step=num_updates)
216
+
217
+ # reset mid-epoch stats after each log interval
218
+ # the end-of-epoch stats will still be preserved
219
+ metrics.reset_meters("train_inner")
220
+
221
+ end_of_epoch = not itr.has_next()
222
+ valid_losses, should_stop = validate_and_save(
223
+ args, trainer, task, epoch_itr, valid_subsets, end_of_epoch
224
+ )
225
+
226
+ if should_stop:
227
+ break
228
+
229
+ # log end-of-epoch stats
230
+ logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
231
+ stats = get_training_stats(metrics.get_smoothed_values("train"))
232
+ progress.print(stats, tag="train", step=num_updates)
233
+
234
+ # reset epoch-level meters
235
+ metrics.reset_meters("train")
236
+ return valid_losses, should_stop
237
+
238
+
239
+ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch):
240
+ num_updates = trainer.get_num_updates()
241
+ max_update = args.max_update or math.inf
242
+ do_save = (
243
+ (end_of_epoch and epoch_itr.epoch % args.save_interval == 0)
244
+ or num_updates >= max_update
245
+ or (
246
+ args.save_interval_updates > 0
247
+ and num_updates > 0
248
+ and num_updates % args.save_interval_updates == 0
249
+ and num_updates >= args.validate_after_updates
250
+ )
251
+ )
252
+ do_validate = (
253
+ (not end_of_epoch and do_save) # validate during mid-epoch saves
254
+ or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0)
255
+ or num_updates >= max_update
256
+ or (
257
+ args.validate_interval_updates > 0
258
+ and num_updates > 0
259
+ and num_updates % args.validate_interval_updates == 0
260
+ )
261
+ ) and not args.disable_validation
262
+
263
+ # Validate
264
+ valid_losses = [None]
265
+ if do_validate:
266
+ valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
267
+
268
+ # Stopping conditions
269
+ should_stop = (
270
+ should_stop_early(args, valid_losses[0])
271
+ or num_updates >= max_update
272
+ or (
273
+ args.stop_time_hours > 0
274
+ and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours
275
+ )
276
+ )
277
+
278
+ # Save checkpoint
279
+ if do_save or should_stop:
280
+ logger.info("begin save checkpoint")
281
+ checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
282
+
283
+ return valid_losses, should_stop
284
+
285
+
286
+ def get_training_stats(stats):
287
+ stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
288
+ return stats
289
+
290
+
291
+ def validate(args, trainer, task, epoch_itr, subsets):
292
+ """Evaluate the model on the validation set(s) and return the losses."""
293
+
294
+ if args.fixed_validation_seed is not None:
295
+ # set fixed seed for every validation
296
+ utils.set_torch_seed(args.fixed_validation_seed)
297
+
298
+ trainer.begin_valid_epoch(epoch_itr.epoch)
299
+ valid_losses = []
300
+ for subset in subsets:
301
+ logger.info('begin validation on "{}" subset'.format(subset))
302
+
303
+ # Initialize data iterator
304
+ itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
305
+ if getattr(args, "tpu", False):
306
+ itr = utils.tpu_data_loader(itr)
307
+ progress = progress_bar.progress_bar(
308
+ itr,
309
+ log_format=args.log_format,
310
+ log_interval=args.log_interval,
311
+ epoch=epoch_itr.epoch,
312
+ prefix=f"valid on '{subset}' subset",
313
+ tensorboard_logdir=(
314
+ args.tensorboard_logdir if distributed_utils.is_master(args) else None
315
+ ),
316
+ default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
317
+ )
318
+
319
+ # create a new root metrics aggregator so validation metrics
320
+ # don't pollute other aggregators (e.g., train meters)
321
+ with metrics.aggregate(new_root=True) as agg:
322
+ for sample in progress:
323
+ trainer.valid_step(sample)
324
+
325
+ # log validation stats
326
+ stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
327
+ progress.print(stats, tag=subset, step=trainer.get_num_updates())
328
+
329
+ valid_losses.append(stats[args.best_checkpoint_metric])
330
+ return valid_losses
331
+
332
+
333
+ def get_valid_stats(args, trainer, stats):
334
+ stats["num_updates"] = trainer.get_num_updates()
335
+ if hasattr(checkpoint_utils.save_checkpoint, "best"):
336
+ key = "best_{0}".format(args.best_checkpoint_metric)
337
+ best_function = max if args.maximize_best_checkpoint_metric else min
338
+ stats[key] = best_function(
339
+ checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric]
340
+ )
341
+ return stats
342
+
343
+
344
+ def cli_main(modify_parser=None):
345
+ parser = options.get_training_parser()
346
+ args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
347
+ if args.profile:
348
+ with torch.cuda.profiler.profile():
349
+ with torch.autograd.profiler.emit_nvtx():
350
+ distributed_utils.call_main(args, main)
351
+ else:
352
+ distributed_utils.call_main(args, main)
353
+
354
+
355
+ if __name__ == "__main__":
356
+ cli_main()
mosesdecoder/phrase-extract/Alignment.cpp ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***********************************************************************
2
+ Moses - statistical machine translation system
3
+ Copyright (C) 2006-2011 University of Edinburgh
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18
+ ***********************************************************************/
19
+
20
+ #include "Alignment.h"
21
+
22
+ #include "phrase-extract/syntax-common/exception.h"
23
+
24
+ #include <algorithm>
25
+ #include <cassert>
26
+ #include <cstdlib>
27
+
28
+ namespace MosesTraining
29
+ {
30
+
31
+ void ReadAlignment(const std::string &s, Alignment &a)
32
+ {
33
+ const std::string digits = "0123456789";
34
+
35
+ a.clear();
36
+
37
+ std::string::size_type begin = 0;
38
+ while (true) {
39
+ std::string::size_type end = s.find("-", begin);
40
+ if (end == std::string::npos) {
41
+ return;
42
+ }
43
+ int src = std::atoi(s.substr(begin, end-begin).c_str());
44
+ if (end+1 == s.size()) {
45
+ throw Syntax::Exception("Target index missing");
46
+ }
47
+
48
+ begin = end+1;
49
+ end = s.find_first_not_of(digits, begin+1);
50
+ int tgt;
51
+ if (end == std::string::npos) {
52
+ tgt = std::atoi(s.substr(begin).c_str());
53
+ a.push_back(std::make_pair(src, tgt));
54
+ return;
55
+ } else {
56
+ tgt = std::atoi(s.substr(begin, end-begin).c_str());
57
+ a.push_back(std::make_pair(src, tgt));
58
+ }
59
+ begin = end+1;
60
+ }
61
+ }
62
+
63
+ void FlipAlignment(Alignment &a)
64
+ {
65
+ for (Alignment::iterator p = a.begin(); p != a.end(); ++p) {
66
+ std::swap(p->first, p->second);
67
+ }
68
+ }
69
+
70
+ } // namespace MosesTraining
mosesdecoder/phrase-extract/AlignmentPhrase.h ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // $Id$
2
+ /***********************************************************************
3
+ Moses - factored phrase-based language decoder
4
+ Copyright (C) 2006 University of Edinburgh
5
+
6
+ This library is free software; you can redistribute it and/or
7
+ modify it under the terms of the GNU Lesser General Public
8
+ License as published by the Free Software Foundation; either
9
+ version 2.1 of the License, or (at your option) any later version.
10
+
11
+ This library is distributed in the hope that it will be useful,
12
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14
+ Lesser General Public License for more details.
15
+
16
+ You should have received a copy of the GNU Lesser General Public
17
+ License along with this library; if not, write to the Free Software
18
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19
+ ***********************************************************************/
20
+
21
+ #pragma once
22
+
23
+ #include <vector>
24
+ #include <set>
25
+
26
+ namespace MosesTraining
27
+ {
28
+
29
+ class WordsRange;
30
+
31
+ class AlignmentElement
32
+ {
33
+ protected:
34
+ std::set<size_t> m_elements;
35
+ public:
36
+ typedef std::set<size_t>::iterator iterator;
37
+ typedef std::set<size_t>::const_iterator const_iterator;
38
+ const_iterator begin() const {
39
+ return m_elements.begin();
40
+ }
41
+ const_iterator end() const {
42
+ return m_elements.end();
43
+ }
44
+
45
+ AlignmentElement() {
46
+ }
47
+
48
+ size_t GetSize() const {
49
+ return m_elements.size();
50
+ }
51
+
52
+ void Merge(size_t align);
53
+ };
54
+
55
+ class AlignmentPhrase
56
+ {
57
+ protected:
58
+ std::vector<AlignmentElement> m_elements;
59
+ public:
60
+ AlignmentPhrase(size_t size)
61
+ :m_elements(size) {
62
+ }
63
+ void Merge(const AlignmentPhrase &newAlignment, const WordsRange &newAlignmentRange);
64
+ void Merge(const std::vector< std::vector<size_t> > &source);
65
+ size_t GetSize() const {
66
+ return m_elements.size();
67
+ }
68
+ const AlignmentElement &GetElement(size_t pos) const {
69
+ return m_elements[pos];
70
+ }
71
+ };
72
+
73
+ } // namespace
74
+
mosesdecoder/phrase-extract/DomainFeature.cpp ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "DomainFeature.h"
2
+ #include "ExtractionPhrasePair.h"
3
+ #include "tables-core.h"
4
+ #include "InputFileStream.h"
5
+ #include "util/tokenize.hh"
6
+
7
+ using namespace std;
8
+
9
+ namespace MosesTraining
10
+ {
11
+
12
+ // handling of domain names: load database with sentence-id / domain name info
13
+ void Domain::load( const std::string &domainFileName )
14
+ {
15
+ Moses::InputFileStream fileS( domainFileName );
16
+ istream *fileP = &fileS;
17
+
18
+ string line;
19
+ while(getline(*fileP, line)) {
20
+ // read
21
+ const vector< string > domainSpecLine = util::tokenize( line );
22
+ int lineNumber;
23
+ if (domainSpecLine.size() != 2 ||
24
+ ! sscanf(domainSpecLine[0].c_str(), "%d", &lineNumber)) {
25
+ std::cerr << "ERROR: in domain specification line: '" << line << "'" << endl;
26
+ exit(1);
27
+ }
28
+ // store
29
+ const string &name = domainSpecLine[1];
30
+ spec.push_back( make_pair( lineNumber, name ));
31
+ if (name2id.find( name ) == name2id.end()) {
32
+ name2id[ name ] = list.size();
33
+ list.push_back( name );
34
+ }
35
+ }
36
+ }
37
+
38
+ // get domain name based on sentence number
39
+ string Domain::getDomainOfSentence( int sentenceId ) const
40
+ {
41
+ for(size_t i=0; i<spec.size(); i++) {
42
+ if (sentenceId <= spec[i].first) {
43
+ return spec[i].second;
44
+ }
45
+ }
46
+ return "undefined";
47
+ }
48
+
49
+ DomainFeature::DomainFeature(const string& domainFile) : m_propertyKey("domain")
50
+ {
51
+ //process domain file
52
+ m_domain.load(domainFile);
53
+ }
54
+
55
+ void DomainFeature::addPropertiesToPhrasePair(ExtractionPhrasePair &phrasePair,
56
+ float count,
57
+ int sentenceId) const
58
+ {
59
+ std::string value = m_domain.getDomainOfSentence(sentenceId);
60
+ phrasePair.AddProperty(m_propertyKey, value, count);
61
+ }
62
+
63
+ void DomainFeature::add(const ScoreFeatureContext& context,
64
+ std::vector<float>& denseValues,
65
+ std::map<std::string,float>& sparseValues) const
66
+ {
67
+ const map<string,float> *domainCount = context.phrasePair.GetProperty(m_propertyKey);
68
+ assert( domainCount != NULL );
69
+ add(*domainCount,
70
+ context.phrasePair.GetCount(),
71
+ context.maybeLog,
72
+ denseValues, sparseValues);
73
+ }
74
+
75
+ void SubsetDomainFeature::add(const map<string,float>& domainCount,
76
+ float count,
77
+ const MaybeLog& maybeLog,
78
+ std::vector<float>& denseValues,
79
+ std::map<std::string,float>& sparseValues) const
80
+ {
81
+ if (m_domain.list.size() > 6) {
82
+ UTIL_THROW_IF(m_domain.list.size() > 6, ScoreFeatureArgumentException,
83
+ "too many domains for core domain subset features");
84
+ }
85
+ size_t bitmap = 0;
86
+ for(size_t bit = 0; bit < m_domain.list.size(); bit++) {
87
+ if (domainCount.find( m_domain.list[ bit ] ) != domainCount.end()) {
88
+ bitmap += 1 << bit;
89
+ }
90
+ }
91
+ for(size_t i = 1; i < (1 << m_domain.list.size()); i++) {
92
+ denseValues.push_back(maybeLog( (bitmap == i) ? 2.718 : 1 ));
93
+ }
94
+ }
95
+
96
+ void SparseSubsetDomainFeature::add(const map<string,float>& domainCount,float count,
97
+ const MaybeLog& maybeLog,
98
+ std::vector<float>& denseValues,
99
+ std::map<std::string,float>& sparseValues) const
100
+ {
101
+ typedef vector<string>::const_iterator I;
102
+ ostringstream key;
103
+ key << "doms";
104
+ for (I i = m_domain.list.begin(); i != m_domain.list.end(); ++i) {
105
+ if (domainCount.find(*i) != domainCount.end()) {
106
+ key << "_" << *i;
107
+ }
108
+ }
109
+ sparseValues[key.str()] = 1;
110
+ }
111
+
112
+
113
+ void RatioDomainFeature::add(const map<string,float>& domainCount,float count,
114
+ const MaybeLog& maybeLog,
115
+ std::vector<float>& denseValues,
116
+ std::map<std::string,float>& sparseValues) const
117
+ {
118
+ typedef vector< string >::const_iterator I;
119
+ for (I i = m_domain.list.begin(); i != m_domain.list.end(); i++ ) {
120
+ map<string,float>::const_iterator dci = domainCount.find(*i);
121
+ if (dci == domainCount.end() ) {
122
+ denseValues.push_back(maybeLog( 1 ));
123
+ } else {
124
+ denseValues.push_back(maybeLog(exp( dci->second / count ) ));
125
+ }
126
+ }
127
+ }
128
+
129
+
130
+ void SparseRatioDomainFeature::add(const map<string,float>& domainCount,float count,
131
+ const MaybeLog& maybeLog,
132
+ std::vector<float>& denseValues,
133
+ std::map<std::string,float>& sparseValues) const
134
+ {
135
+ typedef map< string, float >::const_iterator I;
136
+ for (I i=domainCount.begin(); i != domainCount.end(); i++) {
137
+ sparseValues["domr_" + i->first] = (i->second / count);
138
+ }
139
+ }
140
+
141
+
142
+ void IndicatorDomainFeature::add(const map<string,float>& domainCount,float count,
143
+ const MaybeLog& maybeLog,
144
+ std::vector<float>& denseValues,
145
+ std::map<std::string,float>& sparseValues) const
146
+ {
147
+ typedef vector< string >::const_iterator I;
148
+ for (I i = m_domain.list.begin(); i != m_domain.list.end(); i++ ) {
149
+ map<string,float>::const_iterator dci = domainCount.find(*i);
150
+ if (dci == domainCount.end() ) {
151
+ denseValues.push_back(maybeLog( 1 ));
152
+ } else {
153
+ denseValues.push_back(maybeLog(2.718));
154
+ }
155
+ }
156
+ }
157
+
158
+ void SparseIndicatorDomainFeature::add(const map<string,float>& domainCount,float count,
159
+ const MaybeLog& maybeLog,
160
+ std::vector<float>& denseValues,
161
+ std::map<std::string,float>& sparseValues) const
162
+ {
163
+ typedef map< string, float >::const_iterator I;
164
+ for (I i=domainCount.begin(); i != domainCount.end(); i++) {
165
+ sparseValues["dom_" + i->first] = 1;
166
+ }
167
+ }
168
+
169
+ }
170
+
mosesdecoder/phrase-extract/DomainFeature.h ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // $Id$
2
+
3
+ #ifndef _DOMAIN_H
4
+ #define _DOMAIN_H
5
+
6
+ #include <iostream>
7
+ #include <fstream>
8
+ #include <cassert>
9
+ #include <cstdlib>
10
+ #include <string>
11
+ #include <queue>
12
+ #include <map>
13
+ #include <cmath>
14
+
15
+ #include "ScoreFeature.h"
16
+
17
+ namespace MosesTraining
18
+ {
19
+
20
+ class Domain
21
+ {
22
+ public:
23
+ std::vector< std::pair< int, std::string > > spec;
24
+ std::vector< std::string > list;
25
+ std::map< std::string, int > name2id;
26
+ void load( const std::string &fileName );
27
+ std::string getDomainOfSentence( int sentenceId ) const;
28
+ };
29
+
30
+ class DomainFeature : public ScoreFeature
31
+ {
32
+ public:
33
+
34
+ DomainFeature(const std::string& domainFile);
35
+
36
+ void addPropertiesToPhrasePair(ExtractionPhrasePair &phrasePair,
37
+ float count,
38
+ int sentenceId) const;
39
+
40
+ void add(const ScoreFeatureContext& context,
41
+ std::vector<float>& denseValues,
42
+ std::map<std::string,float>& sparseValues) const;
43
+
44
+ protected:
45
+ /** Overridden in subclass */
46
+ virtual void add(const std::map<std::string,float>& domainCounts, float count,
47
+ const MaybeLog& maybeLog,
48
+ std::vector<float>& denseValues,
49
+ std::map<std::string,float>& sparseValues) const = 0;
50
+
51
+
52
+ Domain m_domain;
53
+
54
+ const std::string m_propertyKey;
55
+
56
+ };
57
+
58
+ class SubsetDomainFeature : public DomainFeature
59
+ {
60
+ public:
61
+ SubsetDomainFeature(const std::string& domainFile) :
62
+ DomainFeature(domainFile) {}
63
+
64
+ protected:
65
+ virtual void add(const std::map<std::string,float>& domainCounts, float count,
66
+ const MaybeLog& maybeLog,
67
+ std::vector<float>& denseValues,
68
+ std::map<std::string,float>& sparseValues) const;
69
+ };
70
+
71
+ class SparseSubsetDomainFeature : public DomainFeature
72
+ {
73
+ public:
74
+ SparseSubsetDomainFeature(const std::string& domainFile) :
75
+ DomainFeature(domainFile) {}
76
+
77
+ protected:
78
+ virtual void add(const std::map<std::string,float>& domainCounts, float count,
79
+ const MaybeLog& maybeLog,
80
+ std::vector<float>& denseValues,
81
+ std::map<std::string,float>& sparseValues) const;
82
+
83
+ };
84
+
85
+ class IndicatorDomainFeature : public DomainFeature
86
+ {
87
+ public:
88
+ IndicatorDomainFeature(const std::string& domainFile) :
89
+ DomainFeature(domainFile) {}
90
+
91
+ protected:
92
+ virtual void add(const std::map<std::string,float>& domainCounts, float count,
93
+ const MaybeLog& maybeLog,
94
+ std::vector<float>& denseValues,
95
+ std::map<std::string,float>& sparseValues) const;
96
+ };
97
+
98
+
99
+ class SparseIndicatorDomainFeature : public DomainFeature
100
+ {
101
+ public:
102
+ SparseIndicatorDomainFeature(const std::string& domainFile) :
103
+ DomainFeature(domainFile) {}
104
+
105
+ protected:
106
+ virtual void add(const std::map<std::string,float>& domainCounts, float count,
107
+ const MaybeLog& maybeLog,
108
+ std::vector<float>& denseValues,
109
+ std::map<std::string,float>& sparseValues) const;
110
+ };
111
+
112
+
113
+ class RatioDomainFeature : public DomainFeature
114
+ {
115
+ public:
116
+ RatioDomainFeature(const std::string& domainFile) :
117
+ DomainFeature(domainFile) {}
118
+
119
+ protected:
120
+ virtual void add(const std::map<std::string,float>& domainCounts, float count,
121
+ const MaybeLog& maybeLog,
122
+ std::vector<float>& denseValues,
123
+ std::map<std::string,float>& sparseValues) const;
124
+ };
125
+
126
+
127
+ class SparseRatioDomainFeature : public DomainFeature
128
+ {
129
+ public:
130
+ SparseRatioDomainFeature(const std::string& domainFile) :
131
+ DomainFeature(domainFile) {}
132
+
133
+ protected:
134
+ virtual void add(const std::map<std::string,float>& domainCounts, float count,
135
+ const MaybeLog& maybeLog,
136
+ std::vector<float>& denseValues,
137
+ std::map<std::string,float>& sparseValues) const;
138
+ };
139
+
140
+
141
+ }
142
+
143
+ #endif
mosesdecoder/phrase-extract/HoleCollection.cpp ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***********************************************************************
2
+ Moses - factored phrase-based language decoder
3
+ Copyright (C) 2010 University of Edinburgh
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18
+ ***********************************************************************/
19
+
20
+ #include "HoleCollection.h"
21
+
22
+ #include <algorithm>
23
+
24
+ namespace MosesTraining
25
+ {
26
+
27
+ void HoleCollection::SortSourceHoles()
28
+ {
29
+ assert(m_sortedSourceHoles.size() == 0);
30
+
31
+ // add
32
+ HoleList::iterator iter;
33
+ for (iter = m_holes.begin(); iter != m_holes.end(); ++iter) {
34
+ Hole &currHole = *iter;
35
+ m_sortedSourceHoles.push_back(&currHole);
36
+ }
37
+
38
+ // sort
39
+ std::sort(m_sortedSourceHoles.begin(), m_sortedSourceHoles.end(), HoleSourceOrderer());
40
+ }
41
+
42
+ void HoleCollection::Add(int startT, int endT, int startS, int endS)
43
+ {
44
+ Hole hole(startS, endS, startT, endT);
45
+ m_scope.push_back(Scope(hole));
46
+ m_sourceHoleStartPoints.push_back(startS);
47
+ m_sourceHoleEndPoints.push_back(endS);
48
+ m_holes.push_back(hole);
49
+ m_sortedSourceHoles.clear();
50
+ }
51
+
52
+ void HoleCollection::RemoveLast()
53
+ {
54
+ m_scope.pop_back();
55
+ m_sourceHoleStartPoints.pop_back();
56
+ m_sourceHoleEndPoints.pop_back();
57
+ m_holes.pop_back();
58
+ m_sortedSourceHoles.clear();
59
+ }
60
+
61
+ int HoleCollection::Scope(const Hole &proposedHole) const
62
+ {
63
+ const int holeStart = proposedHole.GetStart(0);
64
+ const int holeEnd = proposedHole.GetEnd(0);
65
+ int scope = m_scope.back();
66
+ if (holeStart == m_sourcePhraseStart.back() ||
67
+ find(m_sourceHoleEndPoints.begin(), m_sourceHoleEndPoints.end(), holeStart-1) != m_sourceHoleEndPoints.end()) {
68
+ ++scope; // Adding hole would introduce choice point at start of hole.
69
+ }
70
+ if (holeEnd == m_sourcePhraseEnd.back() ||
71
+ find(m_sourceHoleStartPoints.begin(), m_sourceHoleStartPoints.end(), holeEnd-1) != m_sourceHoleStartPoints.end()) {
72
+ ++scope; // Adding hole would introduce choice point at end of hole.
73
+ }
74
+ return scope;
75
+ }
76
+
77
+ }
mosesdecoder/phrase-extract/HoleCollection.h ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***********************************************************************
2
+ Moses - factored phrase-based language decoder
3
+ Copyright (C) 2010 University of Edinburgh
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18
+ ***********************************************************************/
19
+
20
+ #pragma once
21
+ #ifndef HOLECOLLECTION_H_INCLUDED_
22
+ #define HOLECOLLECTION_H_INCLUDED_
23
+
24
+ #include <set>
25
+ #include <vector>
26
+
27
+ #include "Hole.h"
28
+
29
+ namespace MosesTraining
30
+ {
31
+
32
+ class HoleCollection
33
+ {
34
+ protected:
35
+ HoleList m_holes;
36
+ std::vector<Hole*> m_sortedSourceHoles;
37
+ std::vector<int> m_sourceHoleStartPoints;
38
+ std::vector<int> m_sourceHoleEndPoints;
39
+ std::vector<int> m_scope;
40
+ std::vector<int> m_sourcePhraseStart;
41
+ std::vector<int> m_sourcePhraseEnd;
42
+
43
+ public:
44
+ HoleCollection(int sourcePhraseStart, int sourcePhraseEnd)
45
+ : m_scope(1, 0)
46
+ , m_sourcePhraseStart(1, sourcePhraseStart)
47
+ , m_sourcePhraseEnd(1, sourcePhraseEnd) {
48
+ }
49
+
50
+ const HoleList &GetHoles() const {
51
+ return m_holes;
52
+ }
53
+
54
+ HoleList &GetHoles() {
55
+ return m_holes;
56
+ }
57
+
58
+ std::vector<Hole*> &GetSortedSourceHoles() {
59
+ return m_sortedSourceHoles;
60
+ }
61
+
62
+ void Add(int startT, int endT, int startS, int endS);
63
+
64
+ void RemoveLast();
65
+
66
+ bool OverlapSource(const Hole &sourceHole) const {
67
+ HoleList::const_iterator iter;
68
+ for (iter = m_holes.begin(); iter != m_holes.end(); ++iter) {
69
+ const Hole &currHole = *iter;
70
+ if (currHole.Overlap(sourceHole, 0))
71
+ return true;
72
+ }
73
+ return false;
74
+ }
75
+
76
+ bool ConsecSource(const Hole &sourceHole) const {
77
+ HoleList::const_iterator iter;
78
+ for (iter = m_holes.begin(); iter != m_holes.end(); ++iter) {
79
+ const Hole &currHole = *iter;
80
+ if (currHole.Neighbor(sourceHole, 0))
81
+ return true;
82
+ }
83
+ return false;
84
+ }
85
+
86
+ // Determine the scope that would result from adding the given hole.
87
+ int Scope(const Hole &proposedHole) const;
88
+
89
+ void SortSourceHoles();
90
+
91
+ };
92
+
93
+ }
94
+
95
+ #endif
mosesdecoder/phrase-extract/InputFileStream.cpp ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // $Id: InputFileStream.cpp 2780 2010-01-29 17:11:17Z bojar $
2
+
3
+ /***********************************************************************
4
+ Moses - factored phrase-based language decoder
5
+ Copyright (C) 2006 University of Edinburgh
6
+
7
+ This library is free software; you can redistribute it and/or
8
+ modify it under the terms of the GNU Lesser General Public
9
+ License as published by the Free Software Foundation; either
10
+ version 2.1 of the License, or (at your option) any later version.
11
+
12
+ This library is distributed in the hope that it will be useful,
13
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15
+ Lesser General Public License for more details.
16
+
17
+ You should have received a copy of the GNU Lesser General Public
18
+ License along with this library; if not, write to the Free Software
19
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20
+ ***********************************************************************/
21
+
22
+ #include "InputFileStream.h"
23
+ #include "gzfilebuf.h"
24
+ #include <iostream>
25
+
26
+ using namespace std;
27
+
28
+ namespace Moses
29
+ {
30
+ InputFileStream::InputFileStream(const std::string &filePath)
31
+ : std::istream(NULL)
32
+ , m_streambuf(NULL)
33
+ {
34
+ if (filePath.size() > 3 &&
35
+ filePath.substr(filePath.size() - 3, 3) == ".gz") {
36
+ m_streambuf = new gzfilebuf(filePath.c_str());
37
+ } else {
38
+ std::filebuf* fb = new std::filebuf();
39
+ fb = fb->open(filePath.c_str(), std::ios::in);
40
+ if (! fb) {
41
+ cerr << "Can't read " << filePath.c_str() << endl;
42
+ exit(1);
43
+ }
44
+ m_streambuf = fb;
45
+ }
46
+ this->init(m_streambuf);
47
+ }
48
+
49
+ InputFileStream::~InputFileStream()
50
+ {
51
+ delete m_streambuf;
52
+ m_streambuf = NULL;
53
+ }
54
+
55
+ void InputFileStream::Close()
56
+ {
57
+ }
58
+
59
+
60
+ }
61
+
mosesdecoder/phrase-extract/InputFileStream.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // $Id: InputFileStream.h 2939 2010-02-24 11:15:44Z jfouet $
2
+
3
+ /***********************************************************************
4
+ Moses - factored phrase-based language decoder
5
+ Copyright (C) 2006 University of Edinburgh
6
+
7
+ This library is free software; you can redistribute it and/or
8
+ modify it under the terms of the GNU Lesser General Public
9
+ License as published by the Free Software Foundation; either
10
+ version 2.1 of the License, or (at your option) any later version.
11
+
12
+ This library is distributed in the hope that it will be useful,
13
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15
+ Lesser General Public License for more details.
16
+
17
+ You should have received a copy of the GNU Lesser General Public
18
+ License along with this library; if not, write to the Free Software
19
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20
+ ***********************************************************************/
21
+
22
+ #ifndef moses_InputFileStream_h
23
+ #define moses_InputFileStream_h
24
+
25
+ #include <cstdlib>
26
+ #include <fstream>
27
+ #include <string>
28
+
29
+ namespace Moses
30
+ {
31
+
32
+ /** Used in place of std::istream, can read zipped files if it ends in .gz
33
+ */
34
+ class InputFileStream : public std::istream
35
+ {
36
+ protected:
37
+ std::streambuf *m_streambuf;
38
+ public:
39
+
40
+ explicit InputFileStream(const std::string &filePath);
41
+ ~InputFileStream();
42
+
43
+ void Close();
44
+ };
45
+
46
+ }
47
+
48
+ #endif
mosesdecoder/phrase-extract/InternalStructFeature.h ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <iostream>
2
+ #include <fstream>
3
+ #include <cassert>
4
+ #include <cstdlib>
5
+ #include <string>
6
+ #include <queue>
7
+ #include <map>
8
+ #include <cmath>
9
+
10
+ #include "ScoreFeature.h"
11
+ #include "extract-ghkm/Node.h"
12
+
13
+ namespace MosesTraining
14
+ {
15
+
16
+
17
+ class InternalStructFeature : public ScoreFeature
18
+ {
19
+ public:
20
+ InternalStructFeature() : m_type(0) {};
21
+ /** Add the values for this feature function. */
22
+ void add(const ScoreFeatureContext& context,
23
+ std::vector<float>& denseValues,
24
+ std::map<std::string,float>& sparseValues) const;
25
+
26
+
27
+ protected:
28
+ /** Overridden in subclass */
29
+ virtual void add(const std::string *treeFragment,
30
+ float count,
31
+ std::vector<float>& denseValues,
32
+ std::map<std::string,float>& sparseValues) const = 0;
33
+ int m_type;
34
+ };
35
+
36
+ class InternalStructFeatureDense : public InternalStructFeature
37
+ {
38
+ public:
39
+ InternalStructFeatureDense()
40
+ :InternalStructFeature() {
41
+ m_type=1;
42
+ } //std::cout<<"InternalStructFeatureDense: Construct "<<m_type<<"\n";}
43
+ protected:
44
+ virtual void add(const std::string *treeFragment,
45
+ float count,
46
+ std::vector<float>& denseValues,
47
+ std::map<std::string,float>& sparseValues) const;
48
+ };
49
+
50
+ class InternalStructFeatureSparse : public InternalStructFeature
51
+ {
52
+ public:
53
+ InternalStructFeatureSparse()
54
+ :InternalStructFeature() {
55
+ m_type=2;
56
+ }// std::cout<<"InternalStructFeatureSparse: Construct "<<m_type<<"\n";}
57
+ protected:
58
+ virtual void add(const std::string *treeFragment,
59
+ float count,
60
+ std::vector<float>& denseValues,
61
+ std::map<std::string,float>& sparseValues) const;
62
+ };
63
+
64
+ }
mosesdecoder/phrase-extract/OutputFileStream.h ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // $Id: InputFileStream.h 2939 2010-02-24 11:15:44Z jfouet $
2
+
3
+ /***********************************************************************
4
+ Moses - factored phrase-based language decoder
5
+ Copyright (C) 2006 University of Edinburgh
6
+
7
+ This library is free software; you can redistribute it and/or
8
+ modify it under the terms of the GNU Lesser General Public
9
+ License as published by the Free Software Foundation; either
10
+ version 2.1 of the License, or (at your option) any later version.
11
+
12
+ This library is distributed in the hope that it will be useful,
13
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15
+ Lesser General Public License for more details.
16
+
17
+ You should have received a copy of the GNU Lesser General Public
18
+ License along with this library; if not, write to the Free Software
19
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20
+ ***********************************************************************/
21
+
22
+ #pragma once
23
+
24
+ #include <cstdlib>
25
+ #include <fstream>
26
+ #include <string>
27
+ #include <iostream>
28
+ #include <boost/iostreams/filtering_stream.hpp>
29
+
30
+ namespace Moses
31
+ {
32
+
33
+ /** Version of std::ostream with transparent compression.
34
+ *
35
+ * Transparently compresses output when writing to a file whose name ends in
36
+ * ".gz". Or, writes to stdout instead of a file when given a filename
37
+ * consisting of just a dash ("-").
38
+ */
39
+ class OutputFileStream : public boost::iostreams::filtering_ostream
40
+ {
41
+ private:
42
+ /** File that needs flushing & closing when we close this stream.
43
+ *
44
+ * Is NULL when no file is opened, e.g. when writing to standard output.
45
+ */
46
+ std::ofstream *m_outFile;
47
+
48
+ /// Is this stream open?
49
+ bool m_open;
50
+
51
+ public:
52
+ /** Create an unopened OutputFileStream.
53
+ *
54
+ * Until it's been opened, nothing can be done with this stream.
55
+ */
56
+ OutputFileStream();
57
+
58
+ /// Create an OutputFileStream, and open it by calling Open().
59
+ OutputFileStream(const std::string &filePath);
60
+ virtual ~OutputFileStream();
61
+
62
+ // TODO: Can we please just always throw an exception when this fails?
63
+ /** Open stream.
64
+ *
65
+ * If filePath is "-" (just a dash), this opens the stream for writing to
66
+ * standard output. Otherwise, it opens the given file. If the filename
67
+ * has the ".gz" suffix, output will be transparently compressed.
68
+ *
69
+ * Call Close() to close the file.
70
+ *
71
+ * Returns whether opening the file was successful. It may also throw an
72
+ * exception on failure.
73
+ */
74
+ bool Open(const std::string &filePath);
75
+
76
+ /// Flush and close stream. After this, the stream can be opened again.
77
+ void Close();
78
+ };
79
+
80
+ }
81
+
mosesdecoder/phrase-extract/PhraseExtractionOptions.h ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ /***********************************************************************
3
+ Moses - factored phrase-based language decoder
4
+ Copyright (C) 2010 University of Edinburgh
5
+
6
+ This library is free software; you can redistribute it and/or
7
+ modify it under the terms of the GNU Lesser General Public
8
+ License as published by the Free Software Foundation; either
9
+ version 2.1 of the License, or (at your option) any later version.
10
+
11
+ This library is distributed in the hope that it will be useful,
12
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14
+ Lesser General Public License for more details.
15
+
16
+ You should have received a copy of the GNU Lesser General Public
17
+ License along with this library; if not, write to the Free Software
18
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19
+ ***********************************************************************/
20
+
21
+
22
+ #include <string>
23
+ #include <vector>
24
+
25
+ namespace MosesTraining
26
+ {
27
+ enum REO_MODEL_TYPE {REO_MSD, REO_MSLR, REO_MONO};
28
+ enum REO_POS {LEFT, RIGHT, DLEFT, DRIGHT, UNKNOWN};
29
+
30
+
31
+ class PhraseExtractionOptions
32
+ {
33
+
34
+ public:
35
+ int maxPhraseLength;
36
+ int minPhraseLength;
37
+ std::string separator;
38
+
39
+ private:
40
+ bool allModelsOutputFlag;
41
+ bool wordModel;
42
+ REO_MODEL_TYPE wordType;
43
+ bool phraseModel;
44
+ REO_MODEL_TYPE phraseType;
45
+ bool hierModel;
46
+ REO_MODEL_TYPE hierType;
47
+ bool orientationFlag;
48
+ bool translationFlag;
49
+ bool includeSentenceIdFlag; //include sentence id in extract file
50
+ bool onlyOutputSpanInfo;
51
+ bool gzOutput;
52
+ std::string instanceWeightsFile; //weights for each sentence
53
+ bool targetConstituentConstrainedFlag;
54
+ bool targetConstituentBoundariesFlag;
55
+ bool flexScoreFlag;
56
+ bool singleWordHeuristicFlag;
57
+
58
+ public:
59
+ std::vector<std::string> placeholders;
60
+ bool debug;
61
+
62
+ PhraseExtractionOptions(const int initmaxPhraseLength):
63
+ maxPhraseLength(initmaxPhraseLength),
64
+ minPhraseLength(3),
65
+ separator("|||"),
66
+ allModelsOutputFlag(false),
67
+ wordModel(false),
68
+ wordType(REO_MSD),
69
+ phraseModel(false),
70
+ phraseType(REO_MSD),
71
+ hierModel(false),
72
+ hierType(REO_MSD),
73
+ orientationFlag(false),
74
+ translationFlag(true),
75
+ includeSentenceIdFlag(false),
76
+ onlyOutputSpanInfo(false),
77
+ gzOutput(false),
78
+ targetConstituentConstrainedFlag(false),
79
+ targetConstituentBoundariesFlag(false),
80
+ flexScoreFlag(false),
81
+ singleWordHeuristicFlag(false),
82
+ debug(false) {
83
+ }
84
+
85
+ //functions for initialization of options
86
+ void initAllModelsOutputFlag(const bool initallModelsOutputFlag) {
87
+ allModelsOutputFlag=initallModelsOutputFlag;
88
+ }
89
+ void initWordModel(const bool initwordModel) {
90
+ wordModel=initwordModel;
91
+ }
92
+ void initWordType(REO_MODEL_TYPE initwordType ) {
93
+ wordType=initwordType;
94
+ }
95
+ void initPhraseModel(const bool initphraseModel ) {
96
+ phraseModel=initphraseModel;
97
+ }
98
+ void initPhraseType(REO_MODEL_TYPE initphraseType) {
99
+ phraseType=initphraseType;
100
+ }
101
+ void initHierModel(const bool inithierModel) {
102
+ hierModel=inithierModel;
103
+ }
104
+ void initHierType(REO_MODEL_TYPE inithierType) {
105
+ hierType=inithierType;
106
+ }
107
+ void initOrientationFlag(const bool initorientationFlag) {
108
+ orientationFlag=initorientationFlag;
109
+ }
110
+ void initTranslationFlag(const bool inittranslationFlag) {
111
+ translationFlag=inittranslationFlag;
112
+ }
113
+ void initIncludeSentenceIdFlag(const bool initincludeSentenceIdFlag) {
114
+ includeSentenceIdFlag=initincludeSentenceIdFlag;
115
+ }
116
+ void initOnlyOutputSpanInfo(const bool initonlyOutputSpanInfo) {
117
+ onlyOutputSpanInfo= initonlyOutputSpanInfo;
118
+ }
119
+ void initGzOutput (const bool initgzOutput) {
120
+ gzOutput= initgzOutput;
121
+ }
122
+ void initInstanceWeightsFile(const char* initInstanceWeightsFile) {
123
+ instanceWeightsFile = std::string(initInstanceWeightsFile);
124
+ }
125
+ void initTargetConstituentConstrainedFlag(const bool initTargetConstituentConstrainedFlag) {
126
+ targetConstituentConstrainedFlag = initTargetConstituentConstrainedFlag;
127
+ }
128
+ void initTargetConstituentBoundariesFlag(const bool initTargetConstituentBoundariesFlag) {
129
+ targetConstituentBoundariesFlag = initTargetConstituentBoundariesFlag;
130
+ }
131
+ void initFlexScoreFlag(const bool initflexScoreFlag) {
132
+ flexScoreFlag=initflexScoreFlag;
133
+ }
134
+ void initSingleWordHeuristicFlag(const bool initSingleWordHeuristicFlag) {
135
+ singleWordHeuristicFlag = initSingleWordHeuristicFlag;
136
+ }
137
+
138
+ // functions for getting values
139
+ bool isAllModelsOutputFlag() const {
140
+ return allModelsOutputFlag;
141
+ }
142
+ bool isWordModel() const {
143
+ return wordModel;
144
+ }
145
+ REO_MODEL_TYPE isWordType() const {
146
+ return wordType;
147
+ }
148
+ bool isPhraseModel() const {
149
+ return phraseModel;
150
+ }
151
+ REO_MODEL_TYPE isPhraseType() const {
152
+ return phraseType;
153
+ }
154
+ bool isHierModel() const {
155
+ return hierModel;
156
+ }
157
+ REO_MODEL_TYPE isHierType() const {
158
+ return hierType;
159
+ }
160
+ bool isOrientationFlag() const {
161
+ return orientationFlag;
162
+ }
163
+ bool isTranslationFlag() const {
164
+ return translationFlag;
165
+ }
166
+ bool isIncludeSentenceIdFlag() const {
167
+ return includeSentenceIdFlag;
168
+ }
169
+ bool isOnlyOutputSpanInfo() const {
170
+ return onlyOutputSpanInfo;
171
+ }
172
+ bool isGzOutput () const {
173
+ return gzOutput;
174
+ }
175
+ std::string getInstanceWeightsFile() const {
176
+ return instanceWeightsFile;
177
+ }
178
+ bool isTargetConstituentConstrainedFlag() const {
179
+ return targetConstituentConstrainedFlag;
180
+ }
181
+ bool isTargetConstituentBoundariesFlag() const {
182
+ return targetConstituentBoundariesFlag;
183
+ }
184
+ bool isFlexScoreFlag() const {
185
+ return flexScoreFlag;
186
+ }
187
+ bool isSingleWordHeuristicFlag() const {
188
+ return singleWordHeuristicFlag;
189
+ }
190
+ };
191
+
192
+ }
193
+
mosesdecoder/phrase-extract/RuleExtractionOptions.h ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***********************************************************************
2
+ Moses - factored phrase-based language decoder
3
+ Copyright (C) 2010 University of Edinburgh
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18
+ ***********************************************************************/
19
+
20
+ #pragma once
21
+
22
+ namespace MosesTraining
23
+ {
24
+
25
+ struct RuleExtractionOptions {
26
+ public:
27
+ int maxSpan;
28
+ int minHoleSource;
29
+ int minHoleTarget;
30
+ int minWords;
31
+ int maxSymbolsTarget;
32
+ int maxSymbolsSource;
33
+ int maxNonTerm;
34
+ int maxScope;
35
+ bool onlyDirectFlag;
36
+ bool glueGrammarFlag;
37
+ bool unknownWordLabelFlag;
38
+ bool onlyOutputSpanInfo;
39
+ bool noFileLimit;
40
+ bool properConditioning;
41
+ bool nonTermFirstWord;
42
+ bool nonTermConsecTarget;
43
+ bool nonTermConsecSource;
44
+ bool requireAlignedWord;
45
+ bool sourceSyntax;
46
+ bool targetSyntax;
47
+ bool targetSyntacticPreferences;
48
+ bool duplicateRules;
49
+ bool fractionalCounting;
50
+ bool pcfgScore;
51
+ bool gzOutput;
52
+ bool unpairedExtractFormat;
53
+ bool conditionOnTargetLhs;
54
+ bool boundaryRules;
55
+ bool flexScoreFlag;
56
+ bool phraseOrientation;
57
+
58
+ RuleExtractionOptions()
59
+ : maxSpan(10)
60
+ , minHoleSource(2)
61
+ , minHoleTarget(1)
62
+ , minWords(1)
63
+ , maxSymbolsTarget(999)
64
+ , maxSymbolsSource(5)
65
+ , maxNonTerm(2)
66
+ , maxScope(999)
67
+ // int minHoleSize(1)
68
+ // int minSubPhraseSize(1) // minimum size of a remaining lexical phrase
69
+ , onlyDirectFlag(false)
70
+ , glueGrammarFlag(false)
71
+ , unknownWordLabelFlag(false)
72
+ , onlyOutputSpanInfo(false)
73
+ , noFileLimit(false)
74
+ //bool zipFiles(false)
75
+ , properConditioning(false)
76
+ , nonTermFirstWord(true)
77
+ , nonTermConsecTarget(true)
78
+ , nonTermConsecSource(false)
79
+ , requireAlignedWord(true)
80
+ , sourceSyntax(false)
81
+ , targetSyntax(false)
82
+ , targetSyntacticPreferences(false)
83
+ , duplicateRules(true)
84
+ , fractionalCounting(true)
85
+ , pcfgScore(false)
86
+ , gzOutput(false)
87
+ , unpairedExtractFormat(false)
88
+ , conditionOnTargetLhs(false)
89
+ , boundaryRules(false)
90
+ , flexScoreFlag(false)
91
+ , phraseOrientation(false) {}
92
+ };
93
+
94
+ }
95
+
mosesdecoder/phrase-extract/ScoreFeature.cpp ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***********************************************************************
2
+ Moses - factored phrase-based language decoder
3
+ Copyright (C) 2012- University of Edinburgh
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18
+ ***********************************************************************/
19
+
20
+ #include <boost/algorithm/string/predicate.hpp>
21
+ #include "ScoreFeature.h"
22
+ #include "DomainFeature.h"
23
+ #include "InternalStructFeature.h"
24
+
25
+ using namespace std;
26
+ using namespace boost::algorithm;
27
+
28
+ namespace MosesTraining
29
+ {
30
+
31
+
32
+ const string& ScoreFeatureManager::usage() const
33
+ {
34
+ const static string& usage = "[--[Sparse]Domain[Indicator|Ratio|Subset|Bin] domain-file [bins]]" ;
35
+ return usage;
36
+ }
37
+
38
+ void ScoreFeatureManager::configure(const std::vector<std::string> args)
39
+ {
40
+ bool domainAdded = false;
41
+ bool sparseDomainAdded = false;
42
+
43
+ for (size_t i = 0; i < args.size(); ++i) {
44
+ if (args[i] == "--IgnoreSentenceId") {
45
+ m_includeSentenceId = true;
46
+ } else if (starts_with(args[i], "--Domain")) {
47
+ string type = args[i].substr(8);
48
+ ++i;
49
+ UTIL_THROW_IF(i == args.size(), ScoreFeatureArgumentException, "Missing domain file");
50
+ string domainFile = args[i];
51
+ UTIL_THROW_IF(domainAdded, ScoreFeatureArgumentException,
52
+ "Only allowed one domain feature");
53
+ if (type == "Subset") {
54
+ m_features.push_back(ScoreFeaturePtr(new SubsetDomainFeature(domainFile)));
55
+ } else if (type == "Ratio") {
56
+ m_features.push_back(ScoreFeaturePtr(new RatioDomainFeature(domainFile)));
57
+ } else if (type == "Indicator") {
58
+ m_features.push_back(ScoreFeaturePtr(new IndicatorDomainFeature(domainFile)));
59
+ } else {
60
+ UTIL_THROW(ScoreFeatureArgumentException, "Unknown domain feature type " << type);
61
+ }
62
+ domainAdded = true;
63
+ m_includeSentenceId = true;
64
+ } else if (starts_with(args[i], "--SparseDomain")) {
65
+ string type = args[i].substr(14);
66
+ ++i;
67
+ UTIL_THROW_IF(i == args.size(), ScoreFeatureArgumentException, "Missing domain file");
68
+ string domainFile = args[i];
69
+ UTIL_THROW_IF(sparseDomainAdded, ScoreFeatureArgumentException,
70
+ "Only allowed one sparse domain feature");
71
+ if (type == "Subset") {
72
+ m_features.push_back(ScoreFeaturePtr(new SparseSubsetDomainFeature(domainFile)));
73
+ } else if (type == "Ratio") {
74
+ m_features.push_back(ScoreFeaturePtr(new SparseRatioDomainFeature(domainFile)));
75
+ } else if (type == "Indicator") {
76
+ m_features.push_back(ScoreFeaturePtr(new SparseIndicatorDomainFeature(domainFile)));
77
+ } else {
78
+ UTIL_THROW(ScoreFeatureArgumentException, "Unknown domain feature type " << type);
79
+ }
80
+ sparseDomainAdded = true;
81
+ m_includeSentenceId = true;
82
+ } else if(args[i] == "--TreeFeatureSparse") {
83
+ //MARIA
84
+ m_features.push_back(ScoreFeaturePtr(new InternalStructFeatureSparse()));
85
+ } else if(args[i] == "--TreeFeatureDense") {
86
+ //MARIA
87
+ m_features.push_back(ScoreFeaturePtr(new InternalStructFeatureDense()));
88
+ } else {
89
+ UTIL_THROW(ScoreFeatureArgumentException,"Unknown score argument " << args[i]);
90
+ }
91
+
92
+ }
93
+
94
+ }
95
+
96
+ void ScoreFeatureManager::addPropertiesToPhrasePair(ExtractionPhrasePair &phrasePair,
97
+ float count,
98
+ int sentenceId) const
99
+ {
100
+ for (size_t i = 0; i < m_features.size(); ++i) {
101
+ m_features[i]->addPropertiesToPhrasePair(phrasePair, count, sentenceId);
102
+ }
103
+ }
104
+
105
+ void ScoreFeatureManager::addFeatures(const ScoreFeatureContext& context,
106
+ std::vector<float>& denseValues,
107
+ std::map<std::string,float>& sparseValues) const
108
+ {
109
+ for (size_t i = 0; i < m_features.size(); ++i) {
110
+ m_features[i]->add(context, denseValues, sparseValues);
111
+ }
112
+ }
113
+ }
114
+
mosesdecoder/phrase-extract/SyntaxTree.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "syntax-common/tree.h"
4
+
5
+ #include "SyntaxNode.h"
6
+
7
+ namespace MosesTraining
8
+ {
9
+
10
+ typedef Syntax::Tree<SyntaxNode> SyntaxTree;
11
+
12
+ } // namespace MosesTraining
mosesdecoder/phrase-extract/consolidate-direct-main.cpp ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***********************************************************************
2
+ Moses - factored phrase-based language decoder
3
+ Copyright (C) 2009 University of Edinburgh
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18
+ ***********************************************************************/
19
+
20
+ #include <string.h>
21
+ #include <fstream>
22
+ #include <vector>
23
+ #include <string>
24
+ #include <iostream>
25
+ #include <cstdlib>
26
+ #include "InputFileStream.h"
27
+ #include "OutputFileStream.h"
28
+ #include "util/tokenize.hh"
29
+
30
+ using namespace std;
31
+
32
+ vector< string > splitLine(const char *line)
33
+ {
34
+ vector< string > item;
35
+ int start=0;
36
+ int i=0;
37
+ for(; line[i] != '\0'; i++) {
38
+ if (line[i] == ' ' &&
39
+ line[i+1] == '|' &&
40
+ line[i+2] == '|' &&
41
+ line[i+3] == '|' &&
42
+ line[i+4] == ' ') {
43
+ if (start > i) start = i; // empty item
44
+ item.push_back( string( line+start, i-start ) );
45
+ start = i+5;
46
+ i += 3;
47
+ }
48
+ }
49
+ item.push_back( string( line+start, i-start ) );
50
+
51
+ return item;
52
+ }
53
+
54
+ bool getLine( istream &fileP, vector< string > &item )
55
+ {
56
+ if (fileP.eof())
57
+ return false;
58
+
59
+ string line;
60
+ if (getline(fileP, line)) {
61
+ item = splitLine(line.c_str());
62
+ return true;
63
+ } else {
64
+ return false;
65
+ }
66
+ }
67
+
68
+
69
+ int main(int argc, char* argv[])
70
+ {
71
+ cerr << "Starting..." << endl;
72
+
73
+ char* &fileNameDirect = argv[1];
74
+ Moses::InputFileStream fileDirect(fileNameDirect);
75
+
76
+
77
+ //fileDirect.open(fileNameDirect);
78
+ if (fileDirect.fail()) {
79
+ cerr << "ERROR: could not open extract file " << fileNameDirect << endl;
80
+ exit(1);
81
+ }
82
+ istream &fileDirectP = fileDirect;
83
+
84
+ char* &fileNameConsolidated = argv[2];
85
+ ostream *fileConsolidated;
86
+
87
+ if (strcmp(fileNameConsolidated, "-") == 0) {
88
+ fileConsolidated = &cout;
89
+ } else {
90
+ Moses::OutputFileStream *outputFile = new Moses::OutputFileStream();
91
+ bool success = outputFile->Open(fileNameConsolidated);
92
+ if (!success) {
93
+ cerr << "ERROR: could not open file phrase table file "
94
+ << fileNameConsolidated << endl;
95
+ exit(1);
96
+ }
97
+ fileConsolidated = outputFile;
98
+ }
99
+
100
+ int i=0;
101
+ while(true) {
102
+ i++;
103
+ if (i%1000 == 0) cerr << "." << flush;
104
+ if (i%10000 == 0) cerr << ":" << flush;
105
+ if (i%100000 == 0) cerr << "!" << flush;
106
+
107
+ vector< string > itemDirect;
108
+ if (! getLine(fileDirectP, itemDirect ))
109
+ break;
110
+
111
+ const vector< string > count = util::tokenize( itemDirect[4] );
112
+ float countEF = atof(count[0].c_str());
113
+ float countF = atof(count[1].c_str());
114
+ float prob = countF/countEF;
115
+
116
+ (*fileConsolidated) << itemDirect[0] << " ||| " // source
117
+ << itemDirect[1] << " ||| " // target
118
+ << prob << " ||| " // prob
119
+ << itemDirect[2] << "||| " // alignment
120
+ << itemDirect[4] << " " << countEF // counts
121
+ << " ||| " << endl;
122
+ }
123
+
124
+ fileConsolidated->flush();
125
+ if (fileConsolidated != &cout) {
126
+ delete fileConsolidated;
127
+ }
128
+
129
+ cerr << "Finished" << endl;
130
+ }
131
+
mosesdecoder/phrase-extract/extract-lex.h ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <map>
4
+ #include <set>
5
+ #include <sstream>
6
+ #include <fstream>
7
+ #include <iostream>
8
+
9
+ namespace MosesTraining
10
+ {
11
+
12
+ class WordCount
13
+ {
14
+ friend std::ostream& operator<<(std::ostream&, const WordCount&);
15
+ public:
16
+ float m_count;
17
+
18
+ std::map<const std::string*, WordCount> m_coll;
19
+
20
+ WordCount()
21
+ :m_count(0) {
22
+ }
23
+
24
+ //WordCount(const WordCount &copy);
25
+
26
+ WordCount(float count)
27
+ :m_count(count) {
28
+ }
29
+
30
+ void AddCount(float incr);
31
+
32
+ std::map<const std::string*, WordCount> &GetColl() {
33
+ return m_coll;
34
+ }
35
+ const std::map<const std::string*, WordCount> &GetColl() const {
36
+ return m_coll;
37
+ }
38
+
39
+ const float GetCount() const {
40
+ return m_count;
41
+ }
42
+
43
+ };
44
+
45
+ class Vocab
46
+ {
47
+ std::set<std::string> m_coll;
48
+ public:
49
+ const std::string *GetOrAdd(const std::string &word);
50
+ };
51
+
52
+ class ExtractLex
53
+ {
54
+ Vocab m_vocab;
55
+ std::map<const std::string*, WordCount> m_collS2T, m_collT2S;
56
+
57
+ void Process(const std::string *target, const std::string *source);
58
+ void Process(WordCount &wcIn, const std::string *out);
59
+ void ProcessUnaligned(std::vector<std::string> &toksTarget, std::vector<std::string> &toksSource
60
+ , const std::vector<bool> &m_sourceAligned, const std::vector<bool> &m_targetAligned);
61
+
62
+ void Output(const std::map<const std::string*, WordCount> &coll, std::ofstream &outStream);
63
+
64
+ public:
65
+ void Process(std::vector<std::string> &toksTarget, std::vector<std::string> &toksSource, std::vector<std::string> &toksAlign, size_t lineCount);
66
+ void Output(std::ofstream &streamLexS2T, std::ofstream &streamLexT2S);
67
+
68
+ };
69
+
70
+ } // namespace
mosesdecoder/phrase-extract/filter-rule-table/CfgFilter.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <istream>
4
+ #include <ostream>
5
+ #include <string>
6
+ #include <vector>
7
+
8
+ namespace MosesTraining
9
+ {
10
+ namespace Syntax
11
+ {
12
+ namespace FilterRuleTable
13
+ {
14
+
15
+ // Base class for StringCfgFilter and TreeCfgFilter, both of which filter rule
16
+ // tables where the source-side is CFG.
17
+ class CfgFilter
18
+ {
19
+ public:
20
+ virtual ~CfgFilter() {}
21
+
22
+ // Read a rule table from 'in' and filter it according to the test sentences.
23
+ virtual void Filter(std::istream &in, std::ostream &out) = 0;
24
+
25
+ protected:
26
+ };
27
+
28
+ } // namespace FilterRuleTable
29
+ } // namespace Syntax
30
+ } // namespace MosesTraining
mosesdecoder/phrase-extract/filter-rule-table/FilterRuleTable.h ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <vector>
4
+ #include <string>
5
+
6
+ #include <boost/shared_ptr.hpp>
7
+
8
+ #include "SyntaxTree.h"
9
+
10
+ #include "syntax-common/tool.h"
11
+
12
+ #include "StringForest.h"
13
+
14
+ namespace MosesTraining
15
+ {
16
+ namespace Syntax
17
+ {
18
+ namespace FilterRuleTable
19
+ {
20
+
21
+ struct Options;
22
+
23
+ class FilterRuleTable : public Tool
24
+ {
25
+ public:
26
+ FilterRuleTable() : Tool("filter-rule-table") {}
27
+
28
+ virtual int Main(int argc, char *argv[]);
29
+
30
+ private:
31
+ // Filter rule table (on std::cin) for test set (string version).
32
+ void Filter(const std::vector<std::vector<std::string> > &);
33
+
34
+ // Filter rule table (on std::cin) for test set (parse tree version).
35
+ void Filter(const std::vector<boost::shared_ptr<SyntaxTree> > &);
36
+
37
+ void ProcessOptions(int, char *[], Options &) const;
38
+
39
+ // Read test set (string version)
40
+ void ReadTestSet(std::istream &,
41
+ std::vector<boost::shared_ptr<std::string> > &);
42
+
43
+ // Read test set (tree version)
44
+ void ReadTestSet(std::istream &,
45
+ std::vector<boost::shared_ptr<SyntaxTree> > &);
46
+
47
+ // Read test set (forest version)
48
+ void ReadTestSet(std::istream &,
49
+ std::vector<boost::shared_ptr<StringForest> > &);
50
+ };
51
+
52
+ } // namespace FilterRuleTable
53
+ } // namespace Syntax
54
+ } // namespace MosesTraining
mosesdecoder/phrase-extract/filter-rule-table/Forest.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <vector>
4
+
5
+ namespace MosesTraining
6
+ {
7
+ namespace Syntax
8
+ {
9
+ namespace FilterRuleTable
10
+ {
11
+
12
+ template<typename T>
13
+ struct Forest {
14
+ struct Vertex;
15
+
16
+ struct Hyperedge {
17
+ Vertex *head;
18
+ std::vector<Vertex *> tail;
19
+ };
20
+
21
+ struct Vertex {
22
+ ~Vertex();
23
+ T value;
24
+ std::vector<Hyperedge *> incoming;
25
+ };
26
+
27
+ Forest() {}
28
+
29
+ ~Forest();
30
+
31
+ std::vector<Vertex *> vertices;
32
+
33
+ private:
34
+ // Copying is not allowed.
35
+ Forest(const Forest &);
36
+ Forest &operator=(const Forest &);
37
+ };
38
+
39
+ template<typename T>
40
+ Forest<T>::~Forest()
41
+ {
42
+ for (typename std::vector<Vertex *>::iterator p = vertices.begin();
43
+ p != vertices.end(); ++p) {
44
+ delete *p;
45
+ }
46
+ }
47
+
48
+ template<typename T>
49
+ Forest<T>::Vertex::~Vertex()
50
+ {
51
+ for (typename std::vector<Hyperedge *>::iterator p = incoming.begin();
52
+ p != incoming.end(); ++p) {
53
+ delete *p;
54
+ }
55
+ }
56
+
57
+ } // namespace FilterRuleTable
58
+ } // namespace Syntax
59
+ } // namespace Moses
mosesdecoder/phrase-extract/filter-rule-table/ForestTsgFilter.cpp ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ForestTsgFilter.h"
2
+
3
+ #include <boost/make_shared.hpp>
4
+
5
+ namespace MosesTraining
6
+ {
7
+ namespace Syntax
8
+ {
9
+ namespace FilterRuleTable
10
+ {
11
+
12
+ // kMatchLimit is used to limit the effort spent trying to match an individual
13
+ // rule. It defines the maximum number of times that MatchFragment() can be
14
+ // called before the search is aborted and the rule is (possibly wrongly)
15
+ // accepted.
16
+ // FIXME Use a better matching algorithm.
17
+ const std::size_t ForestTsgFilter::kMatchLimit = 10000;
18
+
19
+ ForestTsgFilter::ForestTsgFilter(
20
+ const std::vector<boost::shared_ptr<StringForest> > &sentences)
21
+ {
22
+ // Convert each StringForest to an IdForest.
23
+ m_sentences.reserve(sentences.size());
24
+ for (std::vector<boost::shared_ptr<StringForest> >::const_iterator p =
25
+ sentences.begin(); p != sentences.end(); ++p) {
26
+ m_sentences.push_back(StringForestToIdForest(**p));
27
+ }
28
+
29
+ // Construct a map from vocabulary Ids to IdForest nodes.
30
+ m_idToSentence.resize(m_testVocab.Size());
31
+ for (std::size_t i = 0; i < m_sentences.size(); ++i) {
32
+ const IdForest &forest = *(m_sentences[i]);
33
+ for (std::vector<IdForest::Vertex *>::const_iterator
34
+ p = forest.vertices.begin(); p != forest.vertices.end(); ++p) {
35
+ m_idToSentence[(*p)->value.id][i].push_back(*p);
36
+ }
37
+ }
38
+ }
39
+
40
+ boost::shared_ptr<ForestTsgFilter::IdForest>
41
+ ForestTsgFilter::StringForestToIdForest(const StringForest &f)
42
+ {
43
+ typedef StringForest::Vertex StringVertex;
44
+ typedef StringForest::Hyperedge StringHyperedge;
45
+ typedef IdForest::Vertex IdVertex;
46
+ typedef IdForest::Hyperedge IdHyperedge;
47
+
48
+ boost::shared_ptr<IdForest> g = boost::make_shared<IdForest>();
49
+
50
+ // Map from f's vertices to g's vertices.
51
+ boost::unordered_map<const StringVertex *, const IdVertex *> vertexMap;
52
+
53
+ // Create idForest's vertices and populate vertexMap.
54
+ for (std::vector<StringVertex *>::const_iterator p = f.vertices.begin();
55
+ p != f.vertices.end(); ++p) {
56
+ const StringVertex *v = *p;
57
+ IdVertex *w = new IdVertex();
58
+ w->value.id = m_testVocab.Insert(v->value.symbol);
59
+ w->value.start = v->value.start;
60
+ w->value.end = v->value.end;
61
+ g->vertices.push_back(w);
62
+ vertexMap[v] = w;
63
+ }
64
+
65
+ // Create g's hyperedges.
66
+ for (std::vector<StringVertex *>::const_iterator p = f.vertices.begin();
67
+ p != f.vertices.end(); ++p) {
68
+ for (std::vector<StringHyperedge *>::const_iterator
69
+ q = (*p)->incoming.begin(); q != (*p)->incoming.end(); ++q) {
70
+ IdHyperedge *e = new IdHyperedge();
71
+ e->head = const_cast<IdVertex *>(vertexMap[(*q)->head]);
72
+ e->tail.reserve((*q)->tail.size());
73
+ for (std::vector<StringVertex*>::const_iterator
74
+ r = (*q)->tail.begin(); r != (*q)->tail.end(); ++r) {
75
+ e->tail.push_back(const_cast<IdVertex *>(vertexMap[*r]));
76
+ }
77
+ e->head->incoming.push_back(e);
78
+ }
79
+ }
80
+
81
+ return g;
82
+ }
83
+
84
+ bool ForestTsgFilter::MatchFragment(const IdTree &fragment,
85
+ const std::vector<IdTree *> &leaves)
86
+ {
87
+ typedef std::vector<const IdTree *> TreeVec;
88
+
89
+ // Reset the match counter.
90
+ m_matchCount = 0;
91
+
92
+ // Determine which of the fragment's leaves occurs in the smallest number of
93
+ // sentences in the test set. If the fragment contains a rare word
94
+ // (which is pretty likely assuming a Zipfian distribution) then we only
95
+ // have to try matching the fragment against a small number of potential
96
+ // match sites.
97
+ const IdTree *rarestLeaf = leaves[0];
98
+ std::size_t lowestCount = m_idToSentence[rarestLeaf->value()].size();
99
+ for (std::size_t i = 1; i < leaves.size(); ++i) {
100
+ const IdTree *leaf = leaves[i];
101
+ std::size_t count = m_idToSentence[leaf->value()].size();
102
+ if (count < lowestCount) {
103
+ lowestCount = count;
104
+ rarestLeaf = leaf;
105
+ }
106
+ }
107
+
108
+ // Try to match the rule fragment against the sentences where the rarest
109
+ // leaf was found.
110
+ const InnerMap &leafSentenceMap = m_idToSentence[rarestLeaf->value()];
111
+ const InnerMap &rootSentenceMap = m_idToSentence[fragment.value()];
112
+
113
+ std::vector<std::pair<std::size_t, std::size_t> > spans;
114
+ // For each forest i that contains the rarest leaf symbol...
115
+ for (InnerMap::const_iterator p = leafSentenceMap.begin();
116
+ p != leafSentenceMap.end(); ++p) {
117
+ std::size_t i = p->first;
118
+ // Get the set of candidate match sites in forest i (these are vertices
119
+ // with the same label as the root of the rule fragment).
120
+ InnerMap::const_iterator q = rootSentenceMap.find(i);
121
+ if (q == rootSentenceMap.end()) {
122
+ continue;
123
+ }
124
+ const std::vector<const IdForest::Vertex*> &candidates = q->second;
125
+ // Record the span(s) of the rare leaf symbol in forest i.
126
+ spans.clear();
127
+ for (std::vector<const IdForest::Vertex*>::const_iterator
128
+ r = p->second.begin(); r != p->second.end(); ++r) {
129
+ spans.push_back(std::make_pair((*r)->value.start, (*r)->value.end));
130
+ }
131
+ // For each candidate match site in forest i...
132
+ for (std::vector<const IdForest::Vertex*>::const_iterator
133
+ r = candidates.begin(); r != candidates.end(); ++r) {
134
+ const IdForest::Vertex &v = **r;
135
+ // Check that the subtrees rooted at v are at least as wide as the
136
+ // fragment (counting each non-terminal as being one token wide).
137
+ if (v.value.end - v.value.start + 1 < leaves.size()) {
138
+ continue;
139
+ }
140
+ // Check that the candidate's span covers one of the rare leaf symbols.
141
+ bool covered = false;
142
+ for (std::vector<std::pair<std::size_t, std::size_t> >::const_iterator
143
+ s = spans.begin(); s != spans.end(); ++s) {
144
+ if (v.value.start <= s->first && v.value.end >= s->second) {
145
+ covered = true;
146
+ break;
147
+ }
148
+ }
149
+ if (!covered) {
150
+ continue;
151
+ }
152
+ // Attempt to match the fragment at the candidate site.
153
+ if (MatchFragment(fragment, v)) {
154
+ return true;
155
+ }
156
+ }
157
+ }
158
+ return false;
159
+ }
160
+
161
+ bool ForestTsgFilter::MatchFragment(const IdTree &fragment,
162
+ const IdForest::Vertex &v)
163
+ {
164
+ if (++m_matchCount >= kMatchLimit) {
165
+ return true;
166
+ }
167
+ if (fragment.value() != v.value.id) {
168
+ return false;
169
+ }
170
+ const std::vector<IdTree*> &children = fragment.children();
171
+ if (children.empty()) {
172
+ return true;
173
+ }
174
+ for (std::vector<IdForest::Hyperedge *>::const_iterator
175
+ p = v.incoming.begin(); p != v.incoming.end(); ++p) {
176
+ const std::vector<IdForest::Vertex*> &tail = (*p)->tail;
177
+ if (children.size() != tail.size()) {
178
+ continue;
179
+ }
180
+ bool match = true;
181
+ for (std::size_t i = 0; i < children.size(); ++i) {
182
+ if (!MatchFragment(*children[i], *tail[i])) {
183
+ match = false;
184
+ break;
185
+ }
186
+ }
187
+ if (match) {
188
+ return true;
189
+ }
190
+ }
191
+ return false;
192
+ }
193
+
194
+ } // namespace FilterRuleTable
195
+ } // namespace Syntax
196
+ } // namespace MosesTraining
mosesdecoder/phrase-extract/filter-rule-table/ForestTsgFilter.h ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <istream>
4
+ #include <ostream>
5
+ #include <string>
6
+ #include <vector>
7
+
8
+ #include <boost/shared_ptr.hpp>
9
+ #include <boost/unordered_map.hpp>
10
+ #include <boost/unordered_set.hpp>
11
+
12
+ #include "syntax-common/numbered_set.h"
13
+ #include "syntax-common/tree.h"
14
+ #include "syntax-common/tree_fragment_tokenizer.h"
15
+
16
+ #include "Forest.h"
17
+ #include "StringForest.h"
18
+ #include "TsgFilter.h"
19
+
20
+ namespace MosesTraining
21
+ {
22
+ namespace Syntax
23
+ {
24
+ namespace FilterRuleTable
25
+ {
26
+
27
+ // Filters a rule table, discarding rules that cannot be applied to a given
28
+ // test set. The rule table must have a TSG source-side and the test sentences
29
+ // must be parse forests.
30
+ class ForestTsgFilter : public TsgFilter
31
+ {
32
+ public:
33
+ // Initialize the filter for a given set of test forests.
34
+ ForestTsgFilter(const std::vector<boost::shared_ptr<StringForest> > &);
35
+
36
+ private:
37
+ struct IdForestValue {
38
+ Vocabulary::IdType id;
39
+ std::size_t start;
40
+ std::size_t end;
41
+ };
42
+
43
+ static const std::size_t kMatchLimit;
44
+
45
+ // Represents a forest using integer vocabulary values.
46
+ typedef Forest<IdForestValue> IdForest;
47
+
48
+ typedef boost::unordered_map<std::size_t,
49
+ std::vector<const IdForest::Vertex*> > InnerMap;
50
+
51
+ typedef std::vector<InnerMap> IdToSentenceMap;
52
+
53
+ // Forest-specific implementation of virtual function.
54
+ bool MatchFragment(const IdTree &, const std::vector<IdTree *> &);
55
+
56
+ // Try to match a fragment against a specific vertex of a test forest.
57
+ bool MatchFragment(const IdTree &, const IdForest::Vertex &);
58
+
59
+ // Convert a StringForest to an IdForest (wrt m_testVocab). Inserts symbols
60
+ // into m_testVocab.
61
+ boost::shared_ptr<IdForest> StringForestToIdForest(const StringForest &);
62
+
63
+ std::vector<boost::shared_ptr<IdForest> > m_sentences;
64
+ IdToSentenceMap m_idToSentence;
65
+ std::size_t m_matchCount;
66
+ };
67
+
68
+ } // namespace FilterRuleTable
69
+ } // namespace Syntax
70
+ } // namespace MosesTraining
mosesdecoder/phrase-extract/filter-rule-table/Jamfile ADDED
@@ -0,0 +1 @@
 
 
1
+ exe filter-rule-table : [ glob *.cpp ] ..//syntax-common ..//deps ../..//boost_iostreams ../..//boost_program_options ../..//z : <include>.. ;
mosesdecoder/phrase-extract/filter-rule-table/StringCfgFilter.cpp ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "StringCfgFilter.h"
2
+
3
+ #include <algorithm>
4
+
5
+ #include "util/string_piece_hash.hh"
6
+
7
+ namespace MosesTraining
8
+ {
9
+ namespace Syntax
10
+ {
11
+ namespace FilterRuleTable
12
+ {
13
+
14
+ const std::size_t StringCfgFilter::kMaxNGramLength = 5;
15
+
16
+ StringCfgFilter::StringCfgFilter(
17
+ const std::vector<boost::shared_ptr<std::string> > &sentences)
18
+ : m_maxSentenceLength(-1)
19
+ {
20
+ // Populate m_ngramCoordinateMap (except for the CoordinateTable's
21
+ // sentence vectors) and record the sentence lengths.
22
+ m_sentenceLengths.reserve(sentences.size());
23
+ const util::AnyCharacter delimiter(" \t");
24
+ std::vector<Vocabulary::IdType> vocabIds;
25
+ for (std::size_t i = 0; i < sentences.size(); ++i) {
26
+ vocabIds.clear();
27
+ for (util::TokenIter<util::AnyCharacter, true> p(*sentences[i], delimiter);
28
+ p; ++p) {
29
+ std::string tmp;
30
+ p->CopyToString(&tmp);
31
+ vocabIds.push_back(m_testVocab.Insert(tmp));
32
+ }
33
+ AddSentenceNGrams(vocabIds, i);
34
+ const int sentenceLength = static_cast<int>(vocabIds.size());
35
+ m_sentenceLengths.push_back(sentenceLength);
36
+ m_maxSentenceLength = std::max(sentenceLength, m_maxSentenceLength);
37
+ }
38
+
39
+ // Populate the CoordinateTable's sentence vectors.
40
+ for (NGramCoordinateMap::iterator p = m_ngramCoordinateMap.begin();
41
+ p != m_ngramCoordinateMap.end(); ++p) {
42
+ CoordinateTable &ct = p->second;
43
+ ct.sentences.reserve(ct.intraSentencePositions.size());
44
+ for (boost::unordered_map<int, PositionSeq>::const_iterator
45
+ q = ct.intraSentencePositions.begin();
46
+ q != ct.intraSentencePositions.end(); ++q) {
47
+ ct.sentences.push_back(q->first);
48
+ }
49
+ std::sort(ct.sentences.begin(), ct.sentences.end());
50
+ }
51
+ }
52
+
53
+ void StringCfgFilter::Filter(std::istream &in, std::ostream &out)
54
+ {
55
+ const util::MultiCharacter fieldDelimiter("|||");
56
+ const util::AnyCharacter symbolDelimiter(" \t");
57
+
58
+ std::string line;
59
+ std::string prevLine;
60
+ StringPiece source;
61
+ std::vector<StringPiece> symbols;
62
+ Pattern pattern;
63
+ bool keep = true;
64
+ int lineNum = 0;
65
+
66
+ while (std::getline(in, line)) {
67
+ ++lineNum;
68
+
69
+ // Read the source-side of the rule.
70
+ util::TokenIter<util::MultiCharacter> it(line, fieldDelimiter);
71
+
72
+ // Check if this rule has the same source-side as the previous rule. If
73
+ // it does then we already know whether or not to keep the rule. This
74
+ // optimisation is based on the assumption that the rule table is sorted
75
+ // (which is the case in the standard Moses training pipeline).
76
+ if (*it == source) {
77
+ if (keep) {
78
+ out << line << std::endl;
79
+ }
80
+ continue;
81
+ }
82
+
83
+ // The source-side is different from the previous rule's.
84
+ source = *it;
85
+
86
+ // Tokenize the source-side.
87
+ symbols.clear();
88
+ for (util::TokenIter<util::AnyCharacter, true> p(source, symbolDelimiter);
89
+ p; ++p) {
90
+ symbols.push_back(*p);
91
+ }
92
+
93
+ // Generate a pattern (fails if any source-side terminal is not in the
94
+ // test set vocabulary) and attempt to match it against the test sentences.
95
+ keep = GeneratePattern(symbols, pattern) && MatchPattern(pattern);
96
+ if (keep) {
97
+ out << line << std::endl;
98
+ }
99
+
100
+ // Retain line for the next iteration (in order that the source StringPiece
101
+ // remains valid).
102
+ prevLine.swap(line);
103
+ }
104
+ }
105
+
106
+ void StringCfgFilter::AddSentenceNGrams(
107
+ const std::vector<Vocabulary::IdType> &s, std::size_t sentNum)
108
+ {
109
+ const std::size_t len = s.size();
110
+
111
+ NGram ngram;
112
+ // For each starting position in the sentence:
113
+ for (std::size_t i = 0; i < len; ++i) {
114
+ // For each n-gram length: 1, 2, 3, ... kMaxNGramLength (or less when
115
+ // approaching the end of the sentence):
116
+ for (std::size_t n = 1; n <= std::min(kMaxNGramLength, len-i); ++n) {
117
+ ngram.clear();
118
+ for (std::size_t j = 0; j < n; ++j) {
119
+ ngram.push_back(s[i+j]);
120
+ }
121
+ m_ngramCoordinateMap[ngram].intraSentencePositions[sentNum].push_back(i);
122
+ }
123
+ }
124
+ }
125
+
126
+ bool StringCfgFilter::GeneratePattern(const std::vector<StringPiece> &symbols,
127
+ Pattern &pattern) const
128
+ {
129
+ pattern.subpatterns.clear();
130
+ pattern.minGapWidths.clear();
131
+
132
+ int gapWidth = 0;
133
+
134
+ // The first symbol is handled as a special case because there is always a
135
+ // leading gap / non-gap.
136
+ if (IsNonTerminal(symbols[0])) {
137
+ ++gapWidth;
138
+ } else {
139
+ pattern.minGapWidths.push_back(0);
140
+ // Add the symbol to the first n-gram.
141
+ Vocabulary::IdType vocabId =
142
+ m_testVocab.Lookup(symbols[0], StringPieceCompatibleHash(),
143
+ StringPieceCompatibleEquals());
144
+ if (vocabId == Vocabulary::NullId()) {
145
+ return false;
146
+ }
147
+ pattern.subpatterns.push_back(NGram(1, vocabId));
148
+ }
149
+
150
+ // Process the remaining symbols (except the last which is the RHS).
151
+ for (std::size_t i = 1; i < symbols.size()-1; ++i) {
152
+ // Is current symbol a non-terminal?
153
+ if (IsNonTerminal(symbols[i])) {
154
+ ++gapWidth;
155
+ continue;
156
+ }
157
+ // Does the current terminal follow a non-terminal?
158
+ if (gapWidth > 0) {
159
+ pattern.minGapWidths.push_back(gapWidth);
160
+ gapWidth = 0;
161
+ pattern.subpatterns.resize(pattern.subpatterns.size()+1);
162
+ // Is the current n-gram full?
163
+ } else if (pattern.subpatterns.back().size() == kMaxNGramLength) {
164
+ pattern.minGapWidths.push_back(0);
165
+ pattern.subpatterns.resize(pattern.subpatterns.size()+1);
166
+ }
167
+ // Add the symbol to the current n-gram.
168
+ Vocabulary::IdType vocabId =
169
+ m_testVocab.Lookup(symbols[i], StringPieceCompatibleHash(),
170
+ StringPieceCompatibleEquals());
171
+ if (vocabId == Vocabulary::NullId()) {
172
+ return false;
173
+ }
174
+ pattern.subpatterns.back().push_back(vocabId);
175
+ }
176
+
177
+ // Add the final gap width value (0 if the last symbol was a terminal).
178
+ pattern.minGapWidths.push_back(gapWidth);
179
+ return true;
180
+ }
181
+
182
+ bool StringCfgFilter::IsNonTerminal(const StringPiece &symbol) const
183
+ {
184
+ return symbol.size() >= 3 && symbol[0] == '[' &&
185
+ symbol[symbol.size()-1] == ']';
186
+ }
187
+
188
+ bool StringCfgFilter::MatchPattern(const Pattern &pattern) const
189
+ {
190
+ // Step 0: If the pattern is just a single gap (i.e. the original rule
191
+ // was fully non-lexical) then the pattern matches unless the
192
+ // minimum gap width is wider than any sentence.
193
+ if (pattern.subpatterns.empty()) {
194
+ assert(pattern.minGapWidths.size() == 1);
195
+ return pattern.minGapWidths[0] <= m_maxSentenceLength;
196
+ }
197
+
198
+ // Step 1: Look up all of the subpatterns in m_ngramCoordinateMap and record
199
+ // pointers to their CoordinateTables.
200
+ std::vector<const CoordinateTable *> tables;
201
+ for (std::vector<NGram>::const_iterator p = pattern.subpatterns.begin();
202
+ p != pattern.subpatterns.end(); ++p) {
203
+ NGramCoordinateMap::const_iterator q = m_ngramCoordinateMap.find(*p);
204
+ // If a subpattern doesn't appear in m_ngramCoordinateMap then the match
205
+ // has already failed.
206
+ if (q == m_ngramCoordinateMap.end()) {
207
+ return false;
208
+ }
209
+ tables.push_back(&(q->second));
210
+ }
211
+
212
+ // Step 2: Intersect the CoordinateTables' sentence sets to find the set of
213
+ // test set sentences in which all subpatterns occur.
214
+ std::vector<int> intersection = tables[0]->sentences;
215
+ std::vector<int> tmp(intersection.size());
216
+ for (std::size_t i = 1; i < tables.size(); ++i) {
217
+ std::vector<int>::iterator p = std::set_intersection(
218
+ intersection.begin(), intersection.end(), tables[i]->sentences.begin(),
219
+ tables[i]->sentences.end(), tmp.begin());
220
+ tmp.resize(p-tmp.begin());
221
+ if (tmp.empty()) {
222
+ return false;
223
+ }
224
+ intersection.swap(tmp);
225
+ }
226
+
227
+ // Step 3: For each sentence in the intersection, try to find a consistent
228
+ // sequence of intra-sentence positions (one for each subpattern).
229
+ // 'Consistent' here means that the subpatterns occur in the right
230
+ // order and are separated by at least the minimum widths required
231
+ // by the pattern's gaps).
232
+ for (std::vector<int>::const_iterator p = intersection.begin();
233
+ p != intersection.end(); ++p) {
234
+ if (MatchPattern(pattern, tables, *p)) {
235
+ return true;
236
+ }
237
+ }
238
+ return false;
239
+ }
240
+
241
+ bool StringCfgFilter::MatchPattern(
242
+ const Pattern &pattern,
243
+ std::vector<const CoordinateTable *> &tables,
244
+ int sentenceId) const
245
+ {
246
+ const int sentenceLength = m_sentenceLengths[sentenceId];
247
+
248
+ // In the for loop below, we need to know the set of start position ranges
249
+ // where subpattern i is allowed to occur (rangeSet) and we are generating
250
+ // the ranges for subpattern i+1 (nextRangeSet).
251
+ // TODO Merge ranges if subpattern i follows a non-zero gap.
252
+ std::vector<Range> rangeSet;
253
+ std::vector<Range> nextRangeSet;
254
+
255
+ // Calculate the range for the first subpattern.
256
+ int minStart = pattern.minGapWidths[0];
257
+ int maxStart = sentenceLength - MinWidth(pattern, 0);
258
+ rangeSet.push_back(Range(minStart, maxStart));
259
+
260
+ // Attempt to match subpatterns.
261
+ for (int i = 0; i < pattern.subpatterns.size(); ++i) {
262
+ // Look-up the intra-sentence position sequence.
263
+ boost::unordered_map<int, PositionSeq>::const_iterator r =
264
+ tables[i]->intraSentencePositions.find(sentenceId);
265
+ assert(r != tables[i]->intraSentencePositions.end());
266
+ const PositionSeq &col = r->second;
267
+ for (PositionSeq::const_iterator p = col.begin(); p != col.end(); ++p) {
268
+ bool inRange = false;
269
+ for (std::vector<Range>::const_iterator q = rangeSet.begin();
270
+ q != rangeSet.end(); ++q) {
271
+ // TODO Use the fact that the ranges are ordered to break early.
272
+ if (*p >= q->first && *p <= q->second) {
273
+ inRange = true;
274
+ break;
275
+ }
276
+ }
277
+ if (!inRange) {
278
+ continue;
279
+ }
280
+ // If this is the last subpattern then we're done.
281
+ if (i+1 == pattern.subpatterns.size()) {
282
+ return true;
283
+ }
284
+ nextRangeSet.push_back(CalcNextRange(pattern, i, *p, sentenceLength));
285
+ }
286
+ if (nextRangeSet.empty()) {
287
+ return false;
288
+ }
289
+ rangeSet.swap(nextRangeSet);
290
+ nextRangeSet.clear();
291
+ }
292
+ return true;
293
+ }
294
+
295
+ StringCfgFilter::Range StringCfgFilter::CalcNextRange(
296
+ const Pattern &pattern, int i, int x, int sentenceLength) const
297
+ {
298
+ assert(i+1 < pattern.subpatterns.size());
299
+ Range range;
300
+ if (pattern.minGapWidths[i+1] == 0) {
301
+ // The next subpattern follows this one without a gap.
302
+ range.first = range.second = x + pattern.subpatterns[i].size();
303
+ } else {
304
+ range.first = x + pattern.subpatterns[i].size() + pattern.minGapWidths[i+1];
305
+ // TODO MinWidth should only be computed once per subpattern.
306
+ range.second = sentenceLength - MinWidth(pattern, i+1);
307
+ }
308
+ return range;
309
+ }
310
+
311
+ int StringCfgFilter::MinWidth(const Pattern &pattern, int i) const
312
+ {
313
+ int minWidth = 0;
314
+ for (; i < pattern.subpatterns.size(); ++i) {
315
+ minWidth += pattern.subpatterns[i].size();
316
+ minWidth += pattern.minGapWidths[i+1];
317
+ }
318
+ return minWidth;
319
+ }
320
+
321
+ } // namespace FilterRuleTable
322
+ } // namespace Syntax
323
+ } // namespace MosesTraining
mosesdecoder/phrase-extract/filter-rule-table/StringCfgFilter.h ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <string>
4
+ #include <vector>
5
+
6
+ #include "syntax-common/numbered_set.h"
7
+
8
+ #include <boost/shared_ptr.hpp>
9
+ #include <boost/unordered_map.hpp>
10
+
11
+ #include "util/string_piece.hh"
12
+ #include "util/tokenize_piece.hh"
13
+
14
+ #include "CfgFilter.h"
15
+
16
+ namespace MosesTraining
17
+ {
18
+ namespace Syntax
19
+ {
20
+ namespace FilterRuleTable
21
+ {
22
+
23
+ // Filters a rule table, discarding rules that cannot be applied to a given
24
+ // test set. The rule table must have a CFG source-side and the test sentences
25
+ // must be strings.
26
+ class StringCfgFilter : public CfgFilter
27
+ {
28
+ public:
29
+ // Initialize the filter for a given set of test sentences.
30
+ StringCfgFilter(const std::vector<boost::shared_ptr<std::string> > &);
31
+
32
+ void Filter(std::istream &in, std::ostream &out);
33
+
34
+ private:
35
+ // Filtering works by converting the source LHSs of translation rules to
36
+ // patterns containing variable length gaps and then pattern matching
37
+ // against the test set.
38
+ //
39
+ // The algorithm is vaguely similar to Algorithm 1 from Rahman et al. (2006),
40
+ // but with a slightly different definition of a pattern and designed for a
41
+ // text containing sentence boundaries. Here the text is assumed to be
42
+ // short (a few thousand sentences) and the number of patterns is assumed to
43
+ // be large (tens of millions of rules).
44
+ //
45
+ // M. Sohel Rahman, Costas S. Iliopoulos, Inbok Lee, Manal Mohamed, and
46
+ // William F. Smyth
47
+ // "Finding Patterns with Variable Length Gaps or Don't Cares"
48
+ // In proceedings of COCOON, 2006
49
+
50
+ // Max NGram length.
51
+ static const std::size_t kMaxNGramLength;
52
+
53
+ // Maps words from strings to integers.
54
+ typedef NumberedSet<std::string, std::size_t> Vocabulary;
55
+
56
+ // A NGram is a sequence of words.
57
+ typedef std::vector<Vocabulary::IdType> NGram;
58
+
59
+ // A pattern is an alternating sequence of gaps and NGram subpatterns,
60
+ // starting and ending with a gap. Every gap has a minimum width, which
61
+ // can be any integer >= 0 (a gap of width 0 is really a non-gap).
62
+ //
63
+ // The source LHSs of translation rules are converted to patterns where each
64
+ // sequence of m consecutive non-terminals is converted to a gap with minimum
65
+ // width m. For example, if a rule has the source LHS:
66
+ //
67
+ // [NP] and all the king 's men could n't [VB] [NP] together again
68
+ //
69
+ // and kMaxN is set to 5 then the following pattern is used:
70
+ //
71
+ // * <and all the king 's> * <men could n't> * <together again> *
72
+ //
73
+ // where the gaps have minimum widths of 1, 0, 2, and 0.
74
+ //
75
+ struct Pattern {
76
+ std::vector<NGram> subpatterns;
77
+ std::vector<int> minGapWidths;
78
+ };
79
+
80
+ // A sorted (ascending) sequence of start positions.
81
+ typedef std::vector<int> PositionSeq;
82
+
83
+ // A range of start positions.
84
+ typedef std::pair<int, int> Range;
85
+
86
+ // A CoordinateTable records the set of sentences in which a single
87
+ // n-gram occurs and for each of those sentences, the start positions
88
+ struct CoordinateTable {
89
+ // Sentences IDs (ascending). This contains the same values as the key set
90
+ // from intraSentencePositions but sorted into ascending order.
91
+ std::vector<int> sentences;
92
+ // Map from sentence ID to set of intra-sentence start positions.
93
+ boost::unordered_map<int, PositionSeq> intraSentencePositions;
94
+ };
95
+
96
+ // NGramCoordinateMap is the main search structure. It maps a NGram to
97
+ // a CoordinateTable containing the positions that the n-gram occurs at
98
+ // in the test set.
99
+ typedef boost::unordered_map<NGram, CoordinateTable> NGramCoordinateMap;
100
+
101
+ // Add all n-grams and coordinates for a single sentence s with index i.
102
+ void AddSentenceNGrams(const std::vector<Vocabulary::IdType> &s,
103
+ std::size_t i);
104
+
105
+ // Calculate the range of possible start positions for subpattern i+1
106
+ // assuming that subpattern i has position x.
107
+ Range CalcNextRange(const Pattern &p, int i, int x, int sentenceLength) const;
108
+
109
+ // Generate the pattern corresponding to the given source-side of a rule.
110
+ // This will fail if the rule's source-side contains any terminals that
111
+ // do not occur in the test sentence vocabulary.
112
+ bool GeneratePattern(const std::vector<StringPiece> &, Pattern &) const;
113
+
114
+ // Calculate the minimum width of the pattern suffix starting
115
+ // at subpattern i.
116
+ int MinWidth(const Pattern &p, int i) const;
117
+
118
+ bool IsNonTerminal(const StringPiece &symbol) const;
119
+
120
+ // Try to match the pattern p against any sentence in the test set.
121
+ bool MatchPattern(const Pattern &p) const;
122
+
123
+ // Try to match the pattern p against the sentence with the given ID.
124
+ bool MatchPattern(const Pattern &p,
125
+ std::vector<const CoordinateTable *> &tables,
126
+ int id) const;
127
+
128
+ // The main search structure constructed from the test set sentences.
129
+ NGramCoordinateMap m_ngramCoordinateMap;
130
+
131
+ // The lengths of the test sentences.
132
+ std::vector<int> m_sentenceLengths;
133
+
134
+ // The maximum length of any test sentence.
135
+ int m_maxSentenceLength;
136
+
137
+ // The symbol vocabulary of the test sentences.
138
+ Vocabulary m_testVocab;
139
+ };
140
+
141
+ } // namespace FilterRuleTable
142
+ } // namespace Syntax
143
+ } // namespace MosesTraining
mosesdecoder/phrase-extract/filter-rule-table/StringForest.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <string>
4
+
5
+ #include "Forest.h"
6
+
7
+ namespace MosesTraining
8
+ {
9
+ namespace Syntax
10
+ {
11
+ namespace FilterRuleTable
12
+ {
13
+
14
+ struct StringForestValue {
15
+ std::string symbol; // terminal or non-terminal (without square brackets)
16
+ std::size_t start;
17
+ std::size_t end;
18
+ };
19
+
20
+ typedef Forest<StringForestValue> StringForest;
21
+
22
+ } // namespace FilterRuleTable
23
+ } // namespace Syntax
24
+ } // namespace Moses
mosesdecoder/phrase-extract/filter-rule-table/TreeTsgFilter.h ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <istream>
4
+ #include <ostream>
5
+ #include <string>
6
+ #include <vector>
7
+
8
+ #include <boost/shared_ptr.hpp>
9
+ #include <boost/unordered_map.hpp>
10
+
11
+ #include "SyntaxTree.h"
12
+
13
+ #include "syntax-common/numbered_set.h"
14
+ #include "syntax-common/tree.h"
15
+ #include "syntax-common/tree_fragment_tokenizer.h"
16
+
17
+ #include "TsgFilter.h"
18
+
19
+ namespace MosesTraining
20
+ {
21
+ namespace Syntax
22
+ {
23
+ namespace FilterRuleTable
24
+ {
25
+
26
+ // Filters a rule table, discarding rules that cannot be applied to a given
27
+ // test set. The rule table must have a TSG source-side and the test sentences
28
+ // must be parse trees.
29
+ class TreeTsgFilter : public TsgFilter
30
+ {
31
+ public:
32
+ // Initialize the filter for a given set of test sentences.
33
+ TreeTsgFilter(const std::vector<boost::shared_ptr<SyntaxTree> > &);
34
+
35
+ private:
36
+ // Add an entry to m_labelToTree for every subtree of the given tree.
37
+ void AddNodesToMap(const IdTree &);
38
+
39
+ // Tree-specific implementation of virtual function.
40
+ bool MatchFragment(const IdTree &, const std::vector<IdTree *> &);
41
+
42
+ // Try to match a fragment against a specific subtree of a test tree.
43
+ bool MatchFragment(const IdTree &, const IdTree &);
44
+
45
+ // Convert a SyntaxTree to an IdTree (wrt m_testVocab). Inserts symbols into
46
+ // m_testVocab.
47
+ IdTree *SyntaxTreeToIdTree(const SyntaxTree &);
48
+
49
+ std::vector<boost::shared_ptr<IdTree> > m_sentences;
50
+ std::vector<std::vector<const IdTree *> > m_labelToTree;
51
+ };
52
+
53
+ } // namespace FilterRuleTable
54
+ } // namespace Syntax
55
+ } // namespace MosesTraining
mosesdecoder/phrase-extract/filter-rule-table/TsgFilter.h ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <istream>
4
+ #include <ostream>
5
+ #include <string>
6
+ #include <vector>
7
+
8
+ #include "syntax-common/numbered_set.h"
9
+ #include "syntax-common/tree.h"
10
+ #include "syntax-common/tree_fragment_tokenizer.h"
11
+
12
+ namespace MosesTraining
13
+ {
14
+ namespace Syntax
15
+ {
16
+ namespace FilterRuleTable
17
+ {
18
+
19
+ // Base class for TreeTsgFilter and ForestTsgFilter, both of which filter rule
20
+ // tables where the source-side is TSG.
21
+ class TsgFilter
22
+ {
23
+ public:
24
+ virtual ~TsgFilter() {}
25
+
26
+ // Read a rule table from 'in' and filter it according to the test sentences.
27
+ void Filter(std::istream &in, std::ostream &out);
28
+
29
+ protected:
30
+ // Maps symbols (terminals and non-terminals) from strings to integers.
31
+ typedef NumberedSet<std::string, std::size_t> Vocabulary;
32
+
33
+ // Represents a tree using integer vocabulary values.
34
+ typedef Tree<Vocabulary::IdType> IdTree;
35
+
36
+ // Build an IdTree (wrt m_testVocab) for the tree beginning at position i of
37
+ // the token sequence or return 0 if any symbol in the fragment is not in
38
+ // m_testVocab. If successful then on return, i will be set to the position
39
+ // immediately after the last token of the tree and leaves will contain the
40
+ // pointers to the fragment's leaves. If the build fails then i and leaves
41
+ // are undefined.
42
+ IdTree *BuildTree(const std::vector<TreeFragmentToken> &tokens, int &i,
43
+ std::vector<IdTree *> &leaves);
44
+
45
+ // Try to match a fragment. The implementation depends on whether the test
46
+ // sentences are trees or forests.
47
+ virtual bool MatchFragment(const IdTree &, const std::vector<IdTree *> &) = 0;
48
+
49
+ // The symbol vocabulary of the test sentences.
50
+ Vocabulary m_testVocab;
51
+ };
52
+
53
+ } // namespace FilterRuleTable
54
+ } // namespace Syntax
55
+ } // namespace MosesTraining
mosesdecoder/phrase-extract/lexical-reordering/InputFileStream.cpp ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // $Id: InputFileStream.cpp 2780 2010-01-29 17:11:17Z bojar $
2
+
3
+ /***********************************************************************
4
+ Moses - factored phrase-based language decoder
5
+ Copyright (C) 2006 University of Edinburgh
6
+
7
+ This library is free software; you can redistribute it and/or
8
+ modify it under the terms of the GNU Lesser General Public
9
+ License as published by the Free Software Foundation; either
10
+ version 2.1 of the License, or (at your option) any later version.
11
+
12
+ This library is distributed in the hope that it will be useful,
13
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15
+ Lesser General Public License for more details.
16
+
17
+ You should have received a copy of the GNU Lesser General Public
18
+ License along with this library; if not, write to the Free Software
19
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20
+ ***********************************************************************/
21
+
22
+ #include "InputFileStream.h"
23
+ #include "gzfilebuf.h"
24
+ #include <iostream>
25
+ #include <boost/algorithm/string/predicate.hpp>
26
+
27
+ using namespace std;
28
+ using namespace boost::algorithm;
29
+
30
+ namespace Moses
31
+ {
32
+ InputFileStream::InputFileStream(const std::string &filePath)
33
+ : std::istream(NULL)
34
+ , m_streambuf(NULL)
35
+ {
36
+ Open(filePath);
37
+ }
38
+
39
+ InputFileStream::~InputFileStream()
40
+ {
41
+ Close();
42
+ }
43
+
44
+ void InputFileStream::Open(const std::string &filePath)
45
+ {
46
+ if (ends_with(filePath, ".gz")) {
47
+ m_streambuf = new gzfilebuf(filePath.c_str());
48
+ } else {
49
+ std::filebuf* fb = new std::filebuf();
50
+ fb = fb->open(filePath.c_str(), std::ios::in);
51
+ if (! fb) {
52
+ cerr << "Can't read " << filePath.c_str() << endl;
53
+ exit(1);
54
+ }
55
+ m_streambuf = fb;
56
+ }
57
+ this->init(m_streambuf);
58
+ }
59
+
60
+ void InputFileStream::Close()
61
+ {
62
+ delete m_streambuf;
63
+ m_streambuf = NULL;
64
+ }
65
+
66
+
67
+ }
68
+
mosesdecoder/phrase-extract/lexical-reordering/InputFileStream.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // $Id: InputFileStream.h 2939 2010-02-24 11:15:44Z jfouet $
2
+
3
+ /***********************************************************************
4
+ Moses - factored phrase-based language decoder
5
+ Copyright (C) 2006 University of Edinburgh
6
+
7
+ This library is free software; you can redistribute it and/or
8
+ modify it under the terms of the GNU Lesser General Public
9
+ License as published by the Free Software Foundation; either
10
+ version 2.1 of the License, or (at your option) any later version.
11
+
12
+ This library is distributed in the hope that it will be useful,
13
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15
+ Lesser General Public License for more details.
16
+
17
+ You should have received a copy of the GNU Lesser General Public
18
+ License along with this library; if not, write to the Free Software
19
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20
+ ***********************************************************************/
21
+
22
+ #ifndef moses_InputFileStream_h
23
+ #define moses_InputFileStream_h
24
+
25
+ #include <cstdlib>
26
+ #include <fstream>
27
+ #include <string>
28
+
29
+ namespace Moses
30
+ {
31
+
32
+ /** Used in place of std::istream, can read zipped files if it ends in .gz
33
+ */
34
+ class InputFileStream : public std::istream
35
+ {
36
+ protected:
37
+ std::streambuf *m_streambuf;
38
+ public:
39
+
40
+ explicit InputFileStream(const std::string &filePath);
41
+ ~InputFileStream();
42
+
43
+ void Open(const std::string &filePath);
44
+ void Close();
45
+ };
46
+
47
+ }
48
+
49
+ #endif
mosesdecoder/phrase-extract/lexical-reordering/Jamfile ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ exe lexical-reordering-score : InputFileStream.cpp reordering_classes.cpp score.cpp ../OutputFileStream.cpp ../..//boost_iostreams ../..//boost_filesystem ../../util//kenutil ../..//z ;
2
+
mosesdecoder/phrase-extract/lexical-reordering/gzfilebuf.h ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef moses_gzfile_buf_h
2
+ #define moses_gzfile_buf_h
3
+
4
+ #include <stdexcept>
5
+ #include <streambuf>
6
+ #include <zlib.h>
7
+ #include <cstring>
8
+
9
+ class gzfilebuf : public std::streambuf
10
+ {
11
+ public:
12
+ gzfilebuf(const char *filename) {
13
+ _gzf = gzopen(filename, "rb");
14
+ if (!_gzf)
15
+ throw std::runtime_error("Could not open " + std::string(filename) + ".");
16
+ setg (_buff+sizeof(int), // beginning of putback area
17
+ _buff+sizeof(int), // read position
18
+ _buff+sizeof(int)); // end position
19
+ }
20
+ ~gzfilebuf() {
21
+ gzclose(_gzf);
22
+ }
23
+ protected:
24
+ virtual int_type overflow (int_type c) {
25
+ throw;
26
+ }
27
+
28
+ // write multiple characters
29
+ virtual
30
+ std::streamsize xsputn (const char* s,
31
+ std::streamsize num) {
32
+ throw;
33
+ }
34
+
35
+ virtual std::streampos seekpos ( std::streampos sp, std::ios_base::openmode which = std::ios_base::in | std::ios_base::out ) {
36
+ throw;
37
+ }
38
+
39
+ //read one character
40
+ virtual int_type underflow () {
41
+ // is read position before end of _buff?
42
+ if (gptr() < egptr()) {
43
+ return traits_type::to_int_type(*gptr());
44
+ }
45
+
46
+ /* process size of putback area
47
+ * - use number of characters read
48
+ * - but at most four
49
+ */
50
+ unsigned int numPutback = gptr() - eback();
51
+ if (numPutback > sizeof(int)) {
52
+ numPutback = sizeof(int);
53
+ }
54
+
55
+ /* copy up to four characters previously read into
56
+ * the putback _buff (area of first four characters)
57
+ */
58
+ std::memmove (_buff+(sizeof(int)-numPutback), gptr()-numPutback,
59
+ numPutback);
60
+
61
+ // read new characters
62
+ int num = gzread(_gzf, _buff+sizeof(int), _buffsize-sizeof(int));
63
+ if (num <= 0) {
64
+ // ERROR or EOF
65
+ return EOF;
66
+ }
67
+
68
+ // reset _buff pointers
69
+ setg (_buff+(sizeof(int)-numPutback), // beginning of putback area
70
+ _buff+sizeof(int), // read position
71
+ _buff+sizeof(int)+num); // end of buffer
72
+
73
+ // return next character
74
+ return traits_type::to_int_type(*gptr());
75
+ }
76
+
77
+ std::streamsize xsgetn (char* s,
78
+ std::streamsize num) {
79
+ return gzread(_gzf,s,num);
80
+ }
81
+
82
+ private:
83
+ gzFile _gzf;
84
+ static const unsigned int _buffsize = 1024;
85
+ char _buff[_buffsize];
86
+ };
87
+
88
+ #endif
mosesdecoder/phrase-extract/lexical-reordering/reordering_classes.cpp ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <vector>
3
+ #include <iostream>
4
+ #include <cstdlib>
5
+ #include <numeric>
6
+ #include <cstdio>
7
+ #include <sstream>
8
+ #include <string>
9
+ #include "zlib.h"
10
+
11
+ #include "reordering_classes.h"
12
+
13
+ using namespace std;
14
+
15
+ ModelScore::ModelScore()
16
+ {
17
+ for(int i=MONO; i<=NOMONO; ++i) {
18
+ count_fe_prev.push_back(0);
19
+ count_fe_next.push_back(0);
20
+ count_f_prev.push_back(0);
21
+ count_f_next.push_back(0);
22
+ }
23
+ }
24
+
25
+ ModelScore::~ModelScore() {}
26
+
27
+ ModelScore* ModelScore::createModelScore(const string& modeltype)
28
+ {
29
+ if (modeltype.compare("mslr") == 0) {
30
+ return new ModelScoreMSLR();
31
+ } else if (modeltype.compare("msd") == 0) {
32
+ return new ModelScoreMSD();
33
+ } else if (modeltype.compare("monotonicity") == 0 ) {
34
+ return new ModelScoreMonotonicity();
35
+ } else if (modeltype.compare("leftright") == 0) {
36
+ return new ModelScoreLR();
37
+ } else {
38
+ cerr << "Illegal model type given for lexical reordering model scoring: "
39
+ << modeltype
40
+ << ". The allowed types are: mslr, msd, monotonicity, leftright"
41
+ << endl;
42
+ exit(1);
43
+ }
44
+ }
45
+
46
+ void ModelScore::reset_fe()
47
+ {
48
+ for(int i=MONO; i<=NOMONO; ++i) {
49
+ count_fe_prev[i] = 0;
50
+ count_fe_next[i] = 0;
51
+ }
52
+ }
53
+
54
+ void ModelScore::reset_f()
55
+ {
56
+ for(int i=MONO; i<=NOMONO; ++i) {
57
+ count_f_prev[i] = 0;
58
+ count_f_next[i] = 0;
59
+ }
60
+ }
61
+
62
+ void ModelScore::add_example
63
+ (const StringPiece& previous, const StringPiece& next, float weight)
64
+ {
65
+ count_fe_prev[getType(previous)]+=weight;
66
+ count_f_prev[getType(previous)]+=weight;
67
+ count_fe_next[getType(next)]+=weight;
68
+ count_f_next[getType(next)]+=weight;
69
+ }
70
+
71
+ const vector<double>& ModelScore::get_scores_fe_prev() const
72
+ {
73
+ return count_fe_prev;
74
+ }
75
+
76
+ const vector<double>& ModelScore::get_scores_fe_next() const
77
+ {
78
+ return count_fe_next;
79
+ }
80
+
81
+ const vector<double>& ModelScore::get_scores_f_prev() const
82
+ {
83
+ return count_f_prev;
84
+ }
85
+
86
+ const vector<double>& ModelScore::get_scores_f_next() const
87
+ {
88
+ return count_f_next;
89
+ }
90
+
91
+
92
+ ORIENTATION ModelScore::getType(const StringPiece& s)
93
+ {
94
+ if (s.compare("mono") == 0) {
95
+ return MONO;
96
+ } else if (s.compare("swap") == 0) {
97
+ return SWAP;
98
+ } else if (s.compare("dright") == 0) {
99
+ return DRIGHT;
100
+ } else if (s.compare("dleft") == 0) {
101
+ return DLEFT;
102
+ } else if (s.compare("other") == 0) {
103
+ return OTHER;
104
+ } else if (s.compare("nomono") == 0) {
105
+ return NOMONO;
106
+ } else {
107
+ cerr << "Illegal reordering type used: " << s << endl;
108
+ exit(1);
109
+ }
110
+ }
111
+
112
+
113
+ ORIENTATION ModelScoreMSLR::getType(const StringPiece& s)
114
+ {
115
+ if (s.compare("mono") == 0) {
116
+ return MONO;
117
+ } else if (s.compare("swap") == 0) {
118
+ return SWAP;
119
+ } else if (s.compare("dright") == 0) {
120
+ return DRIGHT;
121
+ } else if (s.compare("dleft") == 0) {
122
+ return DLEFT;
123
+ } else if (s.compare("other") == 0 || s.compare("nomono") == 0) {
124
+ cerr << "Illegal reordering type used: " << s << " for model type mslr. You have to re-run step 5 in order to train such a model." << endl;
125
+ exit(1);
126
+ } else {
127
+ cerr << "Illegal reordering type used: " << s << endl;
128
+ exit(1);
129
+ }
130
+ }
131
+
132
+
133
+ ORIENTATION ModelScoreLR::getType(const StringPiece& s)
134
+ {
135
+ if (s.compare("mono") == 0 || s.compare("dright") == 0) {
136
+ return DRIGHT;
137
+ } else if (s.compare("swap") == 0 || s.compare("dleft") == 0) {
138
+ return DLEFT;
139
+ } else if (s.compare("other") == 0 || s.compare("nomono") == 0) {
140
+ cerr << "Illegal reordering type used: " << s << " for model type LeftRight. You have to re-run step 5 in order to train such a model." << endl;
141
+ exit(1);
142
+ } else {
143
+ cerr << "Illegal reordering type used: " << s << endl;
144
+ exit(1);
145
+ }
146
+ }
147
+
148
+
149
+ ORIENTATION ModelScoreMSD::getType(const StringPiece& s)
150
+ {
151
+ if (s.compare("mono") == 0) {
152
+ return MONO;
153
+ } else if (s.compare("swap") == 0) {
154
+ return SWAP;
155
+ } else if (s.compare("dleft") == 0 ||
156
+ s.compare("dright") == 0 ||
157
+ s.compare("other") == 0) {
158
+ return OTHER;
159
+ } else if (s.compare("nomono") == 0) {
160
+ cerr << "Illegal reordering type used: " << s << " for model type msd. You have to re-run step 5 in order to train such a model." << endl;
161
+ exit(1);
162
+ } else {
163
+ cerr << "Illegal reordering type used: " << s << endl;
164
+ exit(1);
165
+ }
166
+ }
167
+
168
+ ORIENTATION ModelScoreMonotonicity::getType(const StringPiece& s)
169
+ {
170
+ if (s.compare("mono") == 0) {
171
+ return MONO;
172
+ } else if (s.compare("swap") == 0 ||
173
+ s.compare("dleft") == 0 ||
174
+ s.compare("dright") == 0 ||
175
+ s.compare("other") == 0 ||
176
+ s.compare("nomono") == 0 ) {
177
+ return NOMONO;
178
+ } else {
179
+ cerr << "Illegal reordering type used: " << s << endl;
180
+ exit(1);
181
+ }
182
+ }
183
+
184
+
185
+
186
+ void ScorerMSLR::score(const vector<double>& all_scores, vector<double>& scores) const
187
+ {
188
+ scores.push_back(all_scores[MONO]);
189
+ scores.push_back(all_scores[SWAP]);
190
+ scores.push_back(all_scores[DLEFT]);
191
+ scores.push_back(all_scores[DRIGHT]);
192
+ }
193
+
194
+ void ScorerMSD::score(const vector<double>& all_scores, vector<double>& scores) const
195
+ {
196
+ scores.push_back(all_scores[MONO]);
197
+ scores.push_back(all_scores[SWAP]);
198
+ scores.push_back(all_scores[DRIGHT]+all_scores[DLEFT]+all_scores[OTHER]);
199
+ }
200
+
201
+ void ScorerMonotonicity::score(const vector<double>& all_scores, vector<double>& scores) const
202
+ {
203
+ scores.push_back(all_scores[MONO]);
204
+ scores.push_back(all_scores[SWAP]+all_scores[DRIGHT]+all_scores[DLEFT]+all_scores[OTHER]+all_scores[NOMONO]);
205
+ }
206
+
207
+
208
+ void ScorerLR::score(const vector<double>& all_scores, vector<double>& scores) const
209
+ {
210
+ scores.push_back(all_scores[MONO]+all_scores[DRIGHT]);
211
+ scores.push_back(all_scores[SWAP]+all_scores[DLEFT]);
212
+ }
213
+
214
+
215
+ void ScorerMSLR::createSmoothing(const vector<double>& scores, double weight, vector<double>& smoothing) const
216
+ {
217
+ double total = accumulate(scores.begin(), scores.end(), 0);
218
+ smoothing.push_back(weight*(scores[MONO]+0.1)/total);
219
+ smoothing.push_back(weight*(scores[SWAP]+0.1)/total);
220
+ smoothing.push_back(weight*(scores[DLEFT]+0.1)/total);
221
+ smoothing.push_back(weight*(scores[DRIGHT]+0.1)/total);
222
+ }
223
+
224
+ void ScorerMSLR::createConstSmoothing(double weight, vector<double>& smoothing) const
225
+ {
226
+ for (int i=1; i<=4; ++i) {
227
+ smoothing.push_back(weight);
228
+ }
229
+ }
230
+
231
+
232
+ void ScorerMSD::createSmoothing(const vector<double>& scores, double weight, vector<double>& smoothing) const
233
+ {
234
+ double total = accumulate(scores.begin(), scores.end(), 0);
235
+ smoothing.push_back(weight*(scores[MONO]+0.1)/total);
236
+ smoothing.push_back(weight*(scores[SWAP]+0.1)/total);
237
+ smoothing.push_back(weight*(scores[DLEFT]+scores[DRIGHT]+scores[OTHER]+0.1)/total);
238
+ }
239
+
240
+ void ScorerMSD::createConstSmoothing(double weight, vector<double>& smoothing) const
241
+ {
242
+ for (int i=1; i<=3; ++i) {
243
+ smoothing.push_back(weight);
244
+ }
245
+ }
246
+
247
+ void ScorerMonotonicity::createSmoothing(const vector<double>& scores, double weight, vector<double>& smoothing) const
248
+ {
249
+ double total = accumulate(scores.begin(), scores.end(), 0);
250
+ smoothing.push_back(weight*(scores[MONO]+0.1)/total);
251
+ smoothing.push_back(weight*(scores[SWAP]+scores[DLEFT]+scores[DRIGHT]+scores[OTHER]+scores[NOMONO]+0.1)/total);
252
+ }
253
+
254
+ void ScorerMonotonicity::createConstSmoothing(double weight, vector<double>& smoothing) const
255
+ {
256
+ for (double i=1; i<=2; ++i) {
257
+ smoothing.push_back(weight);
258
+ }
259
+ }
260
+
261
+
262
+ void ScorerLR::createSmoothing(const vector<double>& scores, double weight, vector<double>& smoothing) const
263
+ {
264
+ double total = accumulate(scores.begin(), scores.end(), 0);
265
+ smoothing.push_back(weight*(scores[MONO]+scores[DRIGHT]+0.1)/total);
266
+ smoothing.push_back(weight*(scores[SWAP]+scores[DLEFT])/total);
267
+ }
268
+
269
+ void ScorerLR::createConstSmoothing(double weight, vector<double>& smoothing) const
270
+ {
271
+ for (int i=1; i<=2; ++i) {
272
+ smoothing.push_back(weight);
273
+ }
274
+ }
275
+
276
+ void Model::score_fe(const string& f, const string& e)
277
+ {
278
+ if (!fe) //Make sure we do not do anything if it is not a fe model
279
+ return;
280
+ outputFile << f << " ||| " << e << " |||";
281
+ //condition on the previous phrase
282
+ if (previous) {
283
+ vector<double> scores;
284
+ scorer->score(modelscore->get_scores_fe_prev(), scores);
285
+ double sum = 0;
286
+ for(size_t i=0; i<scores.size(); ++i) {
287
+ scores[i] += smoothing_prev[i];
288
+ sum += scores[i];
289
+ }
290
+ for(size_t i=0; i<scores.size(); ++i) {
291
+ outputFile << " " << (scores[i]/sum);
292
+ }
293
+ }
294
+ //condition on the next phrase
295
+ if (next) {
296
+ vector<double> scores;
297
+ scorer->score(modelscore->get_scores_fe_next(), scores);
298
+ double sum = 0;
299
+ for(size_t i=0; i<scores.size(); ++i) {
300
+ scores[i] += smoothing_next[i];
301
+ sum += scores[i];
302
+ }
303
+ for(size_t i=0; i<scores.size(); ++i) {
304
+ outputFile << " " << (scores[i]/sum);
305
+ }
306
+ }
307
+ outputFile << endl;
308
+ }
309
+
310
+ void Model::score_f(const string& f)
311
+ {
312
+ if (fe) //Make sure we do not do anything if it is not a f model
313
+ return;
314
+ cout << f << " |||";
315
+ //condition on the previous phrase
316
+ if (previous) {
317
+ vector<double> scores;
318
+ scorer->score(modelscore->get_scores_f_prev(), scores);
319
+ double sum = 0;
320
+ for(size_t i=0; i<scores.size(); ++i) {
321
+ scores[i] += smoothing_prev[i];
322
+ sum += scores[i];
323
+ }
324
+ for(size_t i=0; i<scores.size(); ++i) {
325
+ outputFile << " " << (scores[i]/sum);
326
+ }
327
+ }
328
+ //condition on the next phrase
329
+ if (next) {
330
+ vector<double> scores;
331
+ scorer->score(modelscore->get_scores_f_next(), scores);
332
+ double sum = 0;
333
+ for(size_t i=0; i<scores.size(); ++i) {
334
+ scores[i] += smoothing_next[i];
335
+ sum += scores[i];
336
+ }
337
+ for(size_t i=0; i<scores.size(); ++i) {
338
+ outputFile << " " << (scores[i]/sum);
339
+ }
340
+ }
341
+ outputFile << endl;
342
+ }
343
+
344
+ Model::Model(ModelScore* ms, Scorer* sc, const string& dir, const string& lang, const string& fn)
345
+ : modelscore(ms), scorer(sc), filename(fn)
346
+ {
347
+ outputFile.Open( (filename+".gz").c_str() );
348
+ fe = false;
349
+ if (lang.compare("fe") == 0) {
350
+ fe = true;
351
+ } else if (lang.compare("f") != 0) {
352
+ cerr << "You have given an illegal language to condition on: " << lang
353
+ << "\nLegal types: fe (on both languages), f (only on source language)\n";
354
+ exit(1);
355
+ }
356
+
357
+ previous = true;
358
+ next = true;
359
+ if (dir.compare("backward") == 0) {
360
+ next = false;
361
+ } else if (dir.compare("forward") == 0) {
362
+ previous = false;
363
+ }
364
+ }
365
+
366
+ Model::~Model()
367
+ {
368
+ outputFile.Close();
369
+ delete modelscore;
370
+ delete scorer;
371
+ }
372
+
373
+ void Model::split_config(const string& config, string& dir, string& lang, string& orient)
374
+ {
375
+ istringstream is(config);
376
+ string type;
377
+ getline(is, type, '-');
378
+ getline(is, orient, '-');
379
+ getline(is, dir, '-');
380
+ getline(is, lang, '-');
381
+ }
382
+
383
+ Model* Model::createModel(ModelScore* modelscore, const string& config, const string& filepath)
384
+ {
385
+ string dir, lang, orient, filename;
386
+ split_config(config,dir,lang,orient);
387
+
388
+ filename = filepath + config;
389
+ if (orient.compare("mslr") == 0) {
390
+ return new Model(modelscore, new ScorerMSLR(), dir, lang, filename);
391
+ } else if (orient.compare("msd") == 0) {
392
+ return new Model(modelscore, new ScorerMSD(), dir, lang, filename);
393
+ } else if (orient.compare("monotonicity") == 0) {
394
+ return new Model(modelscore, new ScorerMonotonicity(), dir, lang, filename);
395
+ } else if (orient.compare("leftright") == 0) {
396
+ return new Model(modelscore, new ScorerLR(), dir, lang, filename);
397
+ } else {
398
+ cerr << "Illegal orientation type of reordering model: " << orient
399
+ << "\n allowed types: mslr, msd, monotonicity, leftright\n";
400
+ exit(1);
401
+ }
402
+ }
403
+
404
+
405
+
406
+ void Model::createSmoothing(double w)
407
+ {
408
+ scorer->createSmoothing(modelscore->get_scores_fe_prev(), w, smoothing_prev);
409
+ scorer->createSmoothing(modelscore->get_scores_fe_next(), w, smoothing_next);
410
+ }
411
+
412
+ void Model::createConstSmoothing(double w)
413
+ {
414
+ scorer->createConstSmoothing(w, smoothing_prev);
415
+ scorer->createConstSmoothing(w, smoothing_next);
416
+ }
mosesdecoder/phrase-extract/lexical-reordering/reordering_classes.h ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * reordering_classes.h
3
+ * Utility classes for lexical reordering table scoring
4
+ *
5
+ * Created by: Sara Stymne - Linköping University
6
+ * Machine Translation Marathon 2010, Dublin
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <vector>
12
+ #include <string>
13
+ #include <fstream>
14
+
15
+ #include "util/string_piece.hh"
16
+ #include "../OutputFileStream.h"
17
+
18
+ enum ORIENTATION {MONO, SWAP, DRIGHT, DLEFT, OTHER, NOMONO};
19
+
20
+
21
+ //Keeps the counts for the different reordering types
22
+ //(Instantiated in 1-3 instances, one for each type of model (hier, phrase, wbe))
23
+ class ModelScore
24
+ {
25
+ private:
26
+ std::vector<double> count_fe_prev;
27
+ std::vector<double> count_fe_next;
28
+ std::vector<double> count_f_prev;
29
+ std::vector<double> count_f_next;
30
+
31
+ protected:
32
+ virtual ORIENTATION getType(const StringPiece& s);
33
+
34
+ public:
35
+ ModelScore();
36
+ virtual ~ModelScore();
37
+ void add_example(const StringPiece& previous, const StringPiece& next, float weight);
38
+ void reset_fe();
39
+ void reset_f();
40
+ const std::vector<double>& get_scores_fe_prev() const;
41
+ const std::vector<double>& get_scores_fe_next() const;
42
+ const std::vector<double>& get_scores_f_prev() const;
43
+ const std::vector<double>& get_scores_f_next() const;
44
+
45
+ static ModelScore* createModelScore(const std::string& modeltype);
46
+ };
47
+
48
+ class ModelScoreMSLR : public ModelScore
49
+ {
50
+ protected:
51
+ virtual ORIENTATION getType(const StringPiece& s);
52
+ };
53
+
54
+ class ModelScoreLR : public ModelScore
55
+ {
56
+ protected:
57
+ virtual ORIENTATION getType(const StringPiece& s);
58
+ };
59
+
60
+ class ModelScoreMSD : public ModelScore
61
+ {
62
+ protected:
63
+ virtual ORIENTATION getType(const StringPiece& s);
64
+ };
65
+
66
+ class ModelScoreMonotonicity : public ModelScore
67
+ {
68
+ protected:
69
+ virtual ORIENTATION getType(const StringPiece& s);
70
+ };
71
+
72
+ //Class for calculating total counts, and to calculate smoothing
73
+ class Scorer
74
+ {
75
+ public:
76
+ virtual ~Scorer() {}
77
+ virtual void score(const std::vector<double>&, std::vector<double>&) const = 0;
78
+ virtual void createSmoothing(const std::vector<double>&, double, std::vector<double>&) const = 0;
79
+ virtual void createConstSmoothing(double, std::vector<double>&) const = 0;
80
+ };
81
+
82
+ class ScorerMSLR : public Scorer
83
+ {
84
+ public:
85
+ virtual void score(const std::vector<double>&, std::vector<double>&) const;
86
+ virtual void createSmoothing(const std::vector<double>&, double, std::vector<double>&) const;
87
+ virtual void createConstSmoothing(double, std::vector<double>&) const;
88
+ };
89
+
90
+ class ScorerMSD : public Scorer
91
+ {
92
+ public:
93
+ virtual void score(const std::vector<double>&, std::vector<double>&) const;
94
+ virtual void createSmoothing(const std::vector<double>&, double, std::vector<double>&) const;
95
+ virtual void createConstSmoothing(double, std::vector<double>&) const;
96
+ };
97
+
98
+ class ScorerMonotonicity : public Scorer
99
+ {
100
+ public:
101
+ virtual void score(const std::vector<double>&, std::vector<double>&) const;
102
+ virtual void createSmoothing(const std::vector<double>&, double, std::vector<double>&) const;
103
+ virtual void createConstSmoothing(double, std::vector<double>&) const;
104
+ };
105
+
106
+ class ScorerLR : public Scorer
107
+ {
108
+ public:
109
+ virtual void score(const std::vector<double>&, std::vector<double>&) const;
110
+ virtual void createSmoothing(const std::vector<double>&, double, std::vector<double>&) const;
111
+ virtual void createConstSmoothing(double, std::vector<double>&) const;
112
+ };
113
+
114
+
115
+ //Class for representing each model
116
+ //Contains a modelscore and scorer (which can be of different model types (mslr, msd...)),
117
+ //and file handling.
118
+ //This class also keeps track of bidirectionality, and which language to condition on
119
+ class Model
120
+ {
121
+ private:
122
+ ModelScore* modelscore;
123
+ Scorer* scorer;
124
+
125
+ std::string filename;
126
+ Moses::OutputFileStream outputFile;
127
+
128
+ bool fe;
129
+ bool previous;
130
+ bool next;
131
+
132
+ std::vector<double> smoothing_prev;
133
+ std::vector<double> smoothing_next;
134
+
135
+ static void split_config(const std::string& config, std::string& dir,
136
+ std::string& lang, std::string& orient);
137
+ public:
138
+ Model(ModelScore* ms, Scorer* sc, const std::string& dir,
139
+ const std::string& lang, const std::string& fn);
140
+ ~Model();
141
+ static Model* createModel(ModelScore*, const std::string&, const std::string&);
142
+ void createSmoothing(double w);
143
+ void createConstSmoothing(double w);
144
+ void score_fe(const std::string& f, const std::string& e);
145
+ void score_f(const std::string& f);
146
+ void zipFile();
147
+ };
148
+
mosesdecoder/phrase-extract/lexical-reordering/score.cpp ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * score_reordering.cpp
3
+ *
4
+ * Created by: Sara Stymne - Linköping University
5
+ * Machine Translation Marathon 2010, Dublin
6
+ */
7
+
8
+ #include <string>
9
+ #include <vector>
10
+ #include <map>
11
+ #include <iostream>
12
+ #include <fstream>
13
+ #include <sstream>
14
+ #include <cstdlib>
15
+ #include <cstring>
16
+
17
+ #include "util/exception.hh"
18
+ #include "util/file_piece.hh"
19
+ #include "util/string_piece.hh"
20
+ #include "util/tokenize_piece.hh"
21
+
22
+ #include "InputFileStream.h"
23
+ #include "reordering_classes.h"
24
+
25
+ using namespace std;
26
+
27
+ void split_line(const StringPiece& line, StringPiece& foreign, StringPiece& english, StringPiece& wbe, StringPiece& phrase, StringPiece& hier, float& weight);
28
+ void get_orientations(const StringPiece& pair, StringPiece& previous, StringPiece& next);
29
+
30
+ class FileFormatException : public util::Exception
31
+ {
32
+ public:
33
+ FileFormatException() throw() {
34
+ *this << "Invalid extract file format: ";
35
+ }
36
+ ~FileFormatException() throw() {}
37
+ };
38
+
39
+ int main(int argc, char* argv[])
40
+ {
41
+
42
+ cerr << "Lexical Reordering Scorer\n"
43
+ << "scores lexical reordering models of several types (hierarchical, phrase-based and word-based-extraction\n";
44
+
45
+ if (argc < 3) {
46
+ cerr << "syntax: score_reordering extractFile smoothingValue filepath (--model \"type max-orientation (specification-strings)\" )+\n";
47
+ exit(1);
48
+ }
49
+
50
+ char* extractFileName = argv[1];
51
+ double smoothingValue = atof(argv[2]);
52
+ string filepath = argv[3];
53
+
54
+ util::FilePiece eFile(extractFileName);
55
+
56
+ bool smoothWithCounts = false;
57
+ map<string,ModelScore*> modelScores;
58
+ vector<Model*> models;
59
+ bool hier = false;
60
+ bool phrase = false;
61
+ bool wbe = false;
62
+
63
+ StringPiece e,f,w,p,h;
64
+ StringPiece prev, next;
65
+
66
+ int i = 4;
67
+ while (i<argc) {
68
+ if (strcmp(argv[i],"--SmoothWithCounts") == 0) {
69
+ smoothWithCounts = true;
70
+ } else if (strcmp(argv[i],"--model") == 0) {
71
+ if (i+1 >= argc) {
72
+ cerr << "score: syntax error, no model information provided to the option" << argv[i] << endl;
73
+ exit(1);
74
+ }
75
+ istringstream is(argv[++i]);
76
+ string m,t;
77
+ is >> m >> t;
78
+ modelScores[m] = ModelScore::createModelScore(t);
79
+ if (m.compare("hier") == 0) {
80
+ hier = true;
81
+ } else if (m.compare("phrase") == 0) {
82
+ phrase = true;
83
+ }
84
+ if (m.compare("wbe") == 0) {
85
+ wbe = true;
86
+ }
87
+
88
+ if (!hier && !phrase && !wbe) {
89
+ cerr << "WARNING: No models specified for lexical reordering. No lexical reordering table will be trained.\n";
90
+ return 0;
91
+ }
92
+
93
+ string config;
94
+ //Store all models
95
+ while (is >> config) {
96
+ models.push_back(Model::createModel(modelScores[m],config,filepath));
97
+ }
98
+ } else {
99
+ cerr << "illegal option given to lexical reordering model score\n";
100
+ exit(1);
101
+ }
102
+ i++;
103
+ }
104
+
105
+ ////////////////////////////////////
106
+ //calculate smoothing
107
+ if (smoothWithCounts) {
108
+ util::FilePiece eFileForCounts(extractFileName);
109
+ while (true) {
110
+ StringPiece line;
111
+ try {
112
+ line = eFileForCounts.ReadLine();
113
+ } catch (util::EndOfFileException &e) {
114
+ break;
115
+ }
116
+ float weight = 1;
117
+ split_line(line,e,f,w,p,h,weight);
118
+ if (hier) {
119
+ get_orientations(h, prev, next);
120
+ modelScores["hier"]->add_example(prev,next,weight);
121
+ }
122
+ if (phrase) {
123
+ get_orientations(p, prev, next);
124
+ modelScores["phrase"]->add_example(prev,next,weight);
125
+ }
126
+ if (wbe) {
127
+ get_orientations(w, prev, next);
128
+ modelScores["wbe"]->add_example(prev,next,weight);
129
+ }
130
+ }
131
+
132
+ // calculate smoothing for each model
133
+ for (size_t i=0; i<models.size(); ++i) {
134
+ models[i]->createSmoothing(smoothingValue);
135
+ }
136
+
137
+ } else {
138
+ //constant smoothing
139
+ for (size_t i=0; i<models.size(); ++i) {
140
+ models[i]->createConstSmoothing(smoothingValue);
141
+ }
142
+ }
143
+
144
+ ////////////////////////////////////
145
+ //calculate scores for reordering table
146
+ string f_current,e_current;
147
+ bool first = true;
148
+ while (true) {
149
+ StringPiece line;
150
+ try {
151
+ line = eFile.ReadLine();
152
+ } catch (util::EndOfFileException &e) {
153
+ break;
154
+ }
155
+ float weight = 1;
156
+ split_line(line,f,e,w,p,h,weight);
157
+
158
+ if (first) {
159
+ f_current = f.as_string(); //FIXME: Avoid the copy.
160
+ e_current = e.as_string();
161
+ first = false;
162
+ } else if (f.compare(f_current) != 0 || e.compare(e_current) != 0) {
163
+ //fe - score
164
+ for (size_t i=0; i<models.size(); ++i) {
165
+ models[i]->score_fe(f_current,e_current);
166
+ }
167
+ //reset
168
+ for(map<string,ModelScore*>::const_iterator it = modelScores.begin(); it != modelScores.end(); ++it) {
169
+ it->second->reset_fe();
170
+ }
171
+
172
+ if (f.compare(f_current) != 0) {
173
+ //f - score
174
+ for (size_t i=0; i<models.size(); ++i) {
175
+ models[i]->score_f(f_current);
176
+ }
177
+ //reset
178
+ for(map<string,ModelScore*>::const_iterator it = modelScores.begin(); it != modelScores.end(); ++it) {
179
+ it->second->reset_f();
180
+ }
181
+ }
182
+ f_current = f.as_string();
183
+ e_current = e.as_string();
184
+ }
185
+
186
+ // uppdate counts
187
+ if (hier) {
188
+ get_orientations(h, prev, next);
189
+ modelScores["hier"]->add_example(prev,next,weight);
190
+ }
191
+ if (phrase) {
192
+ get_orientations(p, prev, next);
193
+ modelScores["phrase"]->add_example(prev,next,weight);
194
+ }
195
+ if (wbe) {
196
+ get_orientations(w, prev, next);
197
+ modelScores["wbe"]->add_example(prev,next,weight);
198
+ }
199
+ }
200
+ //Score the last phrases
201
+ for (size_t i=0; i<models.size(); ++i) {
202
+ models[i]->score_fe(f_current,e_current);
203
+ }
204
+ for (size_t i=0; i<models.size(); ++i) {
205
+ models[i]->score_f(f_current);
206
+ }
207
+
208
+ // delete model objects (and close files)
209
+ for (size_t i=0; i<models.size(); ++i) {
210
+ delete models[i];
211
+ }
212
+ return 0;
213
+ }
214
+
215
+ template <class It> StringPiece
216
+ GrabOrDie(It &it, const StringPiece& line)
217
+ {
218
+ UTIL_THROW_IF(!it, FileFormatException, line.as_string());
219
+ return *it++;
220
+ }
221
+
222
+
223
+ void split_line(
224
+ const StringPiece& line,
225
+ StringPiece& foreign,
226
+ StringPiece& english,
227
+ StringPiece& wbe,
228
+ StringPiece& phrase,
229
+ StringPiece& hier,
230
+ float& weight)
231
+ {
232
+ /*Format is source ||| target ||| orientations
233
+ followed by one of the following 4 possibilities
234
+ eps
235
+ ||| weight
236
+ | phrase | hier
237
+ | phrase | hier ||| weight
238
+ */
239
+
240
+ util::TokenIter<util::MultiCharacter> pipes(line, util::MultiCharacter(" ||| "));
241
+ foreign = GrabOrDie(pipes,line);
242
+ english = GrabOrDie(pipes,line);
243
+ StringPiece next = GrabOrDie(pipes,line);
244
+
245
+ util::TokenIter<util::MultiCharacter> singlePipe(next, util::MultiCharacter(" | "));
246
+ wbe = GrabOrDie(singlePipe,line);
247
+ if (singlePipe) {
248
+ phrase = GrabOrDie(singlePipe, line);
249
+ hier = GrabOrDie(singlePipe, line);
250
+ } else {
251
+ phrase.clear();
252
+ hier.clear();
253
+ }
254
+
255
+ if (pipes) {
256
+ // read the weight
257
+ char* errIndex;
258
+ next = *pipes++;
259
+ weight = static_cast<float>(strtod(next.data(), &errIndex));
260
+ UTIL_THROW_IF(errIndex == next.data(), FileFormatException, line.as_string());
261
+ }
262
+ }
263
+
264
+ void get_orientations(const StringPiece& pair, StringPiece& previous, StringPiece& next)
265
+ {
266
+ util::TokenIter<util::SingleCharacter> tok(pair, util::SingleCharacter(' '));
267
+ previous = GrabOrDie(tok,pair);
268
+ next = GrabOrDie(tok,pair);
269
+ }
mosesdecoder/phrase-extract/pcfg-extract/Jamfile ADDED
@@ -0,0 +1 @@
 
 
1
+ exe pcfg-extract : [ glob *.cc ] ..//syntax-common ../..//boost_program_options : <include>.. ;
mosesdecoder/phrase-extract/pcfg-extract/options.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***********************************************************************
2
+ Moses - statistical machine translation system
3
+ Copyright (C) 2006-2012 University of Edinburgh
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18
+ ***********************************************************************/
19
+
20
+ #pragma once
21
+ #ifndef PCFG_EXTRACT_OPTIONS_H_
22
+ #define PCFG_EXTRACT_OPTIONS_H_
23
+
24
+ #include <string>
25
+
26
+ namespace MosesTraining
27
+ {
28
+ namespace Syntax
29
+ {
30
+ namespace PCFG
31
+ {
32
+
33
+ struct Options {
34
+ std::string corpus_file;
35
+ };
36
+
37
+ } // namespace PCFG
38
+ } // namespace Syntax
39
+ } // namespace MosesTraining
40
+
41
+ #endif
mosesdecoder/phrase-extract/pcfg-extract/pcfg_extract.cc ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***********************************************************************
2
+ Moses - statistical machine translation system
3
+ Copyright (C) 2006-2012 University of Edinburgh
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18
+ ***********************************************************************/
19
+
20
+ #include "pcfg_extract.h"
21
+
22
+ #include <cassert>
23
+ #include <cstdlib>
24
+ #include <fstream>
25
+ #include <iostream>
26
+ #include <map>
27
+ #include <memory>
28
+ #include <set>
29
+ #include <string>
30
+ #include <vector>
31
+
32
+ #include <boost/program_options.hpp>
33
+
34
+ #include "syntax-common/exception.h"
35
+ #include "syntax-common/pcfg.h"
36
+ #include "syntax-common/vocabulary.h"
37
+ #include "syntax-common/xml_tree_parser.h"
38
+
39
+ #include "SyntaxTree.h"
40
+
41
+ #include "options.h"
42
+ #include "rule_collection.h"
43
+ #include "rule_extractor.h"
44
+
45
+ namespace MosesTraining
46
+ {
47
+ namespace Syntax
48
+ {
49
+ namespace PCFG
50
+ {
51
+
52
+ int PcfgExtract::Main(int argc, char *argv[])
53
+ {
54
+ // Process command-line options.
55
+ Options options;
56
+ ProcessOptions(argc, argv, options);
57
+
58
+ // Extract PCFG rules from corpus.
59
+ Vocabulary non_term_vocab;
60
+ RuleExtractor rule_extractor(non_term_vocab);
61
+ RuleCollection rule_collection;
62
+ XmlTreeParser parser;
63
+ std::string line;
64
+ std::size_t line_num = 0;
65
+ std::auto_ptr<MosesTraining::SyntaxTree> tree;
66
+ while (std::getline(std::cin, line)) {
67
+ ++line_num;
68
+ try {
69
+ tree = parser.Parse(line);
70
+ } catch (Exception &e) {
71
+ std::ostringstream msg;
72
+ msg << "line " << line_num << ": " << e.msg();
73
+ Error(msg.str());
74
+ }
75
+ if (!tree.get()) {
76
+ std::ostringstream msg;
77
+ msg << "no tree at line " << line_num;
78
+ Warn(msg.str());
79
+ continue;
80
+ }
81
+ rule_extractor.Extract(*tree, rule_collection);
82
+ }
83
+
84
+ // Score rules and write PCFG to output.
85
+ Pcfg pcfg;
86
+ rule_collection.CreatePcfg(pcfg);
87
+ pcfg.Write(non_term_vocab, std::cout);
88
+
89
+ return 0;
90
+ }
91
+
92
+ void PcfgExtract::ProcessOptions(int argc, char *argv[],
93
+ Options &options) const
94
+ {
95
+ namespace po = boost::program_options;
96
+
97
+ std::ostringstream usage_top;
98
+ usage_top << "Usage: " << name() << "\n\n" << "Options";
99
+
100
+ // Declare the command line options that are visible to the user.
101
+ po::options_description visible(usage_top.str());
102
+ visible.add_options()
103
+ ("help", "print help message and exit")
104
+ ;
105
+
106
+ // Declare the command line options that are hidden from the user
107
+ // (these are used as positional options).
108
+ po::options_description hidden("Hidden options");
109
+ hidden.add_options();
110
+
111
+ // Compose the full set of command-line options.
112
+ po::options_description cmd_line_options;
113
+ cmd_line_options.add(visible).add(hidden);
114
+
115
+ // Register the positional options.
116
+ po::positional_options_description p;
117
+
118
+ // Process the command-line.
119
+ po::variables_map vm;
120
+ try {
121
+ po::store(po::command_line_parser(argc, argv).style(MosesOptionStyle()).
122
+ options(cmd_line_options).positional(p).run(), vm);
123
+ po::notify(vm);
124
+ } catch (const std::exception &e) {
125
+ std::ostringstream msg;
126
+ msg << e.what() << "\n\n" << visible;
127
+ Error(msg.str());
128
+ }
129
+
130
+ if (vm.count("help")) {
131
+ std::cout << visible << std::endl;
132
+ std::exit(0);
133
+ }
134
+ }
135
+
136
+ } // namespace PCFG
137
+ } // namespace Syntax
138
+ } // namespace MosesTraining
mosesdecoder/phrase-extract/pcfg-extract/pcfg_extract.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***********************************************************************
2
+ Moses - statistical machine translation system
3
+ Copyright (C) 2006-2012 University of Edinburgh
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18
+ ***********************************************************************/
19
+
20
+ #pragma once
21
+ #ifndef PCFG_EXTRACT_PCFG_EXTRACT_H_
22
+ #define PCFG_EXTRACT_PCFG_EXTRACT_H_
23
+
24
+ #include "syntax-common/tool.h"
25
+
26
+ namespace MosesTraining
27
+ {
28
+ namespace Syntax
29
+ {
30
+ namespace PCFG
31
+ {
32
+
33
+ struct Options;
34
+
35
+ class PcfgExtract : public Tool
36
+ {
37
+ public:
38
+ PcfgExtract() : Tool("pcfg-extract") {}
39
+ virtual int Main(int, char *[]);
40
+ private:
41
+ void ProcessOptions(int, char *[], Options &) const;
42
+ };
43
+
44
+ } // namespace PCFG
45
+ } // namespace Syntax
46
+ } // namespace MosesTraining
47
+
48
+ #endif
mosesdecoder/phrase-extract/pcfg-extract/rule_collection.h ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***********************************************************************
2
+ Moses - statistical machine translation system
3
+ Copyright (C) 2006-2012 University of Edinburgh
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18
+ ***********************************************************************/
19
+
20
+ #pragma once
21
+ #ifndef PCFG_EXTRACT_RULE_COLLECTION_H_
22
+ #define PCFG_EXTRACT_RULE_COLLECTION_H_
23
+
24
+ #include <vector>
25
+
26
+ #include <boost/unordered_map.hpp>
27
+
28
+ #include "syntax-common/pcfg.h"
29
+
30
+ namespace MosesTraining
31
+ {
32
+ namespace Syntax
33
+ {
34
+ namespace PCFG
35
+ {
36
+
37
+ // Contains PCFG rules and their counts.
38
+ class RuleCollection
39
+ {
40
+ public:
41
+ typedef boost::unordered_map<std::vector<std::size_t>, std::size_t> RhsCountMap;
42
+ typedef boost::unordered_map<std::size_t, RhsCountMap> Map;
43
+ typedef Map::iterator iterator;
44
+ typedef Map::const_iterator const_iterator;
45
+
46
+ RuleCollection() {}
47
+
48
+ iterator begin() {
49
+ return collection_.begin();
50
+ }
51
+ const_iterator begin() const {
52
+ return collection_.begin();
53
+ }
54
+
55
+ iterator end() {
56
+ return collection_.end();
57
+ }
58
+ const_iterator end() const {
59
+ return collection_.end();
60
+ }
61
+
62
+ void Add(std::size_t, const std::vector<std::size_t> &);
63
+ void CreatePcfg(Pcfg &);
64
+
65
+ private:
66
+ Map collection_;
67
+ };
68
+
69
+ } // namespace PCFG
70
+ } // namespace Synatx
71
+ } // namespace MosesTraining
72
+
73
+ #endif
mosesdecoder/phrase-extract/pcfg-extract/rule_extractor.h ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***********************************************************************
2
+ Moses - statistical machine translation system
3
+ Copyright (C) 2006-2012 University of Edinburgh
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18
+ ***********************************************************************/
19
+
20
+ #pragma once
21
+ #ifndef PCFG_EXTRACT_RULE_EXTRACTOR_H_
22
+ #define PCFG_EXTRACT_RULE_EXTRACTOR_H_
23
+
24
+ #include "SyntaxTree.h"
25
+
26
+ #include "syntax-common/vocabulary.h"
27
+
28
+ #include "rule_collection.h"
29
+
30
+ namespace MosesTraining
31
+ {
32
+ namespace Syntax
33
+ {
34
+ namespace PCFG
35
+ {
36
+
37
+ // Extracts PCFG rules from syntax trees and adds them to a RuleCollection.
38
+ class RuleExtractor
39
+ {
40
+ public:
41
+ RuleExtractor(Vocabulary &);
42
+ void Extract(const SyntaxTree &, RuleCollection &) const;
43
+ private:
44
+ Vocabulary &non_term_vocab_;
45
+ };
46
+
47
+ } // namespace PCFG
48
+ } // namespace Syntax
49
+ } // namespace MosesTraining
50
+
51
+ #endif