Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- fairseq-0.10.2/docs/conf.py +133 -0
- fairseq-0.10.2/docs/lr_scheduler.rst +34 -0
- fairseq-0.10.2/docs/requirements.txt +2 -0
- fairseq-0.10.2/docs/tutorial_classifying_names.rst +415 -0
- fairseq-0.10.2/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so +3 -0
- fairseq-0.10.2/fairseq_cli/__init__.py +0 -0
- fairseq-0.10.2/fairseq_cli/__pycache__/generate.cpython-310.pyc +0 -0
- fairseq-0.10.2/fairseq_cli/eval_lm.py +279 -0
- fairseq-0.10.2/fairseq_cli/train.py +356 -0
- mosesdecoder/phrase-extract/Alignment.cpp +70 -0
- mosesdecoder/phrase-extract/AlignmentPhrase.h +74 -0
- mosesdecoder/phrase-extract/DomainFeature.cpp +170 -0
- mosesdecoder/phrase-extract/DomainFeature.h +143 -0
- mosesdecoder/phrase-extract/HoleCollection.cpp +77 -0
- mosesdecoder/phrase-extract/HoleCollection.h +95 -0
- mosesdecoder/phrase-extract/InputFileStream.cpp +61 -0
- mosesdecoder/phrase-extract/InputFileStream.h +48 -0
- mosesdecoder/phrase-extract/InternalStructFeature.h +64 -0
- mosesdecoder/phrase-extract/OutputFileStream.h +81 -0
- mosesdecoder/phrase-extract/PhraseExtractionOptions.h +193 -0
- mosesdecoder/phrase-extract/RuleExtractionOptions.h +95 -0
- mosesdecoder/phrase-extract/ScoreFeature.cpp +114 -0
- mosesdecoder/phrase-extract/SyntaxTree.h +12 -0
- mosesdecoder/phrase-extract/consolidate-direct-main.cpp +131 -0
- mosesdecoder/phrase-extract/extract-lex.h +70 -0
- mosesdecoder/phrase-extract/filter-rule-table/CfgFilter.h +30 -0
- mosesdecoder/phrase-extract/filter-rule-table/FilterRuleTable.h +54 -0
- mosesdecoder/phrase-extract/filter-rule-table/Forest.h +59 -0
- mosesdecoder/phrase-extract/filter-rule-table/ForestTsgFilter.cpp +196 -0
- mosesdecoder/phrase-extract/filter-rule-table/ForestTsgFilter.h +70 -0
- mosesdecoder/phrase-extract/filter-rule-table/Jamfile +1 -0
- mosesdecoder/phrase-extract/filter-rule-table/StringCfgFilter.cpp +323 -0
- mosesdecoder/phrase-extract/filter-rule-table/StringCfgFilter.h +143 -0
- mosesdecoder/phrase-extract/filter-rule-table/StringForest.h +24 -0
- mosesdecoder/phrase-extract/filter-rule-table/TreeTsgFilter.h +55 -0
- mosesdecoder/phrase-extract/filter-rule-table/TsgFilter.h +55 -0
- mosesdecoder/phrase-extract/lexical-reordering/InputFileStream.cpp +68 -0
- mosesdecoder/phrase-extract/lexical-reordering/InputFileStream.h +49 -0
- mosesdecoder/phrase-extract/lexical-reordering/Jamfile +2 -0
- mosesdecoder/phrase-extract/lexical-reordering/gzfilebuf.h +88 -0
- mosesdecoder/phrase-extract/lexical-reordering/reordering_classes.cpp +416 -0
- mosesdecoder/phrase-extract/lexical-reordering/reordering_classes.h +148 -0
- mosesdecoder/phrase-extract/lexical-reordering/score.cpp +269 -0
- mosesdecoder/phrase-extract/pcfg-extract/Jamfile +1 -0
- mosesdecoder/phrase-extract/pcfg-extract/options.h +41 -0
- mosesdecoder/phrase-extract/pcfg-extract/pcfg_extract.cc +138 -0
- mosesdecoder/phrase-extract/pcfg-extract/pcfg_extract.h +48 -0
- mosesdecoder/phrase-extract/pcfg-extract/rule_collection.h +73 -0
- mosesdecoder/phrase-extract/pcfg-extract/rule_extractor.h +51 -0
.gitattributes
CHANGED
|
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 37 |
fairseq-0.10.2/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
fairseq-0.10.2/fairseq/libbleu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 37 |
fairseq-0.10.2/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
fairseq-0.10.2/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
fairseq-0.10.2/docs/conf.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
#
|
| 4 |
+
# fairseq documentation build configuration file, created by
|
| 5 |
+
# sphinx-quickstart on Fri Aug 17 21:45:30 2018.
|
| 6 |
+
#
|
| 7 |
+
# This file is execfile()d with the current directory set to its
|
| 8 |
+
# containing dir.
|
| 9 |
+
#
|
| 10 |
+
# Note that not all possible configuration values are present in this
|
| 11 |
+
# autogenerated file.
|
| 12 |
+
#
|
| 13 |
+
# All configuration values have a default; values that are commented out
|
| 14 |
+
# serve to show the default.
|
| 15 |
+
|
| 16 |
+
# If extensions (or modules to document with autodoc) are in another directory,
|
| 17 |
+
# add these directories to sys.path here. If the directory is relative to the
|
| 18 |
+
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# source code directory, relative to this file, for sphinx-autobuild
|
| 25 |
+
sys.path.insert(0, os.path.abspath(".."))
|
| 26 |
+
|
| 27 |
+
source_suffix = [".rst"]
|
| 28 |
+
|
| 29 |
+
# -- General configuration ------------------------------------------------
|
| 30 |
+
|
| 31 |
+
# If your documentation needs a minimal Sphinx version, state it here.
|
| 32 |
+
#
|
| 33 |
+
# needs_sphinx = '1.0'
|
| 34 |
+
|
| 35 |
+
# Add any Sphinx extension module names here, as strings. They can be
|
| 36 |
+
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
| 37 |
+
# ones.
|
| 38 |
+
extensions = [
|
| 39 |
+
"sphinx.ext.autodoc",
|
| 40 |
+
"sphinx.ext.intersphinx",
|
| 41 |
+
"sphinx.ext.viewcode",
|
| 42 |
+
"sphinx.ext.napoleon",
|
| 43 |
+
"sphinxarg.ext",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
# Add any paths that contain templates here, relative to this directory.
|
| 47 |
+
templates_path = ["_templates"]
|
| 48 |
+
|
| 49 |
+
# The master toctree document.
|
| 50 |
+
master_doc = "index"
|
| 51 |
+
|
| 52 |
+
# General information about the project.
|
| 53 |
+
project = "fairseq"
|
| 54 |
+
copyright = "2019, Facebook AI Research (FAIR)"
|
| 55 |
+
author = "Facebook AI Research (FAIR)"
|
| 56 |
+
|
| 57 |
+
github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/"
|
| 58 |
+
|
| 59 |
+
# The version info for the project you're documenting, acts as replacement for
|
| 60 |
+
# |version| and |release|, also used in various other places throughout the
|
| 61 |
+
# built documents.
|
| 62 |
+
#
|
| 63 |
+
# The short X.Y version.
|
| 64 |
+
version = "0.10.2"
|
| 65 |
+
# The full version, including alpha/beta/rc tags.
|
| 66 |
+
release = "0.10.2"
|
| 67 |
+
|
| 68 |
+
# The language for content autogenerated by Sphinx. Refer to documentation
|
| 69 |
+
# for a list of supported languages.
|
| 70 |
+
#
|
| 71 |
+
# This is also used if you do content translation via gettext catalogs.
|
| 72 |
+
# Usually you set "language" from the command line for these cases.
|
| 73 |
+
language = None
|
| 74 |
+
|
| 75 |
+
# List of patterns, relative to source directory, that match files and
|
| 76 |
+
# directories to ignore when looking for source files.
|
| 77 |
+
# This patterns also effect to html_static_path and html_extra_path
|
| 78 |
+
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
| 79 |
+
|
| 80 |
+
# The name of the Pygments (syntax highlighting) style to use.
|
| 81 |
+
pygments_style = "sphinx"
|
| 82 |
+
highlight_language = "python"
|
| 83 |
+
|
| 84 |
+
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
| 85 |
+
todo_include_todos = False
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# -- Options for HTML output ----------------------------------------------
|
| 89 |
+
|
| 90 |
+
# The theme to use for HTML and HTML Help pages. See the documentation for
|
| 91 |
+
# a list of builtin themes.
|
| 92 |
+
#
|
| 93 |
+
html_theme = "sphinx_rtd_theme"
|
| 94 |
+
|
| 95 |
+
# Theme options are theme-specific and customize the look and feel of a theme
|
| 96 |
+
# further. For a list of options available for each theme, see the
|
| 97 |
+
# documentation.
|
| 98 |
+
#
|
| 99 |
+
# html_theme_options = {}
|
| 100 |
+
|
| 101 |
+
# Add any paths that contain custom static files (such as style sheets) here,
|
| 102 |
+
# relative to this directory. They are copied after the builtin static files,
|
| 103 |
+
# so a file named "default.css" will overwrite the builtin "default.css".
|
| 104 |
+
html_static_path = ["_static"]
|
| 105 |
+
|
| 106 |
+
html_context = {
|
| 107 |
+
"css_files": [
|
| 108 |
+
"_static/theme_overrides.css", # override wide tables in RTD theme
|
| 109 |
+
],
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
# Custom sidebar templates, must be a dictionary that maps document names
|
| 113 |
+
# to template names.
|
| 114 |
+
#
|
| 115 |
+
# This is required for the alabaster theme
|
| 116 |
+
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
|
| 117 |
+
# html_sidebars = {
|
| 118 |
+
# '**': [
|
| 119 |
+
# 'about.html',
|
| 120 |
+
# 'navigation.html',
|
| 121 |
+
# 'relations.html', # needs 'show_related': True theme option to display
|
| 122 |
+
# 'searchbox.html',
|
| 123 |
+
# 'donate.html',
|
| 124 |
+
# ]
|
| 125 |
+
# }
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# Example configuration for intersphinx: refer to the Python standard library.
|
| 129 |
+
intersphinx_mapping = {
|
| 130 |
+
"numpy": ("http://docs.scipy.org/doc/numpy/", None),
|
| 131 |
+
"python": ("https://docs.python.org/", None),
|
| 132 |
+
"torch": ("https://pytorch.org/docs/master/", None),
|
| 133 |
+
}
|
fairseq-0.10.2/docs/lr_scheduler.rst
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. role:: hidden
|
| 2 |
+
:class: hidden-section
|
| 3 |
+
|
| 4 |
+
.. _Learning Rate Schedulers:
|
| 5 |
+
|
| 6 |
+
Learning Rate Schedulers
|
| 7 |
+
========================
|
| 8 |
+
|
| 9 |
+
Learning Rate Schedulers update the learning rate over the course of training.
|
| 10 |
+
Learning rates can be updated after each update via :func:`step_update` or at
|
| 11 |
+
epoch boundaries via :func:`step`.
|
| 12 |
+
|
| 13 |
+
.. automodule:: fairseq.optim.lr_scheduler
|
| 14 |
+
:members:
|
| 15 |
+
|
| 16 |
+
.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
|
| 17 |
+
:members:
|
| 18 |
+
:undoc-members:
|
| 19 |
+
|
| 20 |
+
.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
|
| 21 |
+
:members:
|
| 22 |
+
:undoc-members:
|
| 23 |
+
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
|
| 24 |
+
:members:
|
| 25 |
+
:undoc-members:
|
| 26 |
+
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
|
| 27 |
+
:members:
|
| 28 |
+
:undoc-members:
|
| 29 |
+
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
|
| 30 |
+
:members:
|
| 31 |
+
:undoc-members:
|
| 32 |
+
.. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
|
| 33 |
+
:members:
|
| 34 |
+
:undoc-members:
|
fairseq-0.10.2/docs/requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sphinx<2.0
|
| 2 |
+
sphinx-argparse
|
fairseq-0.10.2/docs/tutorial_classifying_names.rst
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Tutorial: Classifying Names with a Character-Level RNN
|
| 2 |
+
======================================================
|
| 3 |
+
|
| 4 |
+
In this tutorial we will extend fairseq to support *classification* tasks. In
|
| 5 |
+
particular we will re-implement the PyTorch tutorial for `Classifying Names with
|
| 6 |
+
a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`_
|
| 7 |
+
in fairseq. It is recommended to quickly skim that tutorial before beginning
|
| 8 |
+
this one.
|
| 9 |
+
|
| 10 |
+
This tutorial covers:
|
| 11 |
+
|
| 12 |
+
1. **Preprocessing the data** to create dictionaries.
|
| 13 |
+
2. **Registering a new Model** that encodes an input sentence with a simple RNN
|
| 14 |
+
and predicts the output label.
|
| 15 |
+
3. **Registering a new Task** that loads our dictionaries and dataset.
|
| 16 |
+
4. **Training the Model** using the existing command-line tools.
|
| 17 |
+
5. **Writing an evaluation script** that imports fairseq and allows us to
|
| 18 |
+
interactively evaluate our model on new inputs.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
1. Preprocessing the data
|
| 22 |
+
-------------------------
|
| 23 |
+
|
| 24 |
+
The original tutorial provides raw data, but we'll work with a modified version
|
| 25 |
+
of the data that is already tokenized into characters and split into separate
|
| 26 |
+
train, valid and test sets.
|
| 27 |
+
|
| 28 |
+
Download and extract the data from here:
|
| 29 |
+
`tutorial_names.tar.gz <https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz>`_
|
| 30 |
+
|
| 31 |
+
Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess`
|
| 32 |
+
command-line tool to create the dictionaries. While this tool is primarily
|
| 33 |
+
intended for sequence-to-sequence problems, we're able to reuse it here by
|
| 34 |
+
treating the label as a "target" sequence of length 1. We'll also output the
|
| 35 |
+
preprocessed files in "raw" format using the ``--dataset-impl`` option to
|
| 36 |
+
enhance readability:
|
| 37 |
+
|
| 38 |
+
.. code-block:: console
|
| 39 |
+
|
| 40 |
+
> fairseq-preprocess \
|
| 41 |
+
--trainpref names/train --validpref names/valid --testpref names/test \
|
| 42 |
+
--source-lang input --target-lang label \
|
| 43 |
+
--destdir names-bin --dataset-impl raw
|
| 44 |
+
|
| 45 |
+
After running the above command you should see a new directory,
|
| 46 |
+
:file:`names-bin/`, containing the dictionaries for *inputs* and *labels*.
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
2. Registering a new Model
|
| 50 |
+
--------------------------
|
| 51 |
+
|
| 52 |
+
Next we'll register a new model in fairseq that will encode an input sentence
|
| 53 |
+
with a simple RNN and predict the output label. Compared to the original PyTorch
|
| 54 |
+
tutorial, our version will also work with batches of data and GPU Tensors.
|
| 55 |
+
|
| 56 |
+
First let's copy the simple RNN module implemented in the `PyTorch tutorial
|
| 57 |
+
<https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network>`_.
|
| 58 |
+
Create a new file named :file:`fairseq/models/rnn_classifier.py` with the
|
| 59 |
+
following contents::
|
| 60 |
+
|
| 61 |
+
import torch
|
| 62 |
+
import torch.nn as nn
|
| 63 |
+
|
| 64 |
+
class RNN(nn.Module):
|
| 65 |
+
|
| 66 |
+
def __init__(self, input_size, hidden_size, output_size):
|
| 67 |
+
super(RNN, self).__init__()
|
| 68 |
+
|
| 69 |
+
self.hidden_size = hidden_size
|
| 70 |
+
|
| 71 |
+
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
|
| 72 |
+
self.i2o = nn.Linear(input_size + hidden_size, output_size)
|
| 73 |
+
self.softmax = nn.LogSoftmax(dim=1)
|
| 74 |
+
|
| 75 |
+
def forward(self, input, hidden):
|
| 76 |
+
combined = torch.cat((input, hidden), 1)
|
| 77 |
+
hidden = self.i2h(combined)
|
| 78 |
+
output = self.i2o(combined)
|
| 79 |
+
output = self.softmax(output)
|
| 80 |
+
return output, hidden
|
| 81 |
+
|
| 82 |
+
def initHidden(self):
|
| 83 |
+
return torch.zeros(1, self.hidden_size)
|
| 84 |
+
|
| 85 |
+
We must also *register* this model with fairseq using the
|
| 86 |
+
:func:`~fairseq.models.register_model` function decorator. Once the model is
|
| 87 |
+
registered we'll be able to use it with the existing :ref:`Command-line Tools`.
|
| 88 |
+
|
| 89 |
+
All registered models must implement the :class:`~fairseq.models.BaseFairseqModel`
|
| 90 |
+
interface, so we'll create a small wrapper class in the same file and register
|
| 91 |
+
it in fairseq with the name ``'rnn_classifier'``::
|
| 92 |
+
|
| 93 |
+
from fairseq.models import BaseFairseqModel, register_model
|
| 94 |
+
|
| 95 |
+
# Note: the register_model "decorator" should immediately precede the
|
| 96 |
+
# definition of the Model class.
|
| 97 |
+
|
| 98 |
+
@register_model('rnn_classifier')
|
| 99 |
+
class FairseqRNNClassifier(BaseFairseqModel):
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def add_args(parser):
|
| 103 |
+
# Models can override this method to add new command-line arguments.
|
| 104 |
+
# Here we'll add a new command-line argument to configure the
|
| 105 |
+
# dimensionality of the hidden state.
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
'--hidden-dim', type=int, metavar='N',
|
| 108 |
+
help='dimensionality of the hidden state',
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def build_model(cls, args, task):
|
| 113 |
+
# Fairseq initializes models by calling the ``build_model()``
|
| 114 |
+
# function. This provides more flexibility, since the returned model
|
| 115 |
+
# instance can be of a different type than the one that was called.
|
| 116 |
+
# In this case we'll just return a FairseqRNNClassifier instance.
|
| 117 |
+
|
| 118 |
+
# Initialize our RNN module
|
| 119 |
+
rnn = RNN(
|
| 120 |
+
# We'll define the Task in the next section, but for now just
|
| 121 |
+
# notice that the task holds the dictionaries for the "source"
|
| 122 |
+
# (i.e., the input sentence) and "target" (i.e., the label).
|
| 123 |
+
input_size=len(task.source_dictionary),
|
| 124 |
+
hidden_size=args.hidden_dim,
|
| 125 |
+
output_size=len(task.target_dictionary),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Return the wrapped version of the module
|
| 129 |
+
return FairseqRNNClassifier(
|
| 130 |
+
rnn=rnn,
|
| 131 |
+
input_vocab=task.source_dictionary,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def __init__(self, rnn, input_vocab):
|
| 135 |
+
super(FairseqRNNClassifier, self).__init__()
|
| 136 |
+
|
| 137 |
+
self.rnn = rnn
|
| 138 |
+
self.input_vocab = input_vocab
|
| 139 |
+
|
| 140 |
+
# The RNN module in the tutorial expects one-hot inputs, so we can
|
| 141 |
+
# precompute the identity matrix to help convert from indices to
|
| 142 |
+
# one-hot vectors. We register it as a buffer so that it is moved to
|
| 143 |
+
# the GPU when ``cuda()`` is called.
|
| 144 |
+
self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))
|
| 145 |
+
|
| 146 |
+
def forward(self, src_tokens, src_lengths):
|
| 147 |
+
# The inputs to the ``forward()`` function are determined by the
|
| 148 |
+
# Task, and in particular the ``'net_input'`` key in each
|
| 149 |
+
# mini-batch. We'll define the Task in the next section, but for
|
| 150 |
+
# now just know that *src_tokens* has shape `(batch, src_len)` and
|
| 151 |
+
# *src_lengths* has shape `(batch)`.
|
| 152 |
+
bsz, max_src_len = src_tokens.size()
|
| 153 |
+
|
| 154 |
+
# Initialize the RNN hidden state. Compared to the original PyTorch
|
| 155 |
+
# tutorial we'll also handle batched inputs and work on the GPU.
|
| 156 |
+
hidden = self.rnn.initHidden()
|
| 157 |
+
hidden = hidden.repeat(bsz, 1) # expand for batched inputs
|
| 158 |
+
hidden = hidden.to(src_tokens.device) # move to GPU
|
| 159 |
+
|
| 160 |
+
for i in range(max_src_len):
|
| 161 |
+
# WARNING: The inputs have padding, so we should mask those
|
| 162 |
+
# elements here so that padding doesn't affect the results.
|
| 163 |
+
# This is left as an exercise for the reader. The padding symbol
|
| 164 |
+
# is given by ``self.input_vocab.pad()`` and the unpadded length
|
| 165 |
+
# of each input is given by *src_lengths*.
|
| 166 |
+
|
| 167 |
+
# One-hot encode a batch of input characters.
|
| 168 |
+
input = self.one_hot_inputs[src_tokens[:, i].long()]
|
| 169 |
+
|
| 170 |
+
# Feed the input to our RNN.
|
| 171 |
+
output, hidden = self.rnn(input, hidden)
|
| 172 |
+
|
| 173 |
+
# Return the final output state for making a prediction
|
| 174 |
+
return output
|
| 175 |
+
|
| 176 |
+
Finally let's define a *named architecture* with the configuration for our
|
| 177 |
+
model. This is done with the :func:`~fairseq.models.register_model_architecture`
|
| 178 |
+
function decorator. Thereafter this named architecture can be used with the
|
| 179 |
+
``--arch`` command-line argument, e.g., ``--arch pytorch_tutorial_rnn``::
|
| 180 |
+
|
| 181 |
+
from fairseq.models import register_model_architecture
|
| 182 |
+
|
| 183 |
+
# The first argument to ``register_model_architecture()`` should be the name
|
| 184 |
+
# of the model we registered above (i.e., 'rnn_classifier'). The function we
|
| 185 |
+
# register here should take a single argument *args* and modify it in-place
|
| 186 |
+
# to match the desired architecture.
|
| 187 |
+
|
| 188 |
+
@register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
|
| 189 |
+
def pytorch_tutorial_rnn(args):
|
| 190 |
+
# We use ``getattr()`` to prioritize arguments that are explicitly given
|
| 191 |
+
# on the command-line, so that the defaults defined below are only used
|
| 192 |
+
# when no other value has been specified.
|
| 193 |
+
args.hidden_dim = getattr(args, 'hidden_dim', 128)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
3. Registering a new Task
|
| 197 |
+
-------------------------
|
| 198 |
+
|
| 199 |
+
Now we'll register a new :class:`~fairseq.tasks.FairseqTask` that will load our
|
| 200 |
+
dictionaries and dataset. Tasks can also control how the data is batched into
|
| 201 |
+
mini-batches, but in this tutorial we'll reuse the batching provided by
|
| 202 |
+
:class:`fairseq.data.LanguagePairDataset`.
|
| 203 |
+
|
| 204 |
+
Create a new file named :file:`fairseq/tasks/simple_classification.py` with the
|
| 205 |
+
following contents::
|
| 206 |
+
|
| 207 |
+
import os
|
| 208 |
+
import torch
|
| 209 |
+
|
| 210 |
+
from fairseq.data import Dictionary, LanguagePairDataset
|
| 211 |
+
from fairseq.tasks import FairseqTask, register_task
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@register_task('simple_classification')
|
| 215 |
+
class SimpleClassificationTask(FairseqTask):
|
| 216 |
+
|
| 217 |
+
@staticmethod
|
| 218 |
+
def add_args(parser):
|
| 219 |
+
# Add some command-line arguments for specifying where the data is
|
| 220 |
+
# located and the maximum supported input length.
|
| 221 |
+
parser.add_argument('data', metavar='FILE',
|
| 222 |
+
help='file prefix for data')
|
| 223 |
+
parser.add_argument('--max-positions', default=1024, type=int,
|
| 224 |
+
help='max input length')
|
| 225 |
+
|
| 226 |
+
@classmethod
|
| 227 |
+
def setup_task(cls, args, **kwargs):
|
| 228 |
+
# Here we can perform any setup required for the task. This may include
|
| 229 |
+
# loading Dictionaries, initializing shared Embedding layers, etc.
|
| 230 |
+
# In this case we'll just load the Dictionaries.
|
| 231 |
+
input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
|
| 232 |
+
label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
|
| 233 |
+
print('| [input] dictionary: {} types'.format(len(input_vocab)))
|
| 234 |
+
print('| [label] dictionary: {} types'.format(len(label_vocab)))
|
| 235 |
+
|
| 236 |
+
return SimpleClassificationTask(args, input_vocab, label_vocab)
|
| 237 |
+
|
| 238 |
+
def __init__(self, args, input_vocab, label_vocab):
|
| 239 |
+
super().__init__(args)
|
| 240 |
+
self.input_vocab = input_vocab
|
| 241 |
+
self.label_vocab = label_vocab
|
| 242 |
+
|
| 243 |
+
def load_dataset(self, split, **kwargs):
|
| 244 |
+
"""Load a given dataset split (e.g., train, valid, test)."""
|
| 245 |
+
|
| 246 |
+
prefix = os.path.join(self.args.data, '{}.input-label'.format(split))
|
| 247 |
+
|
| 248 |
+
# Read input sentences.
|
| 249 |
+
sentences, lengths = [], []
|
| 250 |
+
with open(prefix + '.input', encoding='utf-8') as file:
|
| 251 |
+
for line in file:
|
| 252 |
+
sentence = line.strip()
|
| 253 |
+
|
| 254 |
+
# Tokenize the sentence, splitting on spaces
|
| 255 |
+
tokens = self.input_vocab.encode_line(
|
| 256 |
+
sentence, add_if_not_exist=False,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
sentences.append(tokens)
|
| 260 |
+
lengths.append(tokens.numel())
|
| 261 |
+
|
| 262 |
+
# Read labels.
|
| 263 |
+
labels = []
|
| 264 |
+
with open(prefix + '.label', encoding='utf-8') as file:
|
| 265 |
+
for line in file:
|
| 266 |
+
label = line.strip()
|
| 267 |
+
labels.append(
|
| 268 |
+
# Convert label to a numeric ID.
|
| 269 |
+
torch.LongTensor([self.label_vocab.add_symbol(label)])
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
assert len(sentences) == len(labels)
|
| 273 |
+
print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))
|
| 274 |
+
|
| 275 |
+
# We reuse LanguagePairDataset since classification can be modeled as a
|
| 276 |
+
# sequence-to-sequence task where the target sequence has length 1.
|
| 277 |
+
self.datasets[split] = LanguagePairDataset(
|
| 278 |
+
src=sentences,
|
| 279 |
+
src_sizes=lengths,
|
| 280 |
+
src_dict=self.input_vocab,
|
| 281 |
+
tgt=labels,
|
| 282 |
+
tgt_sizes=torch.ones(len(labels)), # targets have length 1
|
| 283 |
+
tgt_dict=self.label_vocab,
|
| 284 |
+
left_pad_source=False,
|
| 285 |
+
# Since our target is a single class label, there's no need for
|
| 286 |
+
# teacher forcing. If we set this to ``True`` then our Model's
|
| 287 |
+
# ``forward()`` method would receive an additional argument called
|
| 288 |
+
# *prev_output_tokens* that would contain a shifted version of the
|
| 289 |
+
# target sequence.
|
| 290 |
+
input_feeding=False,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
def max_positions(self):
|
| 294 |
+
"""Return the max input length allowed by the task."""
|
| 295 |
+
# The source should be less than *args.max_positions* and the "target"
|
| 296 |
+
# has max length 1.
|
| 297 |
+
return (self.args.max_positions, 1)
|
| 298 |
+
|
| 299 |
+
@property
|
| 300 |
+
def source_dictionary(self):
|
| 301 |
+
"""Return the source :class:`~fairseq.data.Dictionary`."""
|
| 302 |
+
return self.input_vocab
|
| 303 |
+
|
| 304 |
+
@property
|
| 305 |
+
def target_dictionary(self):
|
| 306 |
+
"""Return the target :class:`~fairseq.data.Dictionary`."""
|
| 307 |
+
return self.label_vocab
|
| 308 |
+
|
| 309 |
+
# We could override this method if we wanted more control over how batches
|
| 310 |
+
# are constructed, but it's not necessary for this tutorial since we can
|
| 311 |
+
# reuse the batching provided by LanguagePairDataset.
|
| 312 |
+
#
|
| 313 |
+
# def get_batch_iterator(
|
| 314 |
+
# self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
|
| 315 |
+
# ignore_invalid_inputs=False, required_batch_size_multiple=1,
|
| 316 |
+
# seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
|
| 317 |
+
# data_buffer_size=0, disable_iterator_cache=False,
|
| 318 |
+
# ):
|
| 319 |
+
# (...)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
4. Training the Model
|
| 323 |
+
---------------------
|
| 324 |
+
|
| 325 |
+
Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
|
| 326 |
+
command-line tool for this, making sure to specify our new Task (``--task
|
| 327 |
+
simple_classification``) and Model architecture (``--arch
|
| 328 |
+
pytorch_tutorial_rnn``):
|
| 329 |
+
|
| 330 |
+
.. note::
|
| 331 |
+
|
| 332 |
+
You can also configure the dimensionality of the hidden state by passing the
|
| 333 |
+
``--hidden-dim`` argument to :ref:`fairseq-train`.
|
| 334 |
+
|
| 335 |
+
.. code-block:: console
|
| 336 |
+
|
| 337 |
+
> fairseq-train names-bin \
|
| 338 |
+
--task simple_classification \
|
| 339 |
+
--arch pytorch_tutorial_rnn \
|
| 340 |
+
--optimizer adam --lr 0.001 --lr-shrink 0.5 \
|
| 341 |
+
--max-tokens 1000
|
| 342 |
+
(...)
|
| 343 |
+
| epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
|
| 344 |
+
| epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
|
| 345 |
+
| done training in 31.6 seconds
|
| 346 |
+
|
| 347 |
+
The model files should appear in the :file:`checkpoints/` directory.
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
5. Writing an evaluation script
|
| 351 |
+
-------------------------------
|
| 352 |
+
|
| 353 |
+
Finally we can write a short script to evaluate our model on new inputs. Create
|
| 354 |
+
a new file named :file:`eval_classifier.py` with the following contents::
|
| 355 |
+
|
| 356 |
+
from fairseq import checkpoint_utils, data, options, tasks
|
| 357 |
+
|
| 358 |
+
# Parse command-line arguments for generation
|
| 359 |
+
parser = options.get_generation_parser(default_task='simple_classification')
|
| 360 |
+
args = options.parse_args_and_arch(parser)
|
| 361 |
+
|
| 362 |
+
# Setup task
|
| 363 |
+
task = tasks.setup_task(args)
|
| 364 |
+
|
| 365 |
+
# Load model
|
| 366 |
+
print('| loading model from {}'.format(args.path))
|
| 367 |
+
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
|
| 368 |
+
model = models[0]
|
| 369 |
+
|
| 370 |
+
while True:
|
| 371 |
+
sentence = input('\nInput: ')
|
| 372 |
+
|
| 373 |
+
# Tokenize into characters
|
| 374 |
+
chars = ' '.join(list(sentence.strip()))
|
| 375 |
+
tokens = task.source_dictionary.encode_line(
|
| 376 |
+
chars, add_if_not_exist=False,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# Build mini-batch to feed to the model
|
| 380 |
+
batch = data.language_pair_dataset.collate(
|
| 381 |
+
samples=[{'id': -1, 'source': tokens}], # bsz = 1
|
| 382 |
+
pad_idx=task.source_dictionary.pad(),
|
| 383 |
+
eos_idx=task.source_dictionary.eos(),
|
| 384 |
+
left_pad_source=False,
|
| 385 |
+
input_feeding=False,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# Feed batch to the model and get predictions
|
| 389 |
+
preds = model(**batch['net_input'])
|
| 390 |
+
|
| 391 |
+
# Print top 3 predictions and their log-probabilities
|
| 392 |
+
top_scores, top_labels = preds[0].topk(k=3)
|
| 393 |
+
for score, label_idx in zip(top_scores, top_labels):
|
| 394 |
+
label_name = task.target_dictionary.string([label_idx])
|
| 395 |
+
print('({:.2f})\t{}'.format(score, label_name))
|
| 396 |
+
|
| 397 |
+
Now we can evaluate our model interactively. Note that we have included the
|
| 398 |
+
original data path (:file:`names-bin/`) so that the dictionaries can be loaded:
|
| 399 |
+
|
| 400 |
+
.. code-block:: console
|
| 401 |
+
|
| 402 |
+
> python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt
|
| 403 |
+
| [input] dictionary: 64 types
|
| 404 |
+
| [label] dictionary: 24 types
|
| 405 |
+
| loading model from checkpoints/checkpoint_best.pt
|
| 406 |
+
|
| 407 |
+
Input: Satoshi
|
| 408 |
+
(-0.61) Japanese
|
| 409 |
+
(-1.20) Arabic
|
| 410 |
+
(-2.86) Italian
|
| 411 |
+
|
| 412 |
+
Input: Sinbad
|
| 413 |
+
(-0.30) Arabic
|
| 414 |
+
(-1.76) English
|
| 415 |
+
(-4.08) Russian
|
fairseq-0.10.2/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5893e460c344970e372a5cba54a7892a793e4519e085acde37ad9ad57ea5c48f
|
| 3 |
+
size 1855456
|
fairseq-0.10.2/fairseq_cli/__init__.py
ADDED
|
File without changes
|
fairseq-0.10.2/fairseq_cli/__pycache__/generate.cpython-310.pyc
ADDED
|
Binary file (8.12 kB). View file
|
|
|
fairseq-0.10.2/fairseq_cli/eval_lm.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3 -u
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Evaluate the perplexity of a trained language model.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import math
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
|
| 17 |
+
from fairseq.data import LMContextWindowDataset
|
| 18 |
+
from fairseq.logging import progress_bar
|
| 19 |
+
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
| 20 |
+
from fairseq.sequence_scorer import SequenceScorer
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 25 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 26 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
| 27 |
+
)
|
| 28 |
+
logger = logging.getLogger("fairseq_cli.eval_lm")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class WordStat(object):
|
| 32 |
+
def __init__(self, word, is_bpe):
|
| 33 |
+
self.word = word
|
| 34 |
+
self.is_bpe = is_bpe
|
| 35 |
+
self.log_prob = 0
|
| 36 |
+
self.next_word_prob = 0
|
| 37 |
+
self.count = 0
|
| 38 |
+
self.missing_next_words = 0
|
| 39 |
+
|
| 40 |
+
def add(self, log_prob, next_word_prob):
|
| 41 |
+
"""increments counters for the sum of log probs of current word and next
|
| 42 |
+
word (given context ending at current word). Since the next word might be at the end of the example,
|
| 43 |
+
or it might be not counted because it is not an ending subword unit,
|
| 44 |
+
also keeps track of how many of those we have seen"""
|
| 45 |
+
if next_word_prob is not None:
|
| 46 |
+
self.next_word_prob += next_word_prob
|
| 47 |
+
else:
|
| 48 |
+
self.missing_next_words += 1
|
| 49 |
+
self.log_prob += log_prob
|
| 50 |
+
self.count += 1
|
| 51 |
+
|
| 52 |
+
def __str__(self):
|
| 53 |
+
return "{}\t{}\t{}\t{}\t{}\t{}".format(
|
| 54 |
+
self.word,
|
| 55 |
+
self.count,
|
| 56 |
+
self.log_prob,
|
| 57 |
+
self.is_bpe,
|
| 58 |
+
self.next_word_prob,
|
| 59 |
+
self.count - self.missing_next_words,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main(parsed_args, **unused_kwargs):
|
| 64 |
+
assert parsed_args.path is not None, "--path required for evaluation!"
|
| 65 |
+
|
| 66 |
+
if torch.cuda.is_available() and not parsed_args.cpu:
|
| 67 |
+
torch.cuda.set_device(parsed_args.device_id)
|
| 68 |
+
|
| 69 |
+
utils.import_user_module(parsed_args)
|
| 70 |
+
|
| 71 |
+
logger.info(parsed_args)
|
| 72 |
+
|
| 73 |
+
use_cuda = torch.cuda.is_available() and not parsed_args.cpu
|
| 74 |
+
|
| 75 |
+
task = tasks.setup_task(parsed_args)
|
| 76 |
+
|
| 77 |
+
# Load ensemble
|
| 78 |
+
logger.info("loading model(s) from {}".format(parsed_args.path))
|
| 79 |
+
models, args = checkpoint_utils.load_model_ensemble(
|
| 80 |
+
parsed_args.path.split(os.pathsep),
|
| 81 |
+
arg_overrides=eval(parsed_args.model_overrides),
|
| 82 |
+
task=task,
|
| 83 |
+
suffix=getattr(parsed_args, "checkpoint_suffix", ""),
|
| 84 |
+
strict=(parsed_args.checkpoint_shard_count == 1),
|
| 85 |
+
num_shards=parsed_args.checkpoint_shard_count,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
for arg in vars(parsed_args).keys():
|
| 89 |
+
if arg not in {
|
| 90 |
+
"self_target",
|
| 91 |
+
"future_target",
|
| 92 |
+
"past_target",
|
| 93 |
+
"tokens_per_sample",
|
| 94 |
+
"output_size_dictionary",
|
| 95 |
+
"add_bos_token",
|
| 96 |
+
}:
|
| 97 |
+
setattr(args, arg, getattr(parsed_args, arg))
|
| 98 |
+
|
| 99 |
+
# reduce tokens per sample by the required context window size
|
| 100 |
+
args.tokens_per_sample -= args.context_window
|
| 101 |
+
task = tasks.setup_task(args)
|
| 102 |
+
|
| 103 |
+
# Load dataset splits
|
| 104 |
+
task.load_dataset(args.gen_subset)
|
| 105 |
+
dataset = task.dataset(args.gen_subset)
|
| 106 |
+
if args.context_window > 0:
|
| 107 |
+
dataset = LMContextWindowDataset(
|
| 108 |
+
dataset=dataset,
|
| 109 |
+
tokens_per_sample=args.tokens_per_sample,
|
| 110 |
+
context_window=args.context_window,
|
| 111 |
+
pad_idx=task.source_dictionary.pad(),
|
| 112 |
+
)
|
| 113 |
+
logger.info("{} {} {} examples".format(args.data, args.gen_subset, len(dataset)))
|
| 114 |
+
|
| 115 |
+
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
|
| 116 |
+
for model in models:
|
| 117 |
+
if args.fp16:
|
| 118 |
+
model.half()
|
| 119 |
+
if use_cuda and not args.pipeline_model_parallel:
|
| 120 |
+
model.cuda()
|
| 121 |
+
model.prepare_for_inference_(args)
|
| 122 |
+
|
| 123 |
+
assert len(models) > 0
|
| 124 |
+
|
| 125 |
+
logger.info(
|
| 126 |
+
"num. model params: {}".format(sum(p.numel() for p in models[0].parameters()))
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
itr = task.get_batch_iterator(
|
| 130 |
+
dataset=dataset,
|
| 131 |
+
max_tokens=args.max_tokens or 36000,
|
| 132 |
+
max_sentences=args.batch_size,
|
| 133 |
+
max_positions=utils.resolve_max_positions(
|
| 134 |
+
*[model.max_positions() for model in models]
|
| 135 |
+
),
|
| 136 |
+
ignore_invalid_inputs=True,
|
| 137 |
+
num_shards=args.num_shards,
|
| 138 |
+
shard_id=args.shard_id,
|
| 139 |
+
num_workers=args.num_workers,
|
| 140 |
+
data_buffer_size=args.data_buffer_size,
|
| 141 |
+
).next_epoch_itr(shuffle=False)
|
| 142 |
+
progress = progress_bar.progress_bar(
|
| 143 |
+
itr,
|
| 144 |
+
log_format=args.log_format,
|
| 145 |
+
log_interval=args.log_interval,
|
| 146 |
+
default_log_format=("tqdm" if not args.no_progress_bar else "none"),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
gen_timer = StopwatchMeter()
|
| 150 |
+
scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)
|
| 151 |
+
|
| 152 |
+
score_sum = 0.0
|
| 153 |
+
count = 0
|
| 154 |
+
|
| 155 |
+
if args.remove_bpe is not None:
|
| 156 |
+
if args.remove_bpe == "sentencepiece":
|
| 157 |
+
raise NotImplementedError
|
| 158 |
+
else:
|
| 159 |
+
bpe_cont = args.remove_bpe.rstrip()
|
| 160 |
+
bpe_toks = {
|
| 161 |
+
i
|
| 162 |
+
for i in range(len(task.source_dictionary))
|
| 163 |
+
if task.source_dictionary[i].endswith(bpe_cont)
|
| 164 |
+
}
|
| 165 |
+
bpe_len = len(bpe_cont)
|
| 166 |
+
else:
|
| 167 |
+
bpe_toks = None
|
| 168 |
+
bpe_len = 0
|
| 169 |
+
|
| 170 |
+
word_stats = dict()
|
| 171 |
+
|
| 172 |
+
wps_meter = TimeMeter()
|
| 173 |
+
|
| 174 |
+
for sample in progress:
|
| 175 |
+
if "net_input" not in sample:
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
| 179 |
+
|
| 180 |
+
gen_timer.start()
|
| 181 |
+
hypos = scorer.generate(models, sample)
|
| 182 |
+
gen_timer.stop(sample["ntokens"])
|
| 183 |
+
|
| 184 |
+
for i, hypos_i in enumerate(hypos):
|
| 185 |
+
hypo = hypos_i[0]
|
| 186 |
+
sample_id = sample["id"][i]
|
| 187 |
+
|
| 188 |
+
tokens = hypo["tokens"]
|
| 189 |
+
tgt_len = tokens.numel()
|
| 190 |
+
pos_scores = hypo["positional_scores"].float()
|
| 191 |
+
|
| 192 |
+
if getattr(args, "add_bos_token", False):
|
| 193 |
+
assert hypo["tokens"][0].item() == task.target_dictionary.bos()
|
| 194 |
+
tokens = tokens[1:]
|
| 195 |
+
pos_scores = pos_scores[1:]
|
| 196 |
+
|
| 197 |
+
skipped_toks = 0
|
| 198 |
+
if bpe_toks is not None:
|
| 199 |
+
for i in range(tgt_len - 1):
|
| 200 |
+
if tokens[i].item() in bpe_toks:
|
| 201 |
+
skipped_toks += 1
|
| 202 |
+
pos_scores[i + 1] += pos_scores[i]
|
| 203 |
+
pos_scores[i] = 0
|
| 204 |
+
|
| 205 |
+
inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf"))
|
| 206 |
+
if inf_scores.any():
|
| 207 |
+
logger.info(
|
| 208 |
+
"skipping tokens with inf scores:",
|
| 209 |
+
task.target_dictionary.string(tokens[inf_scores.nonzero()]),
|
| 210 |
+
)
|
| 211 |
+
pos_scores = pos_scores[(~inf_scores).nonzero()]
|
| 212 |
+
score_sum += pos_scores.sum().cpu()
|
| 213 |
+
count += pos_scores.numel() - skipped_toks
|
| 214 |
+
|
| 215 |
+
if args.output_word_probs or args.output_word_stats:
|
| 216 |
+
w = ""
|
| 217 |
+
word_prob = []
|
| 218 |
+
is_bpe = False
|
| 219 |
+
for i in range(len(tokens)):
|
| 220 |
+
w_ind = tokens[i].item()
|
| 221 |
+
w += task.source_dictionary[w_ind]
|
| 222 |
+
if bpe_toks is not None and w_ind in bpe_toks:
|
| 223 |
+
w = w[:-bpe_len]
|
| 224 |
+
is_bpe = True
|
| 225 |
+
else:
|
| 226 |
+
word_prob.append((w, pos_scores[i].item()))
|
| 227 |
+
|
| 228 |
+
next_prob = None
|
| 229 |
+
ind = i + 1
|
| 230 |
+
while ind < len(tokens):
|
| 231 |
+
if pos_scores[ind].item() != 0:
|
| 232 |
+
next_prob = pos_scores[ind]
|
| 233 |
+
break
|
| 234 |
+
ind += 1
|
| 235 |
+
|
| 236 |
+
word_stats.setdefault(w, WordStat(w, is_bpe)).add(
|
| 237 |
+
pos_scores[i].item(), next_prob
|
| 238 |
+
)
|
| 239 |
+
is_bpe = False
|
| 240 |
+
w = ""
|
| 241 |
+
if args.output_word_probs:
|
| 242 |
+
logger.info(
|
| 243 |
+
str(int(sample_id))
|
| 244 |
+
+ " "
|
| 245 |
+
+ (
|
| 246 |
+
"\t".join(
|
| 247 |
+
"{} [{:2f}]".format(x[0], x[1]) for x in word_prob
|
| 248 |
+
)
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
wps_meter.update(sample["ntokens"])
|
| 253 |
+
progress.log({"wps": round(wps_meter.avg)})
|
| 254 |
+
|
| 255 |
+
avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2
|
| 256 |
+
logger.info(
|
| 257 |
+
"Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format(
|
| 258 |
+
gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg
|
| 259 |
+
)
|
| 260 |
+
)
|
| 261 |
+
logger.info(
|
| 262 |
+
"Loss (base 2): {:.4f}, Perplexity: {:.2f}".format(
|
| 263 |
+
avg_nll_loss, 2 ** avg_nll_loss
|
| 264 |
+
)
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if args.output_word_stats:
|
| 268 |
+
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
|
| 269 |
+
logger.info(ws)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def cli_main():
|
| 273 |
+
parser = options.get_eval_lm_parser()
|
| 274 |
+
args = options.parse_args_and_arch(parser)
|
| 275 |
+
distributed_utils.call_main(args, main)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if __name__ == "__main__":
|
| 279 |
+
cli_main()
|
fairseq-0.10.2/fairseq_cli/train.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3 -u
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
Train a new model on one or across multiple GPUs.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import logging
|
| 12 |
+
import math
|
| 13 |
+
import os
|
| 14 |
+
import random
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from fairseq import (
|
| 20 |
+
checkpoint_utils,
|
| 21 |
+
distributed_utils,
|
| 22 |
+
options,
|
| 23 |
+
quantization_utils,
|
| 24 |
+
tasks,
|
| 25 |
+
utils,
|
| 26 |
+
)
|
| 27 |
+
from fairseq.data import iterators
|
| 28 |
+
from fairseq.logging import meters, metrics, progress_bar
|
| 29 |
+
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
|
| 30 |
+
from fairseq.trainer import Trainer
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 35 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 36 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
| 37 |
+
stream=sys.stdout,
|
| 38 |
+
)
|
| 39 |
+
logger = logging.getLogger("fairseq_cli.train")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def main(args):
|
| 43 |
+
utils.import_user_module(args)
|
| 44 |
+
|
| 45 |
+
assert (
|
| 46 |
+
args.max_tokens is not None or args.batch_size is not None
|
| 47 |
+
), "Must specify batch size either with --max-tokens or --batch-size"
|
| 48 |
+
|
| 49 |
+
metrics.reset()
|
| 50 |
+
|
| 51 |
+
np.random.seed(args.seed)
|
| 52 |
+
utils.set_torch_seed(args.seed)
|
| 53 |
+
|
| 54 |
+
if distributed_utils.is_master(args):
|
| 55 |
+
checkpoint_utils.verify_checkpoint_directory(args.save_dir)
|
| 56 |
+
|
| 57 |
+
# Print args
|
| 58 |
+
logger.info(args)
|
| 59 |
+
|
| 60 |
+
# Setup task, e.g., translation, language modeling, etc.
|
| 61 |
+
task = tasks.setup_task(args)
|
| 62 |
+
|
| 63 |
+
# Load valid dataset (we load training data below, based on the latest checkpoint)
|
| 64 |
+
for valid_sub_split in args.valid_subset.split(","):
|
| 65 |
+
task.load_dataset(valid_sub_split, combine=False, epoch=1)
|
| 66 |
+
|
| 67 |
+
# Build model and criterion
|
| 68 |
+
model = task.build_model(args)
|
| 69 |
+
criterion = task.build_criterion(args)
|
| 70 |
+
logger.info(model)
|
| 71 |
+
logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
|
| 72 |
+
logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
|
| 73 |
+
logger.info(
|
| 74 |
+
"criterion: {} ({})".format(args.criterion, criterion.__class__.__name__)
|
| 75 |
+
)
|
| 76 |
+
logger.info(
|
| 77 |
+
"num. model params: {} (num. trained: {})".format(
|
| 78 |
+
sum(p.numel() for p in model.parameters()),
|
| 79 |
+
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
| 80 |
+
)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# (optionally) Configure quantization
|
| 84 |
+
if args.quantization_config_path is not None:
|
| 85 |
+
quantizer = quantization_utils.Quantizer(
|
| 86 |
+
config_path=args.quantization_config_path,
|
| 87 |
+
max_epoch=args.max_epoch,
|
| 88 |
+
max_update=args.max_update,
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
quantizer = None
|
| 92 |
+
|
| 93 |
+
# Build trainer
|
| 94 |
+
if args.model_parallel_size == 1:
|
| 95 |
+
trainer = Trainer(args, task, model, criterion, quantizer)
|
| 96 |
+
else:
|
| 97 |
+
trainer = MegatronTrainer(args, task, model, criterion)
|
| 98 |
+
|
| 99 |
+
logger.info(
|
| 100 |
+
"training on {} devices (GPUs/TPUs)".format(args.distributed_world_size)
|
| 101 |
+
)
|
| 102 |
+
logger.info(
|
| 103 |
+
"max tokens per GPU = {} and max sentences per GPU = {}".format(
|
| 104 |
+
args.max_tokens, args.batch_size
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Load the latest checkpoint if one is available and restore the
|
| 109 |
+
# corresponding train iterator
|
| 110 |
+
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
|
| 111 |
+
args,
|
| 112 |
+
trainer,
|
| 113 |
+
# don't cache epoch iterators for sharded datasets
|
| 114 |
+
disable_iterator_cache=task.has_sharded_data("train"),
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Train until the learning rate gets too small
|
| 118 |
+
max_epoch = args.max_epoch or math.inf
|
| 119 |
+
lr = trainer.get_lr()
|
| 120 |
+
train_meter = meters.StopwatchMeter()
|
| 121 |
+
train_meter.start()
|
| 122 |
+
|
| 123 |
+
while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
|
| 124 |
+
# train for one epoch
|
| 125 |
+
valid_losses, should_stop = train(args, trainer, task, epoch_itr)
|
| 126 |
+
if should_stop:
|
| 127 |
+
break
|
| 128 |
+
|
| 129 |
+
# only use first validation loss to update the learning rate
|
| 130 |
+
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
|
| 131 |
+
|
| 132 |
+
epoch_itr = trainer.get_train_iterator(
|
| 133 |
+
epoch_itr.next_epoch_idx,
|
| 134 |
+
# sharded data: get train iterator for next epoch
|
| 135 |
+
load_dataset=task.has_sharded_data("train"),
|
| 136 |
+
# don't cache epoch iterators for sharded datasets
|
| 137 |
+
disable_iterator_cache=task.has_sharded_data("train"),
|
| 138 |
+
)
|
| 139 |
+
train_meter.stop()
|
| 140 |
+
logger.info("done training in {:.1f} seconds".format(train_meter.sum))
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def should_stop_early(args, valid_loss):
|
| 144 |
+
# skip check if no validation was done in the current epoch
|
| 145 |
+
if valid_loss is None:
|
| 146 |
+
return False
|
| 147 |
+
if args.patience <= 0:
|
| 148 |
+
return False
|
| 149 |
+
|
| 150 |
+
def is_better(a, b):
|
| 151 |
+
return a > b if args.maximize_best_checkpoint_metric else a < b
|
| 152 |
+
|
| 153 |
+
prev_best = getattr(should_stop_early, "best", None)
|
| 154 |
+
if prev_best is None or is_better(valid_loss, prev_best):
|
| 155 |
+
should_stop_early.best = valid_loss
|
| 156 |
+
should_stop_early.num_runs = 0
|
| 157 |
+
return False
|
| 158 |
+
else:
|
| 159 |
+
should_stop_early.num_runs += 1
|
| 160 |
+
if should_stop_early.num_runs >= args.patience:
|
| 161 |
+
logger.info(
|
| 162 |
+
"early stop since valid performance hasn't improved for last {} runs".format(
|
| 163 |
+
args.patience
|
| 164 |
+
)
|
| 165 |
+
)
|
| 166 |
+
return True
|
| 167 |
+
else:
|
| 168 |
+
return False
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@metrics.aggregate("train")
|
| 172 |
+
def train(args, trainer, task, epoch_itr):
|
| 173 |
+
"""Train the model for one epoch and return validation losses."""
|
| 174 |
+
# Initialize data iterator
|
| 175 |
+
itr = epoch_itr.next_epoch_itr(
|
| 176 |
+
fix_batches_to_gpus=args.fix_batches_to_gpus,
|
| 177 |
+
shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
|
| 178 |
+
)
|
| 179 |
+
update_freq = (
|
| 180 |
+
args.update_freq[epoch_itr.epoch - 1]
|
| 181 |
+
if epoch_itr.epoch <= len(args.update_freq)
|
| 182 |
+
else args.update_freq[-1]
|
| 183 |
+
)
|
| 184 |
+
itr = iterators.GroupedIterator(itr, update_freq)
|
| 185 |
+
if getattr(args, "tpu", False):
|
| 186 |
+
itr = utils.tpu_data_loader(itr)
|
| 187 |
+
progress = progress_bar.progress_bar(
|
| 188 |
+
itr,
|
| 189 |
+
log_format=args.log_format,
|
| 190 |
+
log_interval=args.log_interval,
|
| 191 |
+
epoch=epoch_itr.epoch,
|
| 192 |
+
tensorboard_logdir=(
|
| 193 |
+
args.tensorboard_logdir if distributed_utils.is_master(args) else None
|
| 194 |
+
),
|
| 195 |
+
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
trainer.begin_epoch(epoch_itr.epoch)
|
| 199 |
+
|
| 200 |
+
valid_losses = [None]
|
| 201 |
+
valid_subsets = args.valid_subset.split(",")
|
| 202 |
+
should_stop = False
|
| 203 |
+
num_updates = trainer.get_num_updates()
|
| 204 |
+
for i, samples in enumerate(progress):
|
| 205 |
+
with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
|
| 206 |
+
"train_step-%d" % i
|
| 207 |
+
):
|
| 208 |
+
log_output = trainer.train_step(samples)
|
| 209 |
+
|
| 210 |
+
if log_output is not None: # not OOM, overflow, ...
|
| 211 |
+
# log mid-epoch stats
|
| 212 |
+
num_updates = trainer.get_num_updates()
|
| 213 |
+
if num_updates % args.log_interval == 0:
|
| 214 |
+
stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
|
| 215 |
+
progress.log(stats, tag="train_inner", step=num_updates)
|
| 216 |
+
|
| 217 |
+
# reset mid-epoch stats after each log interval
|
| 218 |
+
# the end-of-epoch stats will still be preserved
|
| 219 |
+
metrics.reset_meters("train_inner")
|
| 220 |
+
|
| 221 |
+
end_of_epoch = not itr.has_next()
|
| 222 |
+
valid_losses, should_stop = validate_and_save(
|
| 223 |
+
args, trainer, task, epoch_itr, valid_subsets, end_of_epoch
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if should_stop:
|
| 227 |
+
break
|
| 228 |
+
|
| 229 |
+
# log end-of-epoch stats
|
| 230 |
+
logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
|
| 231 |
+
stats = get_training_stats(metrics.get_smoothed_values("train"))
|
| 232 |
+
progress.print(stats, tag="train", step=num_updates)
|
| 233 |
+
|
| 234 |
+
# reset epoch-level meters
|
| 235 |
+
metrics.reset_meters("train")
|
| 236 |
+
return valid_losses, should_stop
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch):
|
| 240 |
+
num_updates = trainer.get_num_updates()
|
| 241 |
+
max_update = args.max_update or math.inf
|
| 242 |
+
do_save = (
|
| 243 |
+
(end_of_epoch and epoch_itr.epoch % args.save_interval == 0)
|
| 244 |
+
or num_updates >= max_update
|
| 245 |
+
or (
|
| 246 |
+
args.save_interval_updates > 0
|
| 247 |
+
and num_updates > 0
|
| 248 |
+
and num_updates % args.save_interval_updates == 0
|
| 249 |
+
and num_updates >= args.validate_after_updates
|
| 250 |
+
)
|
| 251 |
+
)
|
| 252 |
+
do_validate = (
|
| 253 |
+
(not end_of_epoch and do_save) # validate during mid-epoch saves
|
| 254 |
+
or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0)
|
| 255 |
+
or num_updates >= max_update
|
| 256 |
+
or (
|
| 257 |
+
args.validate_interval_updates > 0
|
| 258 |
+
and num_updates > 0
|
| 259 |
+
and num_updates % args.validate_interval_updates == 0
|
| 260 |
+
)
|
| 261 |
+
) and not args.disable_validation
|
| 262 |
+
|
| 263 |
+
# Validate
|
| 264 |
+
valid_losses = [None]
|
| 265 |
+
if do_validate:
|
| 266 |
+
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
|
| 267 |
+
|
| 268 |
+
# Stopping conditions
|
| 269 |
+
should_stop = (
|
| 270 |
+
should_stop_early(args, valid_losses[0])
|
| 271 |
+
or num_updates >= max_update
|
| 272 |
+
or (
|
| 273 |
+
args.stop_time_hours > 0
|
| 274 |
+
and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours
|
| 275 |
+
)
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Save checkpoint
|
| 279 |
+
if do_save or should_stop:
|
| 280 |
+
logger.info("begin save checkpoint")
|
| 281 |
+
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
|
| 282 |
+
|
| 283 |
+
return valid_losses, should_stop
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def get_training_stats(stats):
|
| 287 |
+
stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
|
| 288 |
+
return stats
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def validate(args, trainer, task, epoch_itr, subsets):
|
| 292 |
+
"""Evaluate the model on the validation set(s) and return the losses."""
|
| 293 |
+
|
| 294 |
+
if args.fixed_validation_seed is not None:
|
| 295 |
+
# set fixed seed for every validation
|
| 296 |
+
utils.set_torch_seed(args.fixed_validation_seed)
|
| 297 |
+
|
| 298 |
+
trainer.begin_valid_epoch(epoch_itr.epoch)
|
| 299 |
+
valid_losses = []
|
| 300 |
+
for subset in subsets:
|
| 301 |
+
logger.info('begin validation on "{}" subset'.format(subset))
|
| 302 |
+
|
| 303 |
+
# Initialize data iterator
|
| 304 |
+
itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
|
| 305 |
+
if getattr(args, "tpu", False):
|
| 306 |
+
itr = utils.tpu_data_loader(itr)
|
| 307 |
+
progress = progress_bar.progress_bar(
|
| 308 |
+
itr,
|
| 309 |
+
log_format=args.log_format,
|
| 310 |
+
log_interval=args.log_interval,
|
| 311 |
+
epoch=epoch_itr.epoch,
|
| 312 |
+
prefix=f"valid on '{subset}' subset",
|
| 313 |
+
tensorboard_logdir=(
|
| 314 |
+
args.tensorboard_logdir if distributed_utils.is_master(args) else None
|
| 315 |
+
),
|
| 316 |
+
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# create a new root metrics aggregator so validation metrics
|
| 320 |
+
# don't pollute other aggregators (e.g., train meters)
|
| 321 |
+
with metrics.aggregate(new_root=True) as agg:
|
| 322 |
+
for sample in progress:
|
| 323 |
+
trainer.valid_step(sample)
|
| 324 |
+
|
| 325 |
+
# log validation stats
|
| 326 |
+
stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
|
| 327 |
+
progress.print(stats, tag=subset, step=trainer.get_num_updates())
|
| 328 |
+
|
| 329 |
+
valid_losses.append(stats[args.best_checkpoint_metric])
|
| 330 |
+
return valid_losses
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def get_valid_stats(args, trainer, stats):
|
| 334 |
+
stats["num_updates"] = trainer.get_num_updates()
|
| 335 |
+
if hasattr(checkpoint_utils.save_checkpoint, "best"):
|
| 336 |
+
key = "best_{0}".format(args.best_checkpoint_metric)
|
| 337 |
+
best_function = max if args.maximize_best_checkpoint_metric else min
|
| 338 |
+
stats[key] = best_function(
|
| 339 |
+
checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric]
|
| 340 |
+
)
|
| 341 |
+
return stats
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def cli_main(modify_parser=None):
|
| 345 |
+
parser = options.get_training_parser()
|
| 346 |
+
args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
|
| 347 |
+
if args.profile:
|
| 348 |
+
with torch.cuda.profiler.profile():
|
| 349 |
+
with torch.autograd.profiler.emit_nvtx():
|
| 350 |
+
distributed_utils.call_main(args, main)
|
| 351 |
+
else:
|
| 352 |
+
distributed_utils.call_main(args, main)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
if __name__ == "__main__":
|
| 356 |
+
cli_main()
|
mosesdecoder/phrase-extract/Alignment.cpp
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***********************************************************************
|
| 2 |
+
Moses - statistical machine translation system
|
| 3 |
+
Copyright (C) 2006-2011 University of Edinburgh
|
| 4 |
+
|
| 5 |
+
This library is free software; you can redistribute it and/or
|
| 6 |
+
modify it under the terms of the GNU Lesser General Public
|
| 7 |
+
License as published by the Free Software Foundation; either
|
| 8 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This library is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 13 |
+
Lesser General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU Lesser General Public
|
| 16 |
+
License along with this library; if not, write to the Free Software
|
| 17 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 18 |
+
***********************************************************************/
|
| 19 |
+
|
| 20 |
+
#include "Alignment.h"
|
| 21 |
+
|
| 22 |
+
#include "phrase-extract/syntax-common/exception.h"
|
| 23 |
+
|
| 24 |
+
#include <algorithm>
|
| 25 |
+
#include <cassert>
|
| 26 |
+
#include <cstdlib>
|
| 27 |
+
|
| 28 |
+
namespace MosesTraining
|
| 29 |
+
{
|
| 30 |
+
|
| 31 |
+
void ReadAlignment(const std::string &s, Alignment &a)
|
| 32 |
+
{
|
| 33 |
+
const std::string digits = "0123456789";
|
| 34 |
+
|
| 35 |
+
a.clear();
|
| 36 |
+
|
| 37 |
+
std::string::size_type begin = 0;
|
| 38 |
+
while (true) {
|
| 39 |
+
std::string::size_type end = s.find("-", begin);
|
| 40 |
+
if (end == std::string::npos) {
|
| 41 |
+
return;
|
| 42 |
+
}
|
| 43 |
+
int src = std::atoi(s.substr(begin, end-begin).c_str());
|
| 44 |
+
if (end+1 == s.size()) {
|
| 45 |
+
throw Syntax::Exception("Target index missing");
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
begin = end+1;
|
| 49 |
+
end = s.find_first_not_of(digits, begin+1);
|
| 50 |
+
int tgt;
|
| 51 |
+
if (end == std::string::npos) {
|
| 52 |
+
tgt = std::atoi(s.substr(begin).c_str());
|
| 53 |
+
a.push_back(std::make_pair(src, tgt));
|
| 54 |
+
return;
|
| 55 |
+
} else {
|
| 56 |
+
tgt = std::atoi(s.substr(begin, end-begin).c_str());
|
| 57 |
+
a.push_back(std::make_pair(src, tgt));
|
| 58 |
+
}
|
| 59 |
+
begin = end+1;
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
void FlipAlignment(Alignment &a)
|
| 64 |
+
{
|
| 65 |
+
for (Alignment::iterator p = a.begin(); p != a.end(); ++p) {
|
| 66 |
+
std::swap(p->first, p->second);
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
} // namespace MosesTraining
|
mosesdecoder/phrase-extract/AlignmentPhrase.h
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// $Id$
|
| 2 |
+
/***********************************************************************
|
| 3 |
+
Moses - factored phrase-based language decoder
|
| 4 |
+
Copyright (C) 2006 University of Edinburgh
|
| 5 |
+
|
| 6 |
+
This library is free software; you can redistribute it and/or
|
| 7 |
+
modify it under the terms of the GNU Lesser General Public
|
| 8 |
+
License as published by the Free Software Foundation; either
|
| 9 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 10 |
+
|
| 11 |
+
This library is distributed in the hope that it will be useful,
|
| 12 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 13 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 14 |
+
Lesser General Public License for more details.
|
| 15 |
+
|
| 16 |
+
You should have received a copy of the GNU Lesser General Public
|
| 17 |
+
License along with this library; if not, write to the Free Software
|
| 18 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 19 |
+
***********************************************************************/
|
| 20 |
+
|
| 21 |
+
#pragma once
|
| 22 |
+
|
| 23 |
+
#include <vector>
|
| 24 |
+
#include <set>
|
| 25 |
+
|
| 26 |
+
namespace MosesTraining
|
| 27 |
+
{
|
| 28 |
+
|
| 29 |
+
class WordsRange;
|
| 30 |
+
|
| 31 |
+
class AlignmentElement
|
| 32 |
+
{
|
| 33 |
+
protected:
|
| 34 |
+
std::set<size_t> m_elements;
|
| 35 |
+
public:
|
| 36 |
+
typedef std::set<size_t>::iterator iterator;
|
| 37 |
+
typedef std::set<size_t>::const_iterator const_iterator;
|
| 38 |
+
const_iterator begin() const {
|
| 39 |
+
return m_elements.begin();
|
| 40 |
+
}
|
| 41 |
+
const_iterator end() const {
|
| 42 |
+
return m_elements.end();
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
AlignmentElement() {
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
size_t GetSize() const {
|
| 49 |
+
return m_elements.size();
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
void Merge(size_t align);
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
class AlignmentPhrase
|
| 56 |
+
{
|
| 57 |
+
protected:
|
| 58 |
+
std::vector<AlignmentElement> m_elements;
|
| 59 |
+
public:
|
| 60 |
+
AlignmentPhrase(size_t size)
|
| 61 |
+
:m_elements(size) {
|
| 62 |
+
}
|
| 63 |
+
void Merge(const AlignmentPhrase &newAlignment, const WordsRange &newAlignmentRange);
|
| 64 |
+
void Merge(const std::vector< std::vector<size_t> > &source);
|
| 65 |
+
size_t GetSize() const {
|
| 66 |
+
return m_elements.size();
|
| 67 |
+
}
|
| 68 |
+
const AlignmentElement &GetElement(size_t pos) const {
|
| 69 |
+
return m_elements[pos];
|
| 70 |
+
}
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
} // namespace
|
| 74 |
+
|
mosesdecoder/phrase-extract/DomainFeature.cpp
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "DomainFeature.h"
|
| 2 |
+
#include "ExtractionPhrasePair.h"
|
| 3 |
+
#include "tables-core.h"
|
| 4 |
+
#include "InputFileStream.h"
|
| 5 |
+
#include "util/tokenize.hh"
|
| 6 |
+
|
| 7 |
+
using namespace std;
|
| 8 |
+
|
| 9 |
+
namespace MosesTraining
|
| 10 |
+
{
|
| 11 |
+
|
| 12 |
+
// handling of domain names: load database with sentence-id / domain name info
|
| 13 |
+
void Domain::load( const std::string &domainFileName )
|
| 14 |
+
{
|
| 15 |
+
Moses::InputFileStream fileS( domainFileName );
|
| 16 |
+
istream *fileP = &fileS;
|
| 17 |
+
|
| 18 |
+
string line;
|
| 19 |
+
while(getline(*fileP, line)) {
|
| 20 |
+
// read
|
| 21 |
+
const vector< string > domainSpecLine = util::tokenize( line );
|
| 22 |
+
int lineNumber;
|
| 23 |
+
if (domainSpecLine.size() != 2 ||
|
| 24 |
+
! sscanf(domainSpecLine[0].c_str(), "%d", &lineNumber)) {
|
| 25 |
+
std::cerr << "ERROR: in domain specification line: '" << line << "'" << endl;
|
| 26 |
+
exit(1);
|
| 27 |
+
}
|
| 28 |
+
// store
|
| 29 |
+
const string &name = domainSpecLine[1];
|
| 30 |
+
spec.push_back( make_pair( lineNumber, name ));
|
| 31 |
+
if (name2id.find( name ) == name2id.end()) {
|
| 32 |
+
name2id[ name ] = list.size();
|
| 33 |
+
list.push_back( name );
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// get domain name based on sentence number
|
| 39 |
+
string Domain::getDomainOfSentence( int sentenceId ) const
|
| 40 |
+
{
|
| 41 |
+
for(size_t i=0; i<spec.size(); i++) {
|
| 42 |
+
if (sentenceId <= spec[i].first) {
|
| 43 |
+
return spec[i].second;
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
return "undefined";
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
DomainFeature::DomainFeature(const string& domainFile) : m_propertyKey("domain")
|
| 50 |
+
{
|
| 51 |
+
//process domain file
|
| 52 |
+
m_domain.load(domainFile);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
void DomainFeature::addPropertiesToPhrasePair(ExtractionPhrasePair &phrasePair,
|
| 56 |
+
float count,
|
| 57 |
+
int sentenceId) const
|
| 58 |
+
{
|
| 59 |
+
std::string value = m_domain.getDomainOfSentence(sentenceId);
|
| 60 |
+
phrasePair.AddProperty(m_propertyKey, value, count);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
void DomainFeature::add(const ScoreFeatureContext& context,
|
| 64 |
+
std::vector<float>& denseValues,
|
| 65 |
+
std::map<std::string,float>& sparseValues) const
|
| 66 |
+
{
|
| 67 |
+
const map<string,float> *domainCount = context.phrasePair.GetProperty(m_propertyKey);
|
| 68 |
+
assert( domainCount != NULL );
|
| 69 |
+
add(*domainCount,
|
| 70 |
+
context.phrasePair.GetCount(),
|
| 71 |
+
context.maybeLog,
|
| 72 |
+
denseValues, sparseValues);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
void SubsetDomainFeature::add(const map<string,float>& domainCount,
|
| 76 |
+
float count,
|
| 77 |
+
const MaybeLog& maybeLog,
|
| 78 |
+
std::vector<float>& denseValues,
|
| 79 |
+
std::map<std::string,float>& sparseValues) const
|
| 80 |
+
{
|
| 81 |
+
if (m_domain.list.size() > 6) {
|
| 82 |
+
UTIL_THROW_IF(m_domain.list.size() > 6, ScoreFeatureArgumentException,
|
| 83 |
+
"too many domains for core domain subset features");
|
| 84 |
+
}
|
| 85 |
+
size_t bitmap = 0;
|
| 86 |
+
for(size_t bit = 0; bit < m_domain.list.size(); bit++) {
|
| 87 |
+
if (domainCount.find( m_domain.list[ bit ] ) != domainCount.end()) {
|
| 88 |
+
bitmap += 1 << bit;
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
for(size_t i = 1; i < (1 << m_domain.list.size()); i++) {
|
| 92 |
+
denseValues.push_back(maybeLog( (bitmap == i) ? 2.718 : 1 ));
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
void SparseSubsetDomainFeature::add(const map<string,float>& domainCount,float count,
|
| 97 |
+
const MaybeLog& maybeLog,
|
| 98 |
+
std::vector<float>& denseValues,
|
| 99 |
+
std::map<std::string,float>& sparseValues) const
|
| 100 |
+
{
|
| 101 |
+
typedef vector<string>::const_iterator I;
|
| 102 |
+
ostringstream key;
|
| 103 |
+
key << "doms";
|
| 104 |
+
for (I i = m_domain.list.begin(); i != m_domain.list.end(); ++i) {
|
| 105 |
+
if (domainCount.find(*i) != domainCount.end()) {
|
| 106 |
+
key << "_" << *i;
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
sparseValues[key.str()] = 1;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
void RatioDomainFeature::add(const map<string,float>& domainCount,float count,
|
| 114 |
+
const MaybeLog& maybeLog,
|
| 115 |
+
std::vector<float>& denseValues,
|
| 116 |
+
std::map<std::string,float>& sparseValues) const
|
| 117 |
+
{
|
| 118 |
+
typedef vector< string >::const_iterator I;
|
| 119 |
+
for (I i = m_domain.list.begin(); i != m_domain.list.end(); i++ ) {
|
| 120 |
+
map<string,float>::const_iterator dci = domainCount.find(*i);
|
| 121 |
+
if (dci == domainCount.end() ) {
|
| 122 |
+
denseValues.push_back(maybeLog( 1 ));
|
| 123 |
+
} else {
|
| 124 |
+
denseValues.push_back(maybeLog(exp( dci->second / count ) ));
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
void SparseRatioDomainFeature::add(const map<string,float>& domainCount,float count,
|
| 131 |
+
const MaybeLog& maybeLog,
|
| 132 |
+
std::vector<float>& denseValues,
|
| 133 |
+
std::map<std::string,float>& sparseValues) const
|
| 134 |
+
{
|
| 135 |
+
typedef map< string, float >::const_iterator I;
|
| 136 |
+
for (I i=domainCount.begin(); i != domainCount.end(); i++) {
|
| 137 |
+
sparseValues["domr_" + i->first] = (i->second / count);
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
void IndicatorDomainFeature::add(const map<string,float>& domainCount,float count,
|
| 143 |
+
const MaybeLog& maybeLog,
|
| 144 |
+
std::vector<float>& denseValues,
|
| 145 |
+
std::map<std::string,float>& sparseValues) const
|
| 146 |
+
{
|
| 147 |
+
typedef vector< string >::const_iterator I;
|
| 148 |
+
for (I i = m_domain.list.begin(); i != m_domain.list.end(); i++ ) {
|
| 149 |
+
map<string,float>::const_iterator dci = domainCount.find(*i);
|
| 150 |
+
if (dci == domainCount.end() ) {
|
| 151 |
+
denseValues.push_back(maybeLog( 1 ));
|
| 152 |
+
} else {
|
| 153 |
+
denseValues.push_back(maybeLog(2.718));
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
void SparseIndicatorDomainFeature::add(const map<string,float>& domainCount,float count,
|
| 159 |
+
const MaybeLog& maybeLog,
|
| 160 |
+
std::vector<float>& denseValues,
|
| 161 |
+
std::map<std::string,float>& sparseValues) const
|
| 162 |
+
{
|
| 163 |
+
typedef map< string, float >::const_iterator I;
|
| 164 |
+
for (I i=domainCount.begin(); i != domainCount.end(); i++) {
|
| 165 |
+
sparseValues["dom_" + i->first] = 1;
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
}
|
| 170 |
+
|
mosesdecoder/phrase-extract/DomainFeature.h
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// $Id$
|
| 2 |
+
|
| 3 |
+
#ifndef _DOMAIN_H
|
| 4 |
+
#define _DOMAIN_H
|
| 5 |
+
|
| 6 |
+
#include <iostream>
|
| 7 |
+
#include <fstream>
|
| 8 |
+
#include <cassert>
|
| 9 |
+
#include <cstdlib>
|
| 10 |
+
#include <string>
|
| 11 |
+
#include <queue>
|
| 12 |
+
#include <map>
|
| 13 |
+
#include <cmath>
|
| 14 |
+
|
| 15 |
+
#include "ScoreFeature.h"
|
| 16 |
+
|
| 17 |
+
namespace MosesTraining
|
| 18 |
+
{
|
| 19 |
+
|
| 20 |
+
class Domain
|
| 21 |
+
{
|
| 22 |
+
public:
|
| 23 |
+
std::vector< std::pair< int, std::string > > spec;
|
| 24 |
+
std::vector< std::string > list;
|
| 25 |
+
std::map< std::string, int > name2id;
|
| 26 |
+
void load( const std::string &fileName );
|
| 27 |
+
std::string getDomainOfSentence( int sentenceId ) const;
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
class DomainFeature : public ScoreFeature
|
| 31 |
+
{
|
| 32 |
+
public:
|
| 33 |
+
|
| 34 |
+
DomainFeature(const std::string& domainFile);
|
| 35 |
+
|
| 36 |
+
void addPropertiesToPhrasePair(ExtractionPhrasePair &phrasePair,
|
| 37 |
+
float count,
|
| 38 |
+
int sentenceId) const;
|
| 39 |
+
|
| 40 |
+
void add(const ScoreFeatureContext& context,
|
| 41 |
+
std::vector<float>& denseValues,
|
| 42 |
+
std::map<std::string,float>& sparseValues) const;
|
| 43 |
+
|
| 44 |
+
protected:
|
| 45 |
+
/** Overridden in subclass */
|
| 46 |
+
virtual void add(const std::map<std::string,float>& domainCounts, float count,
|
| 47 |
+
const MaybeLog& maybeLog,
|
| 48 |
+
std::vector<float>& denseValues,
|
| 49 |
+
std::map<std::string,float>& sparseValues) const = 0;
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
Domain m_domain;
|
| 53 |
+
|
| 54 |
+
const std::string m_propertyKey;
|
| 55 |
+
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
class SubsetDomainFeature : public DomainFeature
|
| 59 |
+
{
|
| 60 |
+
public:
|
| 61 |
+
SubsetDomainFeature(const std::string& domainFile) :
|
| 62 |
+
DomainFeature(domainFile) {}
|
| 63 |
+
|
| 64 |
+
protected:
|
| 65 |
+
virtual void add(const std::map<std::string,float>& domainCounts, float count,
|
| 66 |
+
const MaybeLog& maybeLog,
|
| 67 |
+
std::vector<float>& denseValues,
|
| 68 |
+
std::map<std::string,float>& sparseValues) const;
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
class SparseSubsetDomainFeature : public DomainFeature
|
| 72 |
+
{
|
| 73 |
+
public:
|
| 74 |
+
SparseSubsetDomainFeature(const std::string& domainFile) :
|
| 75 |
+
DomainFeature(domainFile) {}
|
| 76 |
+
|
| 77 |
+
protected:
|
| 78 |
+
virtual void add(const std::map<std::string,float>& domainCounts, float count,
|
| 79 |
+
const MaybeLog& maybeLog,
|
| 80 |
+
std::vector<float>& denseValues,
|
| 81 |
+
std::map<std::string,float>& sparseValues) const;
|
| 82 |
+
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
class IndicatorDomainFeature : public DomainFeature
|
| 86 |
+
{
|
| 87 |
+
public:
|
| 88 |
+
IndicatorDomainFeature(const std::string& domainFile) :
|
| 89 |
+
DomainFeature(domainFile) {}
|
| 90 |
+
|
| 91 |
+
protected:
|
| 92 |
+
virtual void add(const std::map<std::string,float>& domainCounts, float count,
|
| 93 |
+
const MaybeLog& maybeLog,
|
| 94 |
+
std::vector<float>& denseValues,
|
| 95 |
+
std::map<std::string,float>& sparseValues) const;
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class SparseIndicatorDomainFeature : public DomainFeature
|
| 100 |
+
{
|
| 101 |
+
public:
|
| 102 |
+
SparseIndicatorDomainFeature(const std::string& domainFile) :
|
| 103 |
+
DomainFeature(domainFile) {}
|
| 104 |
+
|
| 105 |
+
protected:
|
| 106 |
+
virtual void add(const std::map<std::string,float>& domainCounts, float count,
|
| 107 |
+
const MaybeLog& maybeLog,
|
| 108 |
+
std::vector<float>& denseValues,
|
| 109 |
+
std::map<std::string,float>& sparseValues) const;
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class RatioDomainFeature : public DomainFeature
|
| 114 |
+
{
|
| 115 |
+
public:
|
| 116 |
+
RatioDomainFeature(const std::string& domainFile) :
|
| 117 |
+
DomainFeature(domainFile) {}
|
| 118 |
+
|
| 119 |
+
protected:
|
| 120 |
+
virtual void add(const std::map<std::string,float>& domainCounts, float count,
|
| 121 |
+
const MaybeLog& maybeLog,
|
| 122 |
+
std::vector<float>& denseValues,
|
| 123 |
+
std::map<std::string,float>& sparseValues) const;
|
| 124 |
+
};
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class SparseRatioDomainFeature : public DomainFeature
|
| 128 |
+
{
|
| 129 |
+
public:
|
| 130 |
+
SparseRatioDomainFeature(const std::string& domainFile) :
|
| 131 |
+
DomainFeature(domainFile) {}
|
| 132 |
+
|
| 133 |
+
protected:
|
| 134 |
+
virtual void add(const std::map<std::string,float>& domainCounts, float count,
|
| 135 |
+
const MaybeLog& maybeLog,
|
| 136 |
+
std::vector<float>& denseValues,
|
| 137 |
+
std::map<std::string,float>& sparseValues) const;
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
#endif
|
mosesdecoder/phrase-extract/HoleCollection.cpp
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***********************************************************************
|
| 2 |
+
Moses - factored phrase-based language decoder
|
| 3 |
+
Copyright (C) 2010 University of Edinburgh
|
| 4 |
+
|
| 5 |
+
This library is free software; you can redistribute it and/or
|
| 6 |
+
modify it under the terms of the GNU Lesser General Public
|
| 7 |
+
License as published by the Free Software Foundation; either
|
| 8 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This library is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 13 |
+
Lesser General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU Lesser General Public
|
| 16 |
+
License along with this library; if not, write to the Free Software
|
| 17 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 18 |
+
***********************************************************************/
|
| 19 |
+
|
| 20 |
+
#include "HoleCollection.h"
|
| 21 |
+
|
| 22 |
+
#include <algorithm>
|
| 23 |
+
|
| 24 |
+
namespace MosesTraining
|
| 25 |
+
{
|
| 26 |
+
|
| 27 |
+
void HoleCollection::SortSourceHoles()
|
| 28 |
+
{
|
| 29 |
+
assert(m_sortedSourceHoles.size() == 0);
|
| 30 |
+
|
| 31 |
+
// add
|
| 32 |
+
HoleList::iterator iter;
|
| 33 |
+
for (iter = m_holes.begin(); iter != m_holes.end(); ++iter) {
|
| 34 |
+
Hole &currHole = *iter;
|
| 35 |
+
m_sortedSourceHoles.push_back(&currHole);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// sort
|
| 39 |
+
std::sort(m_sortedSourceHoles.begin(), m_sortedSourceHoles.end(), HoleSourceOrderer());
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
void HoleCollection::Add(int startT, int endT, int startS, int endS)
|
| 43 |
+
{
|
| 44 |
+
Hole hole(startS, endS, startT, endT);
|
| 45 |
+
m_scope.push_back(Scope(hole));
|
| 46 |
+
m_sourceHoleStartPoints.push_back(startS);
|
| 47 |
+
m_sourceHoleEndPoints.push_back(endS);
|
| 48 |
+
m_holes.push_back(hole);
|
| 49 |
+
m_sortedSourceHoles.clear();
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
void HoleCollection::RemoveLast()
|
| 53 |
+
{
|
| 54 |
+
m_scope.pop_back();
|
| 55 |
+
m_sourceHoleStartPoints.pop_back();
|
| 56 |
+
m_sourceHoleEndPoints.pop_back();
|
| 57 |
+
m_holes.pop_back();
|
| 58 |
+
m_sortedSourceHoles.clear();
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
int HoleCollection::Scope(const Hole &proposedHole) const
|
| 62 |
+
{
|
| 63 |
+
const int holeStart = proposedHole.GetStart(0);
|
| 64 |
+
const int holeEnd = proposedHole.GetEnd(0);
|
| 65 |
+
int scope = m_scope.back();
|
| 66 |
+
if (holeStart == m_sourcePhraseStart.back() ||
|
| 67 |
+
find(m_sourceHoleEndPoints.begin(), m_sourceHoleEndPoints.end(), holeStart-1) != m_sourceHoleEndPoints.end()) {
|
| 68 |
+
++scope; // Adding hole would introduce choice point at start of hole.
|
| 69 |
+
}
|
| 70 |
+
if (holeEnd == m_sourcePhraseEnd.back() ||
|
| 71 |
+
find(m_sourceHoleStartPoints.begin(), m_sourceHoleStartPoints.end(), holeEnd-1) != m_sourceHoleStartPoints.end()) {
|
| 72 |
+
++scope; // Adding hole would introduce choice point at end of hole.
|
| 73 |
+
}
|
| 74 |
+
return scope;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
}
|
mosesdecoder/phrase-extract/HoleCollection.h
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***********************************************************************
|
| 2 |
+
Moses - factored phrase-based language decoder
|
| 3 |
+
Copyright (C) 2010 University of Edinburgh
|
| 4 |
+
|
| 5 |
+
This library is free software; you can redistribute it and/or
|
| 6 |
+
modify it under the terms of the GNU Lesser General Public
|
| 7 |
+
License as published by the Free Software Foundation; either
|
| 8 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This library is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 13 |
+
Lesser General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU Lesser General Public
|
| 16 |
+
License along with this library; if not, write to the Free Software
|
| 17 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 18 |
+
***********************************************************************/
|
| 19 |
+
|
| 20 |
+
#pragma once
|
| 21 |
+
#ifndef HOLECOLLECTION_H_INCLUDED_
|
| 22 |
+
#define HOLECOLLECTION_H_INCLUDED_
|
| 23 |
+
|
| 24 |
+
#include <set>
|
| 25 |
+
#include <vector>
|
| 26 |
+
|
| 27 |
+
#include "Hole.h"
|
| 28 |
+
|
| 29 |
+
namespace MosesTraining
|
| 30 |
+
{
|
| 31 |
+
|
| 32 |
+
class HoleCollection
|
| 33 |
+
{
|
| 34 |
+
protected:
|
| 35 |
+
HoleList m_holes;
|
| 36 |
+
std::vector<Hole*> m_sortedSourceHoles;
|
| 37 |
+
std::vector<int> m_sourceHoleStartPoints;
|
| 38 |
+
std::vector<int> m_sourceHoleEndPoints;
|
| 39 |
+
std::vector<int> m_scope;
|
| 40 |
+
std::vector<int> m_sourcePhraseStart;
|
| 41 |
+
std::vector<int> m_sourcePhraseEnd;
|
| 42 |
+
|
| 43 |
+
public:
|
| 44 |
+
HoleCollection(int sourcePhraseStart, int sourcePhraseEnd)
|
| 45 |
+
: m_scope(1, 0)
|
| 46 |
+
, m_sourcePhraseStart(1, sourcePhraseStart)
|
| 47 |
+
, m_sourcePhraseEnd(1, sourcePhraseEnd) {
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
const HoleList &GetHoles() const {
|
| 51 |
+
return m_holes;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
HoleList &GetHoles() {
|
| 55 |
+
return m_holes;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
std::vector<Hole*> &GetSortedSourceHoles() {
|
| 59 |
+
return m_sortedSourceHoles;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
void Add(int startT, int endT, int startS, int endS);
|
| 63 |
+
|
| 64 |
+
void RemoveLast();
|
| 65 |
+
|
| 66 |
+
bool OverlapSource(const Hole &sourceHole) const {
|
| 67 |
+
HoleList::const_iterator iter;
|
| 68 |
+
for (iter = m_holes.begin(); iter != m_holes.end(); ++iter) {
|
| 69 |
+
const Hole &currHole = *iter;
|
| 70 |
+
if (currHole.Overlap(sourceHole, 0))
|
| 71 |
+
return true;
|
| 72 |
+
}
|
| 73 |
+
return false;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
bool ConsecSource(const Hole &sourceHole) const {
|
| 77 |
+
HoleList::const_iterator iter;
|
| 78 |
+
for (iter = m_holes.begin(); iter != m_holes.end(); ++iter) {
|
| 79 |
+
const Hole &currHole = *iter;
|
| 80 |
+
if (currHole.Neighbor(sourceHole, 0))
|
| 81 |
+
return true;
|
| 82 |
+
}
|
| 83 |
+
return false;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// Determine the scope that would result from adding the given hole.
|
| 87 |
+
int Scope(const Hole &proposedHole) const;
|
| 88 |
+
|
| 89 |
+
void SortSourceHoles();
|
| 90 |
+
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
#endif
|
mosesdecoder/phrase-extract/InputFileStream.cpp
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// $Id: InputFileStream.cpp 2780 2010-01-29 17:11:17Z bojar $
|
| 2 |
+
|
| 3 |
+
/***********************************************************************
|
| 4 |
+
Moses - factored phrase-based language decoder
|
| 5 |
+
Copyright (C) 2006 University of Edinburgh
|
| 6 |
+
|
| 7 |
+
This library is free software; you can redistribute it and/or
|
| 8 |
+
modify it under the terms of the GNU Lesser General Public
|
| 9 |
+
License as published by the Free Software Foundation; either
|
| 10 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 11 |
+
|
| 12 |
+
This library is distributed in the hope that it will be useful,
|
| 13 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 14 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 15 |
+
Lesser General Public License for more details.
|
| 16 |
+
|
| 17 |
+
You should have received a copy of the GNU Lesser General Public
|
| 18 |
+
License along with this library; if not, write to the Free Software
|
| 19 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 20 |
+
***********************************************************************/
|
| 21 |
+
|
| 22 |
+
#include "InputFileStream.h"
|
| 23 |
+
#include "gzfilebuf.h"
|
| 24 |
+
#include <iostream>
|
| 25 |
+
|
| 26 |
+
using namespace std;
|
| 27 |
+
|
| 28 |
+
namespace Moses
|
| 29 |
+
{
|
| 30 |
+
InputFileStream::InputFileStream(const std::string &filePath)
|
| 31 |
+
: std::istream(NULL)
|
| 32 |
+
, m_streambuf(NULL)
|
| 33 |
+
{
|
| 34 |
+
if (filePath.size() > 3 &&
|
| 35 |
+
filePath.substr(filePath.size() - 3, 3) == ".gz") {
|
| 36 |
+
m_streambuf = new gzfilebuf(filePath.c_str());
|
| 37 |
+
} else {
|
| 38 |
+
std::filebuf* fb = new std::filebuf();
|
| 39 |
+
fb = fb->open(filePath.c_str(), std::ios::in);
|
| 40 |
+
if (! fb) {
|
| 41 |
+
cerr << "Can't read " << filePath.c_str() << endl;
|
| 42 |
+
exit(1);
|
| 43 |
+
}
|
| 44 |
+
m_streambuf = fb;
|
| 45 |
+
}
|
| 46 |
+
this->init(m_streambuf);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
InputFileStream::~InputFileStream()
|
| 50 |
+
{
|
| 51 |
+
delete m_streambuf;
|
| 52 |
+
m_streambuf = NULL;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
void InputFileStream::Close()
|
| 56 |
+
{
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
}
|
| 61 |
+
|
mosesdecoder/phrase-extract/InputFileStream.h
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// $Id: InputFileStream.h 2939 2010-02-24 11:15:44Z jfouet $
|
| 2 |
+
|
| 3 |
+
/***********************************************************************
|
| 4 |
+
Moses - factored phrase-based language decoder
|
| 5 |
+
Copyright (C) 2006 University of Edinburgh
|
| 6 |
+
|
| 7 |
+
This library is free software; you can redistribute it and/or
|
| 8 |
+
modify it under the terms of the GNU Lesser General Public
|
| 9 |
+
License as published by the Free Software Foundation; either
|
| 10 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 11 |
+
|
| 12 |
+
This library is distributed in the hope that it will be useful,
|
| 13 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 14 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 15 |
+
Lesser General Public License for more details.
|
| 16 |
+
|
| 17 |
+
You should have received a copy of the GNU Lesser General Public
|
| 18 |
+
License along with this library; if not, write to the Free Software
|
| 19 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 20 |
+
***********************************************************************/
|
| 21 |
+
|
| 22 |
+
#ifndef moses_InputFileStream_h
|
| 23 |
+
#define moses_InputFileStream_h
|
| 24 |
+
|
| 25 |
+
#include <cstdlib>
|
| 26 |
+
#include <fstream>
|
| 27 |
+
#include <string>
|
| 28 |
+
|
| 29 |
+
namespace Moses
|
| 30 |
+
{
|
| 31 |
+
|
| 32 |
+
/** Used in place of std::istream, can read zipped files if it ends in .gz
|
| 33 |
+
*/
|
| 34 |
+
class InputFileStream : public std::istream
|
| 35 |
+
{
|
| 36 |
+
protected:
|
| 37 |
+
std::streambuf *m_streambuf;
|
| 38 |
+
public:
|
| 39 |
+
|
| 40 |
+
explicit InputFileStream(const std::string &filePath);
|
| 41 |
+
~InputFileStream();
|
| 42 |
+
|
| 43 |
+
void Close();
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
#endif
|
mosesdecoder/phrase-extract/InternalStructFeature.h
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <iostream>
|
| 2 |
+
#include <fstream>
|
| 3 |
+
#include <cassert>
|
| 4 |
+
#include <cstdlib>
|
| 5 |
+
#include <string>
|
| 6 |
+
#include <queue>
|
| 7 |
+
#include <map>
|
| 8 |
+
#include <cmath>
|
| 9 |
+
|
| 10 |
+
#include "ScoreFeature.h"
|
| 11 |
+
#include "extract-ghkm/Node.h"
|
| 12 |
+
|
| 13 |
+
namespace MosesTraining
|
| 14 |
+
{
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class InternalStructFeature : public ScoreFeature
|
| 18 |
+
{
|
| 19 |
+
public:
|
| 20 |
+
InternalStructFeature() : m_type(0) {};
|
| 21 |
+
/** Add the values for this feature function. */
|
| 22 |
+
void add(const ScoreFeatureContext& context,
|
| 23 |
+
std::vector<float>& denseValues,
|
| 24 |
+
std::map<std::string,float>& sparseValues) const;
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
protected:
|
| 28 |
+
/** Overridden in subclass */
|
| 29 |
+
virtual void add(const std::string *treeFragment,
|
| 30 |
+
float count,
|
| 31 |
+
std::vector<float>& denseValues,
|
| 32 |
+
std::map<std::string,float>& sparseValues) const = 0;
|
| 33 |
+
int m_type;
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
class InternalStructFeatureDense : public InternalStructFeature
|
| 37 |
+
{
|
| 38 |
+
public:
|
| 39 |
+
InternalStructFeatureDense()
|
| 40 |
+
:InternalStructFeature() {
|
| 41 |
+
m_type=1;
|
| 42 |
+
} //std::cout<<"InternalStructFeatureDense: Construct "<<m_type<<"\n";}
|
| 43 |
+
protected:
|
| 44 |
+
virtual void add(const std::string *treeFragment,
|
| 45 |
+
float count,
|
| 46 |
+
std::vector<float>& denseValues,
|
| 47 |
+
std::map<std::string,float>& sparseValues) const;
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
class InternalStructFeatureSparse : public InternalStructFeature
|
| 51 |
+
{
|
| 52 |
+
public:
|
| 53 |
+
InternalStructFeatureSparse()
|
| 54 |
+
:InternalStructFeature() {
|
| 55 |
+
m_type=2;
|
| 56 |
+
}// std::cout<<"InternalStructFeatureSparse: Construct "<<m_type<<"\n";}
|
| 57 |
+
protected:
|
| 58 |
+
virtual void add(const std::string *treeFragment,
|
| 59 |
+
float count,
|
| 60 |
+
std::vector<float>& denseValues,
|
| 61 |
+
std::map<std::string,float>& sparseValues) const;
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
}
|
mosesdecoder/phrase-extract/OutputFileStream.h
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// $Id: InputFileStream.h 2939 2010-02-24 11:15:44Z jfouet $
|
| 2 |
+
|
| 3 |
+
/***********************************************************************
|
| 4 |
+
Moses - factored phrase-based language decoder
|
| 5 |
+
Copyright (C) 2006 University of Edinburgh
|
| 6 |
+
|
| 7 |
+
This library is free software; you can redistribute it and/or
|
| 8 |
+
modify it under the terms of the GNU Lesser General Public
|
| 9 |
+
License as published by the Free Software Foundation; either
|
| 10 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 11 |
+
|
| 12 |
+
This library is distributed in the hope that it will be useful,
|
| 13 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 14 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 15 |
+
Lesser General Public License for more details.
|
| 16 |
+
|
| 17 |
+
You should have received a copy of the GNU Lesser General Public
|
| 18 |
+
License along with this library; if not, write to the Free Software
|
| 19 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 20 |
+
***********************************************************************/
|
| 21 |
+
|
| 22 |
+
#pragma once
|
| 23 |
+
|
| 24 |
+
#include <cstdlib>
|
| 25 |
+
#include <fstream>
|
| 26 |
+
#include <string>
|
| 27 |
+
#include <iostream>
|
| 28 |
+
#include <boost/iostreams/filtering_stream.hpp>
|
| 29 |
+
|
| 30 |
+
namespace Moses
|
| 31 |
+
{
|
| 32 |
+
|
| 33 |
+
/** Version of std::ostream with transparent compression.
|
| 34 |
+
*
|
| 35 |
+
* Transparently compresses output when writing to a file whose name ends in
|
| 36 |
+
* ".gz". Or, writes to stdout instead of a file when given a filename
|
| 37 |
+
* consisting of just a dash ("-").
|
| 38 |
+
*/
|
| 39 |
+
class OutputFileStream : public boost::iostreams::filtering_ostream
|
| 40 |
+
{
|
| 41 |
+
private:
|
| 42 |
+
/** File that needs flushing & closing when we close this stream.
|
| 43 |
+
*
|
| 44 |
+
* Is NULL when no file is opened, e.g. when writing to standard output.
|
| 45 |
+
*/
|
| 46 |
+
std::ofstream *m_outFile;
|
| 47 |
+
|
| 48 |
+
/// Is this stream open?
|
| 49 |
+
bool m_open;
|
| 50 |
+
|
| 51 |
+
public:
|
| 52 |
+
/** Create an unopened OutputFileStream.
|
| 53 |
+
*
|
| 54 |
+
* Until it's been opened, nothing can be done with this stream.
|
| 55 |
+
*/
|
| 56 |
+
OutputFileStream();
|
| 57 |
+
|
| 58 |
+
/// Create an OutputFileStream, and open it by calling Open().
|
| 59 |
+
OutputFileStream(const std::string &filePath);
|
| 60 |
+
virtual ~OutputFileStream();
|
| 61 |
+
|
| 62 |
+
// TODO: Can we please just always throw an exception when this fails?
|
| 63 |
+
/** Open stream.
|
| 64 |
+
*
|
| 65 |
+
* If filePath is "-" (just a dash), this opens the stream for writing to
|
| 66 |
+
* standard output. Otherwise, it opens the given file. If the filename
|
| 67 |
+
* has the ".gz" suffix, output will be transparently compressed.
|
| 68 |
+
*
|
| 69 |
+
* Call Close() to close the file.
|
| 70 |
+
*
|
| 71 |
+
* Returns whether opening the file was successful. It may also throw an
|
| 72 |
+
* exception on failure.
|
| 73 |
+
*/
|
| 74 |
+
bool Open(const std::string &filePath);
|
| 75 |
+
|
| 76 |
+
/// Flush and close stream. After this, the stream can be opened again.
|
| 77 |
+
void Close();
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
}
|
| 81 |
+
|
mosesdecoder/phrase-extract/PhraseExtractionOptions.h
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
/***********************************************************************
|
| 3 |
+
Moses - factored phrase-based language decoder
|
| 4 |
+
Copyright (C) 2010 University of Edinburgh
|
| 5 |
+
|
| 6 |
+
This library is free software; you can redistribute it and/or
|
| 7 |
+
modify it under the terms of the GNU Lesser General Public
|
| 8 |
+
License as published by the Free Software Foundation; either
|
| 9 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 10 |
+
|
| 11 |
+
This library is distributed in the hope that it will be useful,
|
| 12 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 13 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 14 |
+
Lesser General Public License for more details.
|
| 15 |
+
|
| 16 |
+
You should have received a copy of the GNU Lesser General Public
|
| 17 |
+
License along with this library; if not, write to the Free Software
|
| 18 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 19 |
+
***********************************************************************/
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
#include <string>
|
| 23 |
+
#include <vector>
|
| 24 |
+
|
| 25 |
+
namespace MosesTraining
|
| 26 |
+
{
|
| 27 |
+
enum REO_MODEL_TYPE {REO_MSD, REO_MSLR, REO_MONO};
|
| 28 |
+
enum REO_POS {LEFT, RIGHT, DLEFT, DRIGHT, UNKNOWN};
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class PhraseExtractionOptions
|
| 32 |
+
{
|
| 33 |
+
|
| 34 |
+
public:
|
| 35 |
+
int maxPhraseLength;
|
| 36 |
+
int minPhraseLength;
|
| 37 |
+
std::string separator;
|
| 38 |
+
|
| 39 |
+
private:
|
| 40 |
+
bool allModelsOutputFlag;
|
| 41 |
+
bool wordModel;
|
| 42 |
+
REO_MODEL_TYPE wordType;
|
| 43 |
+
bool phraseModel;
|
| 44 |
+
REO_MODEL_TYPE phraseType;
|
| 45 |
+
bool hierModel;
|
| 46 |
+
REO_MODEL_TYPE hierType;
|
| 47 |
+
bool orientationFlag;
|
| 48 |
+
bool translationFlag;
|
| 49 |
+
bool includeSentenceIdFlag; //include sentence id in extract file
|
| 50 |
+
bool onlyOutputSpanInfo;
|
| 51 |
+
bool gzOutput;
|
| 52 |
+
std::string instanceWeightsFile; //weights for each sentence
|
| 53 |
+
bool targetConstituentConstrainedFlag;
|
| 54 |
+
bool targetConstituentBoundariesFlag;
|
| 55 |
+
bool flexScoreFlag;
|
| 56 |
+
bool singleWordHeuristicFlag;
|
| 57 |
+
|
| 58 |
+
public:
|
| 59 |
+
std::vector<std::string> placeholders;
|
| 60 |
+
bool debug;
|
| 61 |
+
|
| 62 |
+
PhraseExtractionOptions(const int initmaxPhraseLength):
|
| 63 |
+
maxPhraseLength(initmaxPhraseLength),
|
| 64 |
+
minPhraseLength(3),
|
| 65 |
+
separator("|||"),
|
| 66 |
+
allModelsOutputFlag(false),
|
| 67 |
+
wordModel(false),
|
| 68 |
+
wordType(REO_MSD),
|
| 69 |
+
phraseModel(false),
|
| 70 |
+
phraseType(REO_MSD),
|
| 71 |
+
hierModel(false),
|
| 72 |
+
hierType(REO_MSD),
|
| 73 |
+
orientationFlag(false),
|
| 74 |
+
translationFlag(true),
|
| 75 |
+
includeSentenceIdFlag(false),
|
| 76 |
+
onlyOutputSpanInfo(false),
|
| 77 |
+
gzOutput(false),
|
| 78 |
+
targetConstituentConstrainedFlag(false),
|
| 79 |
+
targetConstituentBoundariesFlag(false),
|
| 80 |
+
flexScoreFlag(false),
|
| 81 |
+
singleWordHeuristicFlag(false),
|
| 82 |
+
debug(false) {
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
//functions for initialization of options
|
| 86 |
+
void initAllModelsOutputFlag(const bool initallModelsOutputFlag) {
|
| 87 |
+
allModelsOutputFlag=initallModelsOutputFlag;
|
| 88 |
+
}
|
| 89 |
+
void initWordModel(const bool initwordModel) {
|
| 90 |
+
wordModel=initwordModel;
|
| 91 |
+
}
|
| 92 |
+
void initWordType(REO_MODEL_TYPE initwordType ) {
|
| 93 |
+
wordType=initwordType;
|
| 94 |
+
}
|
| 95 |
+
void initPhraseModel(const bool initphraseModel ) {
|
| 96 |
+
phraseModel=initphraseModel;
|
| 97 |
+
}
|
| 98 |
+
void initPhraseType(REO_MODEL_TYPE initphraseType) {
|
| 99 |
+
phraseType=initphraseType;
|
| 100 |
+
}
|
| 101 |
+
void initHierModel(const bool inithierModel) {
|
| 102 |
+
hierModel=inithierModel;
|
| 103 |
+
}
|
| 104 |
+
void initHierType(REO_MODEL_TYPE inithierType) {
|
| 105 |
+
hierType=inithierType;
|
| 106 |
+
}
|
| 107 |
+
void initOrientationFlag(const bool initorientationFlag) {
|
| 108 |
+
orientationFlag=initorientationFlag;
|
| 109 |
+
}
|
| 110 |
+
void initTranslationFlag(const bool inittranslationFlag) {
|
| 111 |
+
translationFlag=inittranslationFlag;
|
| 112 |
+
}
|
| 113 |
+
void initIncludeSentenceIdFlag(const bool initincludeSentenceIdFlag) {
|
| 114 |
+
includeSentenceIdFlag=initincludeSentenceIdFlag;
|
| 115 |
+
}
|
| 116 |
+
void initOnlyOutputSpanInfo(const bool initonlyOutputSpanInfo) {
|
| 117 |
+
onlyOutputSpanInfo= initonlyOutputSpanInfo;
|
| 118 |
+
}
|
| 119 |
+
void initGzOutput (const bool initgzOutput) {
|
| 120 |
+
gzOutput= initgzOutput;
|
| 121 |
+
}
|
| 122 |
+
void initInstanceWeightsFile(const char* initInstanceWeightsFile) {
|
| 123 |
+
instanceWeightsFile = std::string(initInstanceWeightsFile);
|
| 124 |
+
}
|
| 125 |
+
void initTargetConstituentConstrainedFlag(const bool initTargetConstituentConstrainedFlag) {
|
| 126 |
+
targetConstituentConstrainedFlag = initTargetConstituentConstrainedFlag;
|
| 127 |
+
}
|
| 128 |
+
void initTargetConstituentBoundariesFlag(const bool initTargetConstituentBoundariesFlag) {
|
| 129 |
+
targetConstituentBoundariesFlag = initTargetConstituentBoundariesFlag;
|
| 130 |
+
}
|
| 131 |
+
void initFlexScoreFlag(const bool initflexScoreFlag) {
|
| 132 |
+
flexScoreFlag=initflexScoreFlag;
|
| 133 |
+
}
|
| 134 |
+
void initSingleWordHeuristicFlag(const bool initSingleWordHeuristicFlag) {
|
| 135 |
+
singleWordHeuristicFlag = initSingleWordHeuristicFlag;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
// functions for getting values
|
| 139 |
+
bool isAllModelsOutputFlag() const {
|
| 140 |
+
return allModelsOutputFlag;
|
| 141 |
+
}
|
| 142 |
+
bool isWordModel() const {
|
| 143 |
+
return wordModel;
|
| 144 |
+
}
|
| 145 |
+
REO_MODEL_TYPE isWordType() const {
|
| 146 |
+
return wordType;
|
| 147 |
+
}
|
| 148 |
+
bool isPhraseModel() const {
|
| 149 |
+
return phraseModel;
|
| 150 |
+
}
|
| 151 |
+
REO_MODEL_TYPE isPhraseType() const {
|
| 152 |
+
return phraseType;
|
| 153 |
+
}
|
| 154 |
+
bool isHierModel() const {
|
| 155 |
+
return hierModel;
|
| 156 |
+
}
|
| 157 |
+
REO_MODEL_TYPE isHierType() const {
|
| 158 |
+
return hierType;
|
| 159 |
+
}
|
| 160 |
+
bool isOrientationFlag() const {
|
| 161 |
+
return orientationFlag;
|
| 162 |
+
}
|
| 163 |
+
bool isTranslationFlag() const {
|
| 164 |
+
return translationFlag;
|
| 165 |
+
}
|
| 166 |
+
bool isIncludeSentenceIdFlag() const {
|
| 167 |
+
return includeSentenceIdFlag;
|
| 168 |
+
}
|
| 169 |
+
bool isOnlyOutputSpanInfo() const {
|
| 170 |
+
return onlyOutputSpanInfo;
|
| 171 |
+
}
|
| 172 |
+
bool isGzOutput () const {
|
| 173 |
+
return gzOutput;
|
| 174 |
+
}
|
| 175 |
+
std::string getInstanceWeightsFile() const {
|
| 176 |
+
return instanceWeightsFile;
|
| 177 |
+
}
|
| 178 |
+
bool isTargetConstituentConstrainedFlag() const {
|
| 179 |
+
return targetConstituentConstrainedFlag;
|
| 180 |
+
}
|
| 181 |
+
bool isTargetConstituentBoundariesFlag() const {
|
| 182 |
+
return targetConstituentBoundariesFlag;
|
| 183 |
+
}
|
| 184 |
+
bool isFlexScoreFlag() const {
|
| 185 |
+
return flexScoreFlag;
|
| 186 |
+
}
|
| 187 |
+
bool isSingleWordHeuristicFlag() const {
|
| 188 |
+
return singleWordHeuristicFlag;
|
| 189 |
+
}
|
| 190 |
+
};
|
| 191 |
+
|
| 192 |
+
}
|
| 193 |
+
|
mosesdecoder/phrase-extract/RuleExtractionOptions.h
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***********************************************************************
|
| 2 |
+
Moses - factored phrase-based language decoder
|
| 3 |
+
Copyright (C) 2010 University of Edinburgh
|
| 4 |
+
|
| 5 |
+
This library is free software; you can redistribute it and/or
|
| 6 |
+
modify it under the terms of the GNU Lesser General Public
|
| 7 |
+
License as published by the Free Software Foundation; either
|
| 8 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This library is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 13 |
+
Lesser General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU Lesser General Public
|
| 16 |
+
License along with this library; if not, write to the Free Software
|
| 17 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 18 |
+
***********************************************************************/
|
| 19 |
+
|
| 20 |
+
#pragma once
|
| 21 |
+
|
| 22 |
+
namespace MosesTraining
|
| 23 |
+
{
|
| 24 |
+
|
| 25 |
+
struct RuleExtractionOptions {
|
| 26 |
+
public:
|
| 27 |
+
int maxSpan;
|
| 28 |
+
int minHoleSource;
|
| 29 |
+
int minHoleTarget;
|
| 30 |
+
int minWords;
|
| 31 |
+
int maxSymbolsTarget;
|
| 32 |
+
int maxSymbolsSource;
|
| 33 |
+
int maxNonTerm;
|
| 34 |
+
int maxScope;
|
| 35 |
+
bool onlyDirectFlag;
|
| 36 |
+
bool glueGrammarFlag;
|
| 37 |
+
bool unknownWordLabelFlag;
|
| 38 |
+
bool onlyOutputSpanInfo;
|
| 39 |
+
bool noFileLimit;
|
| 40 |
+
bool properConditioning;
|
| 41 |
+
bool nonTermFirstWord;
|
| 42 |
+
bool nonTermConsecTarget;
|
| 43 |
+
bool nonTermConsecSource;
|
| 44 |
+
bool requireAlignedWord;
|
| 45 |
+
bool sourceSyntax;
|
| 46 |
+
bool targetSyntax;
|
| 47 |
+
bool targetSyntacticPreferences;
|
| 48 |
+
bool duplicateRules;
|
| 49 |
+
bool fractionalCounting;
|
| 50 |
+
bool pcfgScore;
|
| 51 |
+
bool gzOutput;
|
| 52 |
+
bool unpairedExtractFormat;
|
| 53 |
+
bool conditionOnTargetLhs;
|
| 54 |
+
bool boundaryRules;
|
| 55 |
+
bool flexScoreFlag;
|
| 56 |
+
bool phraseOrientation;
|
| 57 |
+
|
| 58 |
+
RuleExtractionOptions()
|
| 59 |
+
: maxSpan(10)
|
| 60 |
+
, minHoleSource(2)
|
| 61 |
+
, minHoleTarget(1)
|
| 62 |
+
, minWords(1)
|
| 63 |
+
, maxSymbolsTarget(999)
|
| 64 |
+
, maxSymbolsSource(5)
|
| 65 |
+
, maxNonTerm(2)
|
| 66 |
+
, maxScope(999)
|
| 67 |
+
// int minHoleSize(1)
|
| 68 |
+
// int minSubPhraseSize(1) // minimum size of a remaining lexical phrase
|
| 69 |
+
, onlyDirectFlag(false)
|
| 70 |
+
, glueGrammarFlag(false)
|
| 71 |
+
, unknownWordLabelFlag(false)
|
| 72 |
+
, onlyOutputSpanInfo(false)
|
| 73 |
+
, noFileLimit(false)
|
| 74 |
+
//bool zipFiles(false)
|
| 75 |
+
, properConditioning(false)
|
| 76 |
+
, nonTermFirstWord(true)
|
| 77 |
+
, nonTermConsecTarget(true)
|
| 78 |
+
, nonTermConsecSource(false)
|
| 79 |
+
, requireAlignedWord(true)
|
| 80 |
+
, sourceSyntax(false)
|
| 81 |
+
, targetSyntax(false)
|
| 82 |
+
, targetSyntacticPreferences(false)
|
| 83 |
+
, duplicateRules(true)
|
| 84 |
+
, fractionalCounting(true)
|
| 85 |
+
, pcfgScore(false)
|
| 86 |
+
, gzOutput(false)
|
| 87 |
+
, unpairedExtractFormat(false)
|
| 88 |
+
, conditionOnTargetLhs(false)
|
| 89 |
+
, boundaryRules(false)
|
| 90 |
+
, flexScoreFlag(false)
|
| 91 |
+
, phraseOrientation(false) {}
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
}
|
| 95 |
+
|
mosesdecoder/phrase-extract/ScoreFeature.cpp
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***********************************************************************
|
| 2 |
+
Moses - factored phrase-based language decoder
|
| 3 |
+
Copyright (C) 2012- University of Edinburgh
|
| 4 |
+
|
| 5 |
+
This library is free software; you can redistribute it and/or
|
| 6 |
+
modify it under the terms of the GNU Lesser General Public
|
| 7 |
+
License as published by the Free Software Foundation; either
|
| 8 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This library is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 13 |
+
Lesser General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU Lesser General Public
|
| 16 |
+
License along with this library; if not, write to the Free Software
|
| 17 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 18 |
+
***********************************************************************/
|
| 19 |
+
|
| 20 |
+
#include <boost/algorithm/string/predicate.hpp>
|
| 21 |
+
#include "ScoreFeature.h"
|
| 22 |
+
#include "DomainFeature.h"
|
| 23 |
+
#include "InternalStructFeature.h"
|
| 24 |
+
|
| 25 |
+
using namespace std;
|
| 26 |
+
using namespace boost::algorithm;
|
| 27 |
+
|
| 28 |
+
namespace MosesTraining
|
| 29 |
+
{
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
const string& ScoreFeatureManager::usage() const
|
| 33 |
+
{
|
| 34 |
+
const static string& usage = "[--[Sparse]Domain[Indicator|Ratio|Subset|Bin] domain-file [bins]]" ;
|
| 35 |
+
return usage;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
void ScoreFeatureManager::configure(const std::vector<std::string> args)
|
| 39 |
+
{
|
| 40 |
+
bool domainAdded = false;
|
| 41 |
+
bool sparseDomainAdded = false;
|
| 42 |
+
|
| 43 |
+
for (size_t i = 0; i < args.size(); ++i) {
|
| 44 |
+
if (args[i] == "--IgnoreSentenceId") {
|
| 45 |
+
m_includeSentenceId = true;
|
| 46 |
+
} else if (starts_with(args[i], "--Domain")) {
|
| 47 |
+
string type = args[i].substr(8);
|
| 48 |
+
++i;
|
| 49 |
+
UTIL_THROW_IF(i == args.size(), ScoreFeatureArgumentException, "Missing domain file");
|
| 50 |
+
string domainFile = args[i];
|
| 51 |
+
UTIL_THROW_IF(domainAdded, ScoreFeatureArgumentException,
|
| 52 |
+
"Only allowed one domain feature");
|
| 53 |
+
if (type == "Subset") {
|
| 54 |
+
m_features.push_back(ScoreFeaturePtr(new SubsetDomainFeature(domainFile)));
|
| 55 |
+
} else if (type == "Ratio") {
|
| 56 |
+
m_features.push_back(ScoreFeaturePtr(new RatioDomainFeature(domainFile)));
|
| 57 |
+
} else if (type == "Indicator") {
|
| 58 |
+
m_features.push_back(ScoreFeaturePtr(new IndicatorDomainFeature(domainFile)));
|
| 59 |
+
} else {
|
| 60 |
+
UTIL_THROW(ScoreFeatureArgumentException, "Unknown domain feature type " << type);
|
| 61 |
+
}
|
| 62 |
+
domainAdded = true;
|
| 63 |
+
m_includeSentenceId = true;
|
| 64 |
+
} else if (starts_with(args[i], "--SparseDomain")) {
|
| 65 |
+
string type = args[i].substr(14);
|
| 66 |
+
++i;
|
| 67 |
+
UTIL_THROW_IF(i == args.size(), ScoreFeatureArgumentException, "Missing domain file");
|
| 68 |
+
string domainFile = args[i];
|
| 69 |
+
UTIL_THROW_IF(sparseDomainAdded, ScoreFeatureArgumentException,
|
| 70 |
+
"Only allowed one sparse domain feature");
|
| 71 |
+
if (type == "Subset") {
|
| 72 |
+
m_features.push_back(ScoreFeaturePtr(new SparseSubsetDomainFeature(domainFile)));
|
| 73 |
+
} else if (type == "Ratio") {
|
| 74 |
+
m_features.push_back(ScoreFeaturePtr(new SparseRatioDomainFeature(domainFile)));
|
| 75 |
+
} else if (type == "Indicator") {
|
| 76 |
+
m_features.push_back(ScoreFeaturePtr(new SparseIndicatorDomainFeature(domainFile)));
|
| 77 |
+
} else {
|
| 78 |
+
UTIL_THROW(ScoreFeatureArgumentException, "Unknown domain feature type " << type);
|
| 79 |
+
}
|
| 80 |
+
sparseDomainAdded = true;
|
| 81 |
+
m_includeSentenceId = true;
|
| 82 |
+
} else if(args[i] == "--TreeFeatureSparse") {
|
| 83 |
+
//MARIA
|
| 84 |
+
m_features.push_back(ScoreFeaturePtr(new InternalStructFeatureSparse()));
|
| 85 |
+
} else if(args[i] == "--TreeFeatureDense") {
|
| 86 |
+
//MARIA
|
| 87 |
+
m_features.push_back(ScoreFeaturePtr(new InternalStructFeatureDense()));
|
| 88 |
+
} else {
|
| 89 |
+
UTIL_THROW(ScoreFeatureArgumentException,"Unknown score argument " << args[i]);
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
void ScoreFeatureManager::addPropertiesToPhrasePair(ExtractionPhrasePair &phrasePair,
|
| 97 |
+
float count,
|
| 98 |
+
int sentenceId) const
|
| 99 |
+
{
|
| 100 |
+
for (size_t i = 0; i < m_features.size(); ++i) {
|
| 101 |
+
m_features[i]->addPropertiesToPhrasePair(phrasePair, count, sentenceId);
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
void ScoreFeatureManager::addFeatures(const ScoreFeatureContext& context,
|
| 106 |
+
std::vector<float>& denseValues,
|
| 107 |
+
std::map<std::string,float>& sparseValues) const
|
| 108 |
+
{
|
| 109 |
+
for (size_t i = 0; i < m_features.size(); ++i) {
|
| 110 |
+
m_features[i]->add(context, denseValues, sparseValues);
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
mosesdecoder/phrase-extract/SyntaxTree.h
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "syntax-common/tree.h"
|
| 4 |
+
|
| 5 |
+
#include "SyntaxNode.h"
|
| 6 |
+
|
| 7 |
+
namespace MosesTraining
|
| 8 |
+
{
|
| 9 |
+
|
| 10 |
+
typedef Syntax::Tree<SyntaxNode> SyntaxTree;
|
| 11 |
+
|
| 12 |
+
} // namespace MosesTraining
|
mosesdecoder/phrase-extract/consolidate-direct-main.cpp
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***********************************************************************
|
| 2 |
+
Moses - factored phrase-based language decoder
|
| 3 |
+
Copyright (C) 2009 University of Edinburgh
|
| 4 |
+
|
| 5 |
+
This library is free software; you can redistribute it and/or
|
| 6 |
+
modify it under the terms of the GNU Lesser General Public
|
| 7 |
+
License as published by the Free Software Foundation; either
|
| 8 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This library is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 13 |
+
Lesser General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU Lesser General Public
|
| 16 |
+
License along with this library; if not, write to the Free Software
|
| 17 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 18 |
+
***********************************************************************/
|
| 19 |
+
|
| 20 |
+
#include <string.h>
|
| 21 |
+
#include <fstream>
|
| 22 |
+
#include <vector>
|
| 23 |
+
#include <string>
|
| 24 |
+
#include <iostream>
|
| 25 |
+
#include <cstdlib>
|
| 26 |
+
#include "InputFileStream.h"
|
| 27 |
+
#include "OutputFileStream.h"
|
| 28 |
+
#include "util/tokenize.hh"
|
| 29 |
+
|
| 30 |
+
using namespace std;
|
| 31 |
+
|
| 32 |
+
vector< string > splitLine(const char *line)
|
| 33 |
+
{
|
| 34 |
+
vector< string > item;
|
| 35 |
+
int start=0;
|
| 36 |
+
int i=0;
|
| 37 |
+
for(; line[i] != '\0'; i++) {
|
| 38 |
+
if (line[i] == ' ' &&
|
| 39 |
+
line[i+1] == '|' &&
|
| 40 |
+
line[i+2] == '|' &&
|
| 41 |
+
line[i+3] == '|' &&
|
| 42 |
+
line[i+4] == ' ') {
|
| 43 |
+
if (start > i) start = i; // empty item
|
| 44 |
+
item.push_back( string( line+start, i-start ) );
|
| 45 |
+
start = i+5;
|
| 46 |
+
i += 3;
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
item.push_back( string( line+start, i-start ) );
|
| 50 |
+
|
| 51 |
+
return item;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
bool getLine( istream &fileP, vector< string > &item )
|
| 55 |
+
{
|
| 56 |
+
if (fileP.eof())
|
| 57 |
+
return false;
|
| 58 |
+
|
| 59 |
+
string line;
|
| 60 |
+
if (getline(fileP, line)) {
|
| 61 |
+
item = splitLine(line.c_str());
|
| 62 |
+
return true;
|
| 63 |
+
} else {
|
| 64 |
+
return false;
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
int main(int argc, char* argv[])
|
| 70 |
+
{
|
| 71 |
+
cerr << "Starting..." << endl;
|
| 72 |
+
|
| 73 |
+
char* &fileNameDirect = argv[1];
|
| 74 |
+
Moses::InputFileStream fileDirect(fileNameDirect);
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
//fileDirect.open(fileNameDirect);
|
| 78 |
+
if (fileDirect.fail()) {
|
| 79 |
+
cerr << "ERROR: could not open extract file " << fileNameDirect << endl;
|
| 80 |
+
exit(1);
|
| 81 |
+
}
|
| 82 |
+
istream &fileDirectP = fileDirect;
|
| 83 |
+
|
| 84 |
+
char* &fileNameConsolidated = argv[2];
|
| 85 |
+
ostream *fileConsolidated;
|
| 86 |
+
|
| 87 |
+
if (strcmp(fileNameConsolidated, "-") == 0) {
|
| 88 |
+
fileConsolidated = &cout;
|
| 89 |
+
} else {
|
| 90 |
+
Moses::OutputFileStream *outputFile = new Moses::OutputFileStream();
|
| 91 |
+
bool success = outputFile->Open(fileNameConsolidated);
|
| 92 |
+
if (!success) {
|
| 93 |
+
cerr << "ERROR: could not open file phrase table file "
|
| 94 |
+
<< fileNameConsolidated << endl;
|
| 95 |
+
exit(1);
|
| 96 |
+
}
|
| 97 |
+
fileConsolidated = outputFile;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
int i=0;
|
| 101 |
+
while(true) {
|
| 102 |
+
i++;
|
| 103 |
+
if (i%1000 == 0) cerr << "." << flush;
|
| 104 |
+
if (i%10000 == 0) cerr << ":" << flush;
|
| 105 |
+
if (i%100000 == 0) cerr << "!" << flush;
|
| 106 |
+
|
| 107 |
+
vector< string > itemDirect;
|
| 108 |
+
if (! getLine(fileDirectP, itemDirect ))
|
| 109 |
+
break;
|
| 110 |
+
|
| 111 |
+
const vector< string > count = util::tokenize( itemDirect[4] );
|
| 112 |
+
float countEF = atof(count[0].c_str());
|
| 113 |
+
float countF = atof(count[1].c_str());
|
| 114 |
+
float prob = countF/countEF;
|
| 115 |
+
|
| 116 |
+
(*fileConsolidated) << itemDirect[0] << " ||| " // source
|
| 117 |
+
<< itemDirect[1] << " ||| " // target
|
| 118 |
+
<< prob << " ||| " // prob
|
| 119 |
+
<< itemDirect[2] << "||| " // alignment
|
| 120 |
+
<< itemDirect[4] << " " << countEF // counts
|
| 121 |
+
<< " ||| " << endl;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
fileConsolidated->flush();
|
| 125 |
+
if (fileConsolidated != &cout) {
|
| 126 |
+
delete fileConsolidated;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
cerr << "Finished" << endl;
|
| 130 |
+
}
|
| 131 |
+
|
mosesdecoder/phrase-extract/extract-lex.h
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <map>
|
| 4 |
+
#include <set>
|
| 5 |
+
#include <sstream>
|
| 6 |
+
#include <fstream>
|
| 7 |
+
#include <iostream>
|
| 8 |
+
|
| 9 |
+
namespace MosesTraining
|
| 10 |
+
{
|
| 11 |
+
|
| 12 |
+
class WordCount
|
| 13 |
+
{
|
| 14 |
+
friend std::ostream& operator<<(std::ostream&, const WordCount&);
|
| 15 |
+
public:
|
| 16 |
+
float m_count;
|
| 17 |
+
|
| 18 |
+
std::map<const std::string*, WordCount> m_coll;
|
| 19 |
+
|
| 20 |
+
WordCount()
|
| 21 |
+
:m_count(0) {
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
//WordCount(const WordCount ©);
|
| 25 |
+
|
| 26 |
+
WordCount(float count)
|
| 27 |
+
:m_count(count) {
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
void AddCount(float incr);
|
| 31 |
+
|
| 32 |
+
std::map<const std::string*, WordCount> &GetColl() {
|
| 33 |
+
return m_coll;
|
| 34 |
+
}
|
| 35 |
+
const std::map<const std::string*, WordCount> &GetColl() const {
|
| 36 |
+
return m_coll;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
const float GetCount() const {
|
| 40 |
+
return m_count;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
};
|
| 44 |
+
|
| 45 |
+
class Vocab
|
| 46 |
+
{
|
| 47 |
+
std::set<std::string> m_coll;
|
| 48 |
+
public:
|
| 49 |
+
const std::string *GetOrAdd(const std::string &word);
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
class ExtractLex
|
| 53 |
+
{
|
| 54 |
+
Vocab m_vocab;
|
| 55 |
+
std::map<const std::string*, WordCount> m_collS2T, m_collT2S;
|
| 56 |
+
|
| 57 |
+
void Process(const std::string *target, const std::string *source);
|
| 58 |
+
void Process(WordCount &wcIn, const std::string *out);
|
| 59 |
+
void ProcessUnaligned(std::vector<std::string> &toksTarget, std::vector<std::string> &toksSource
|
| 60 |
+
, const std::vector<bool> &m_sourceAligned, const std::vector<bool> &m_targetAligned);
|
| 61 |
+
|
| 62 |
+
void Output(const std::map<const std::string*, WordCount> &coll, std::ofstream &outStream);
|
| 63 |
+
|
| 64 |
+
public:
|
| 65 |
+
void Process(std::vector<std::string> &toksTarget, std::vector<std::string> &toksSource, std::vector<std::string> &toksAlign, size_t lineCount);
|
| 66 |
+
void Output(std::ofstream &streamLexS2T, std::ofstream &streamLexT2S);
|
| 67 |
+
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
} // namespace
|
mosesdecoder/phrase-extract/filter-rule-table/CfgFilter.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <istream>
|
| 4 |
+
#include <ostream>
|
| 5 |
+
#include <string>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
namespace MosesTraining
|
| 9 |
+
{
|
| 10 |
+
namespace Syntax
|
| 11 |
+
{
|
| 12 |
+
namespace FilterRuleTable
|
| 13 |
+
{
|
| 14 |
+
|
| 15 |
+
// Base class for StringCfgFilter and TreeCfgFilter, both of which filter rule
|
| 16 |
+
// tables where the source-side is CFG.
|
| 17 |
+
class CfgFilter
|
| 18 |
+
{
|
| 19 |
+
public:
|
| 20 |
+
virtual ~CfgFilter() {}
|
| 21 |
+
|
| 22 |
+
// Read a rule table from 'in' and filter it according to the test sentences.
|
| 23 |
+
virtual void Filter(std::istream &in, std::ostream &out) = 0;
|
| 24 |
+
|
| 25 |
+
protected:
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
} // namespace FilterRuleTable
|
| 29 |
+
} // namespace Syntax
|
| 30 |
+
} // namespace MosesTraining
|
mosesdecoder/phrase-extract/filter-rule-table/FilterRuleTable.h
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <vector>
|
| 4 |
+
#include <string>
|
| 5 |
+
|
| 6 |
+
#include <boost/shared_ptr.hpp>
|
| 7 |
+
|
| 8 |
+
#include "SyntaxTree.h"
|
| 9 |
+
|
| 10 |
+
#include "syntax-common/tool.h"
|
| 11 |
+
|
| 12 |
+
#include "StringForest.h"
|
| 13 |
+
|
| 14 |
+
namespace MosesTraining
|
| 15 |
+
{
|
| 16 |
+
namespace Syntax
|
| 17 |
+
{
|
| 18 |
+
namespace FilterRuleTable
|
| 19 |
+
{
|
| 20 |
+
|
| 21 |
+
struct Options;
|
| 22 |
+
|
| 23 |
+
class FilterRuleTable : public Tool
|
| 24 |
+
{
|
| 25 |
+
public:
|
| 26 |
+
FilterRuleTable() : Tool("filter-rule-table") {}
|
| 27 |
+
|
| 28 |
+
virtual int Main(int argc, char *argv[]);
|
| 29 |
+
|
| 30 |
+
private:
|
| 31 |
+
// Filter rule table (on std::cin) for test set (string version).
|
| 32 |
+
void Filter(const std::vector<std::vector<std::string> > &);
|
| 33 |
+
|
| 34 |
+
// Filter rule table (on std::cin) for test set (parse tree version).
|
| 35 |
+
void Filter(const std::vector<boost::shared_ptr<SyntaxTree> > &);
|
| 36 |
+
|
| 37 |
+
void ProcessOptions(int, char *[], Options &) const;
|
| 38 |
+
|
| 39 |
+
// Read test set (string version)
|
| 40 |
+
void ReadTestSet(std::istream &,
|
| 41 |
+
std::vector<boost::shared_ptr<std::string> > &);
|
| 42 |
+
|
| 43 |
+
// Read test set (tree version)
|
| 44 |
+
void ReadTestSet(std::istream &,
|
| 45 |
+
std::vector<boost::shared_ptr<SyntaxTree> > &);
|
| 46 |
+
|
| 47 |
+
// Read test set (forest version)
|
| 48 |
+
void ReadTestSet(std::istream &,
|
| 49 |
+
std::vector<boost::shared_ptr<StringForest> > &);
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
} // namespace FilterRuleTable
|
| 53 |
+
} // namespace Syntax
|
| 54 |
+
} // namespace MosesTraining
|
mosesdecoder/phrase-extract/filter-rule-table/Forest.h
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <vector>
|
| 4 |
+
|
| 5 |
+
namespace MosesTraining
|
| 6 |
+
{
|
| 7 |
+
namespace Syntax
|
| 8 |
+
{
|
| 9 |
+
namespace FilterRuleTable
|
| 10 |
+
{
|
| 11 |
+
|
| 12 |
+
template<typename T>
|
| 13 |
+
struct Forest {
|
| 14 |
+
struct Vertex;
|
| 15 |
+
|
| 16 |
+
struct Hyperedge {
|
| 17 |
+
Vertex *head;
|
| 18 |
+
std::vector<Vertex *> tail;
|
| 19 |
+
};
|
| 20 |
+
|
| 21 |
+
struct Vertex {
|
| 22 |
+
~Vertex();
|
| 23 |
+
T value;
|
| 24 |
+
std::vector<Hyperedge *> incoming;
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
Forest() {}
|
| 28 |
+
|
| 29 |
+
~Forest();
|
| 30 |
+
|
| 31 |
+
std::vector<Vertex *> vertices;
|
| 32 |
+
|
| 33 |
+
private:
|
| 34 |
+
// Copying is not allowed.
|
| 35 |
+
Forest(const Forest &);
|
| 36 |
+
Forest &operator=(const Forest &);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
template<typename T>
|
| 40 |
+
Forest<T>::~Forest()
|
| 41 |
+
{
|
| 42 |
+
for (typename std::vector<Vertex *>::iterator p = vertices.begin();
|
| 43 |
+
p != vertices.end(); ++p) {
|
| 44 |
+
delete *p;
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
template<typename T>
|
| 49 |
+
Forest<T>::Vertex::~Vertex()
|
| 50 |
+
{
|
| 51 |
+
for (typename std::vector<Hyperedge *>::iterator p = incoming.begin();
|
| 52 |
+
p != incoming.end(); ++p) {
|
| 53 |
+
delete *p;
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
} // namespace FilterRuleTable
|
| 58 |
+
} // namespace Syntax
|
| 59 |
+
} // namespace Moses
|
mosesdecoder/phrase-extract/filter-rule-table/ForestTsgFilter.cpp
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ForestTsgFilter.h"
|
| 2 |
+
|
| 3 |
+
#include <boost/make_shared.hpp>
|
| 4 |
+
|
| 5 |
+
namespace MosesTraining
|
| 6 |
+
{
|
| 7 |
+
namespace Syntax
|
| 8 |
+
{
|
| 9 |
+
namespace FilterRuleTable
|
| 10 |
+
{
|
| 11 |
+
|
| 12 |
+
// kMatchLimit is used to limit the effort spent trying to match an individual
|
| 13 |
+
// rule. It defines the maximum number of times that MatchFragment() can be
|
| 14 |
+
// called before the search is aborted and the rule is (possibly wrongly)
|
| 15 |
+
// accepted.
|
| 16 |
+
// FIXME Use a better matching algorithm.
|
| 17 |
+
const std::size_t ForestTsgFilter::kMatchLimit = 10000;
|
| 18 |
+
|
| 19 |
+
ForestTsgFilter::ForestTsgFilter(
|
| 20 |
+
const std::vector<boost::shared_ptr<StringForest> > &sentences)
|
| 21 |
+
{
|
| 22 |
+
// Convert each StringForest to an IdForest.
|
| 23 |
+
m_sentences.reserve(sentences.size());
|
| 24 |
+
for (std::vector<boost::shared_ptr<StringForest> >::const_iterator p =
|
| 25 |
+
sentences.begin(); p != sentences.end(); ++p) {
|
| 26 |
+
m_sentences.push_back(StringForestToIdForest(**p));
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// Construct a map from vocabulary Ids to IdForest nodes.
|
| 30 |
+
m_idToSentence.resize(m_testVocab.Size());
|
| 31 |
+
for (std::size_t i = 0; i < m_sentences.size(); ++i) {
|
| 32 |
+
const IdForest &forest = *(m_sentences[i]);
|
| 33 |
+
for (std::vector<IdForest::Vertex *>::const_iterator
|
| 34 |
+
p = forest.vertices.begin(); p != forest.vertices.end(); ++p) {
|
| 35 |
+
m_idToSentence[(*p)->value.id][i].push_back(*p);
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
boost::shared_ptr<ForestTsgFilter::IdForest>
|
| 41 |
+
ForestTsgFilter::StringForestToIdForest(const StringForest &f)
|
| 42 |
+
{
|
| 43 |
+
typedef StringForest::Vertex StringVertex;
|
| 44 |
+
typedef StringForest::Hyperedge StringHyperedge;
|
| 45 |
+
typedef IdForest::Vertex IdVertex;
|
| 46 |
+
typedef IdForest::Hyperedge IdHyperedge;
|
| 47 |
+
|
| 48 |
+
boost::shared_ptr<IdForest> g = boost::make_shared<IdForest>();
|
| 49 |
+
|
| 50 |
+
// Map from f's vertices to g's vertices.
|
| 51 |
+
boost::unordered_map<const StringVertex *, const IdVertex *> vertexMap;
|
| 52 |
+
|
| 53 |
+
// Create idForest's vertices and populate vertexMap.
|
| 54 |
+
for (std::vector<StringVertex *>::const_iterator p = f.vertices.begin();
|
| 55 |
+
p != f.vertices.end(); ++p) {
|
| 56 |
+
const StringVertex *v = *p;
|
| 57 |
+
IdVertex *w = new IdVertex();
|
| 58 |
+
w->value.id = m_testVocab.Insert(v->value.symbol);
|
| 59 |
+
w->value.start = v->value.start;
|
| 60 |
+
w->value.end = v->value.end;
|
| 61 |
+
g->vertices.push_back(w);
|
| 62 |
+
vertexMap[v] = w;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
// Create g's hyperedges.
|
| 66 |
+
for (std::vector<StringVertex *>::const_iterator p = f.vertices.begin();
|
| 67 |
+
p != f.vertices.end(); ++p) {
|
| 68 |
+
for (std::vector<StringHyperedge *>::const_iterator
|
| 69 |
+
q = (*p)->incoming.begin(); q != (*p)->incoming.end(); ++q) {
|
| 70 |
+
IdHyperedge *e = new IdHyperedge();
|
| 71 |
+
e->head = const_cast<IdVertex *>(vertexMap[(*q)->head]);
|
| 72 |
+
e->tail.reserve((*q)->tail.size());
|
| 73 |
+
for (std::vector<StringVertex*>::const_iterator
|
| 74 |
+
r = (*q)->tail.begin(); r != (*q)->tail.end(); ++r) {
|
| 75 |
+
e->tail.push_back(const_cast<IdVertex *>(vertexMap[*r]));
|
| 76 |
+
}
|
| 77 |
+
e->head->incoming.push_back(e);
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
return g;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
bool ForestTsgFilter::MatchFragment(const IdTree &fragment,
|
| 85 |
+
const std::vector<IdTree *> &leaves)
|
| 86 |
+
{
|
| 87 |
+
typedef std::vector<const IdTree *> TreeVec;
|
| 88 |
+
|
| 89 |
+
// Reset the match counter.
|
| 90 |
+
m_matchCount = 0;
|
| 91 |
+
|
| 92 |
+
// Determine which of the fragment's leaves occurs in the smallest number of
|
| 93 |
+
// sentences in the test set. If the fragment contains a rare word
|
| 94 |
+
// (which is pretty likely assuming a Zipfian distribution) then we only
|
| 95 |
+
// have to try matching the fragment against a small number of potential
|
| 96 |
+
// match sites.
|
| 97 |
+
const IdTree *rarestLeaf = leaves[0];
|
| 98 |
+
std::size_t lowestCount = m_idToSentence[rarestLeaf->value()].size();
|
| 99 |
+
for (std::size_t i = 1; i < leaves.size(); ++i) {
|
| 100 |
+
const IdTree *leaf = leaves[i];
|
| 101 |
+
std::size_t count = m_idToSentence[leaf->value()].size();
|
| 102 |
+
if (count < lowestCount) {
|
| 103 |
+
lowestCount = count;
|
| 104 |
+
rarestLeaf = leaf;
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// Try to match the rule fragment against the sentences where the rarest
|
| 109 |
+
// leaf was found.
|
| 110 |
+
const InnerMap &leafSentenceMap = m_idToSentence[rarestLeaf->value()];
|
| 111 |
+
const InnerMap &rootSentenceMap = m_idToSentence[fragment.value()];
|
| 112 |
+
|
| 113 |
+
std::vector<std::pair<std::size_t, std::size_t> > spans;
|
| 114 |
+
// For each forest i that contains the rarest leaf symbol...
|
| 115 |
+
for (InnerMap::const_iterator p = leafSentenceMap.begin();
|
| 116 |
+
p != leafSentenceMap.end(); ++p) {
|
| 117 |
+
std::size_t i = p->first;
|
| 118 |
+
// Get the set of candidate match sites in forest i (these are vertices
|
| 119 |
+
// with the same label as the root of the rule fragment).
|
| 120 |
+
InnerMap::const_iterator q = rootSentenceMap.find(i);
|
| 121 |
+
if (q == rootSentenceMap.end()) {
|
| 122 |
+
continue;
|
| 123 |
+
}
|
| 124 |
+
const std::vector<const IdForest::Vertex*> &candidates = q->second;
|
| 125 |
+
// Record the span(s) of the rare leaf symbol in forest i.
|
| 126 |
+
spans.clear();
|
| 127 |
+
for (std::vector<const IdForest::Vertex*>::const_iterator
|
| 128 |
+
r = p->second.begin(); r != p->second.end(); ++r) {
|
| 129 |
+
spans.push_back(std::make_pair((*r)->value.start, (*r)->value.end));
|
| 130 |
+
}
|
| 131 |
+
// For each candidate match site in forest i...
|
| 132 |
+
for (std::vector<const IdForest::Vertex*>::const_iterator
|
| 133 |
+
r = candidates.begin(); r != candidates.end(); ++r) {
|
| 134 |
+
const IdForest::Vertex &v = **r;
|
| 135 |
+
// Check that the subtrees rooted at v are at least as wide as the
|
| 136 |
+
// fragment (counting each non-terminal as being one token wide).
|
| 137 |
+
if (v.value.end - v.value.start + 1 < leaves.size()) {
|
| 138 |
+
continue;
|
| 139 |
+
}
|
| 140 |
+
// Check that the candidate's span covers one of the rare leaf symbols.
|
| 141 |
+
bool covered = false;
|
| 142 |
+
for (std::vector<std::pair<std::size_t, std::size_t> >::const_iterator
|
| 143 |
+
s = spans.begin(); s != spans.end(); ++s) {
|
| 144 |
+
if (v.value.start <= s->first && v.value.end >= s->second) {
|
| 145 |
+
covered = true;
|
| 146 |
+
break;
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
if (!covered) {
|
| 150 |
+
continue;
|
| 151 |
+
}
|
| 152 |
+
// Attempt to match the fragment at the candidate site.
|
| 153 |
+
if (MatchFragment(fragment, v)) {
|
| 154 |
+
return true;
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
return false;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
bool ForestTsgFilter::MatchFragment(const IdTree &fragment,
|
| 162 |
+
const IdForest::Vertex &v)
|
| 163 |
+
{
|
| 164 |
+
if (++m_matchCount >= kMatchLimit) {
|
| 165 |
+
return true;
|
| 166 |
+
}
|
| 167 |
+
if (fragment.value() != v.value.id) {
|
| 168 |
+
return false;
|
| 169 |
+
}
|
| 170 |
+
const std::vector<IdTree*> &children = fragment.children();
|
| 171 |
+
if (children.empty()) {
|
| 172 |
+
return true;
|
| 173 |
+
}
|
| 174 |
+
for (std::vector<IdForest::Hyperedge *>::const_iterator
|
| 175 |
+
p = v.incoming.begin(); p != v.incoming.end(); ++p) {
|
| 176 |
+
const std::vector<IdForest::Vertex*> &tail = (*p)->tail;
|
| 177 |
+
if (children.size() != tail.size()) {
|
| 178 |
+
continue;
|
| 179 |
+
}
|
| 180 |
+
bool match = true;
|
| 181 |
+
for (std::size_t i = 0; i < children.size(); ++i) {
|
| 182 |
+
if (!MatchFragment(*children[i], *tail[i])) {
|
| 183 |
+
match = false;
|
| 184 |
+
break;
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
if (match) {
|
| 188 |
+
return true;
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
return false;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
} // namespace FilterRuleTable
|
| 195 |
+
} // namespace Syntax
|
| 196 |
+
} // namespace MosesTraining
|
mosesdecoder/phrase-extract/filter-rule-table/ForestTsgFilter.h
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <istream>
|
| 4 |
+
#include <ostream>
|
| 5 |
+
#include <string>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
#include <boost/shared_ptr.hpp>
|
| 9 |
+
#include <boost/unordered_map.hpp>
|
| 10 |
+
#include <boost/unordered_set.hpp>
|
| 11 |
+
|
| 12 |
+
#include "syntax-common/numbered_set.h"
|
| 13 |
+
#include "syntax-common/tree.h"
|
| 14 |
+
#include "syntax-common/tree_fragment_tokenizer.h"
|
| 15 |
+
|
| 16 |
+
#include "Forest.h"
|
| 17 |
+
#include "StringForest.h"
|
| 18 |
+
#include "TsgFilter.h"
|
| 19 |
+
|
| 20 |
+
namespace MosesTraining
|
| 21 |
+
{
|
| 22 |
+
namespace Syntax
|
| 23 |
+
{
|
| 24 |
+
namespace FilterRuleTable
|
| 25 |
+
{
|
| 26 |
+
|
| 27 |
+
// Filters a rule table, discarding rules that cannot be applied to a given
|
| 28 |
+
// test set. The rule table must have a TSG source-side and the test sentences
|
| 29 |
+
// must be parse forests.
|
| 30 |
+
class ForestTsgFilter : public TsgFilter
|
| 31 |
+
{
|
| 32 |
+
public:
|
| 33 |
+
// Initialize the filter for a given set of test forests.
|
| 34 |
+
ForestTsgFilter(const std::vector<boost::shared_ptr<StringForest> > &);
|
| 35 |
+
|
| 36 |
+
private:
|
| 37 |
+
struct IdForestValue {
|
| 38 |
+
Vocabulary::IdType id;
|
| 39 |
+
std::size_t start;
|
| 40 |
+
std::size_t end;
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
static const std::size_t kMatchLimit;
|
| 44 |
+
|
| 45 |
+
// Represents a forest using integer vocabulary values.
|
| 46 |
+
typedef Forest<IdForestValue> IdForest;
|
| 47 |
+
|
| 48 |
+
typedef boost::unordered_map<std::size_t,
|
| 49 |
+
std::vector<const IdForest::Vertex*> > InnerMap;
|
| 50 |
+
|
| 51 |
+
typedef std::vector<InnerMap> IdToSentenceMap;
|
| 52 |
+
|
| 53 |
+
// Forest-specific implementation of virtual function.
|
| 54 |
+
bool MatchFragment(const IdTree &, const std::vector<IdTree *> &);
|
| 55 |
+
|
| 56 |
+
// Try to match a fragment against a specific vertex of a test forest.
|
| 57 |
+
bool MatchFragment(const IdTree &, const IdForest::Vertex &);
|
| 58 |
+
|
| 59 |
+
// Convert a StringForest to an IdForest (wrt m_testVocab). Inserts symbols
|
| 60 |
+
// into m_testVocab.
|
| 61 |
+
boost::shared_ptr<IdForest> StringForestToIdForest(const StringForest &);
|
| 62 |
+
|
| 63 |
+
std::vector<boost::shared_ptr<IdForest> > m_sentences;
|
| 64 |
+
IdToSentenceMap m_idToSentence;
|
| 65 |
+
std::size_t m_matchCount;
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
} // namespace FilterRuleTable
|
| 69 |
+
} // namespace Syntax
|
| 70 |
+
} // namespace MosesTraining
|
mosesdecoder/phrase-extract/filter-rule-table/Jamfile
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
exe filter-rule-table : [ glob *.cpp ] ..//syntax-common ..//deps ../..//boost_iostreams ../..//boost_program_options ../..//z : <include>.. ;
|
mosesdecoder/phrase-extract/filter-rule-table/StringCfgFilter.cpp
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "StringCfgFilter.h"
|
| 2 |
+
|
| 3 |
+
#include <algorithm>
|
| 4 |
+
|
| 5 |
+
#include "util/string_piece_hash.hh"
|
| 6 |
+
|
| 7 |
+
namespace MosesTraining
|
| 8 |
+
{
|
| 9 |
+
namespace Syntax
|
| 10 |
+
{
|
| 11 |
+
namespace FilterRuleTable
|
| 12 |
+
{
|
| 13 |
+
|
| 14 |
+
const std::size_t StringCfgFilter::kMaxNGramLength = 5;
|
| 15 |
+
|
| 16 |
+
StringCfgFilter::StringCfgFilter(
|
| 17 |
+
const std::vector<boost::shared_ptr<std::string> > &sentences)
|
| 18 |
+
: m_maxSentenceLength(-1)
|
| 19 |
+
{
|
| 20 |
+
// Populate m_ngramCoordinateMap (except for the CoordinateTable's
|
| 21 |
+
// sentence vectors) and record the sentence lengths.
|
| 22 |
+
m_sentenceLengths.reserve(sentences.size());
|
| 23 |
+
const util::AnyCharacter delimiter(" \t");
|
| 24 |
+
std::vector<Vocabulary::IdType> vocabIds;
|
| 25 |
+
for (std::size_t i = 0; i < sentences.size(); ++i) {
|
| 26 |
+
vocabIds.clear();
|
| 27 |
+
for (util::TokenIter<util::AnyCharacter, true> p(*sentences[i], delimiter);
|
| 28 |
+
p; ++p) {
|
| 29 |
+
std::string tmp;
|
| 30 |
+
p->CopyToString(&tmp);
|
| 31 |
+
vocabIds.push_back(m_testVocab.Insert(tmp));
|
| 32 |
+
}
|
| 33 |
+
AddSentenceNGrams(vocabIds, i);
|
| 34 |
+
const int sentenceLength = static_cast<int>(vocabIds.size());
|
| 35 |
+
m_sentenceLengths.push_back(sentenceLength);
|
| 36 |
+
m_maxSentenceLength = std::max(sentenceLength, m_maxSentenceLength);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// Populate the CoordinateTable's sentence vectors.
|
| 40 |
+
for (NGramCoordinateMap::iterator p = m_ngramCoordinateMap.begin();
|
| 41 |
+
p != m_ngramCoordinateMap.end(); ++p) {
|
| 42 |
+
CoordinateTable &ct = p->second;
|
| 43 |
+
ct.sentences.reserve(ct.intraSentencePositions.size());
|
| 44 |
+
for (boost::unordered_map<int, PositionSeq>::const_iterator
|
| 45 |
+
q = ct.intraSentencePositions.begin();
|
| 46 |
+
q != ct.intraSentencePositions.end(); ++q) {
|
| 47 |
+
ct.sentences.push_back(q->first);
|
| 48 |
+
}
|
| 49 |
+
std::sort(ct.sentences.begin(), ct.sentences.end());
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
void StringCfgFilter::Filter(std::istream &in, std::ostream &out)
|
| 54 |
+
{
|
| 55 |
+
const util::MultiCharacter fieldDelimiter("|||");
|
| 56 |
+
const util::AnyCharacter symbolDelimiter(" \t");
|
| 57 |
+
|
| 58 |
+
std::string line;
|
| 59 |
+
std::string prevLine;
|
| 60 |
+
StringPiece source;
|
| 61 |
+
std::vector<StringPiece> symbols;
|
| 62 |
+
Pattern pattern;
|
| 63 |
+
bool keep = true;
|
| 64 |
+
int lineNum = 0;
|
| 65 |
+
|
| 66 |
+
while (std::getline(in, line)) {
|
| 67 |
+
++lineNum;
|
| 68 |
+
|
| 69 |
+
// Read the source-side of the rule.
|
| 70 |
+
util::TokenIter<util::MultiCharacter> it(line, fieldDelimiter);
|
| 71 |
+
|
| 72 |
+
// Check if this rule has the same source-side as the previous rule. If
|
| 73 |
+
// it does then we already know whether or not to keep the rule. This
|
| 74 |
+
// optimisation is based on the assumption that the rule table is sorted
|
| 75 |
+
// (which is the case in the standard Moses training pipeline).
|
| 76 |
+
if (*it == source) {
|
| 77 |
+
if (keep) {
|
| 78 |
+
out << line << std::endl;
|
| 79 |
+
}
|
| 80 |
+
continue;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
// The source-side is different from the previous rule's.
|
| 84 |
+
source = *it;
|
| 85 |
+
|
| 86 |
+
// Tokenize the source-side.
|
| 87 |
+
symbols.clear();
|
| 88 |
+
for (util::TokenIter<util::AnyCharacter, true> p(source, symbolDelimiter);
|
| 89 |
+
p; ++p) {
|
| 90 |
+
symbols.push_back(*p);
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// Generate a pattern (fails if any source-side terminal is not in the
|
| 94 |
+
// test set vocabulary) and attempt to match it against the test sentences.
|
| 95 |
+
keep = GeneratePattern(symbols, pattern) && MatchPattern(pattern);
|
| 96 |
+
if (keep) {
|
| 97 |
+
out << line << std::endl;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
// Retain line for the next iteration (in order that the source StringPiece
|
| 101 |
+
// remains valid).
|
| 102 |
+
prevLine.swap(line);
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
void StringCfgFilter::AddSentenceNGrams(
|
| 107 |
+
const std::vector<Vocabulary::IdType> &s, std::size_t sentNum)
|
| 108 |
+
{
|
| 109 |
+
const std::size_t len = s.size();
|
| 110 |
+
|
| 111 |
+
NGram ngram;
|
| 112 |
+
// For each starting position in the sentence:
|
| 113 |
+
for (std::size_t i = 0; i < len; ++i) {
|
| 114 |
+
// For each n-gram length: 1, 2, 3, ... kMaxNGramLength (or less when
|
| 115 |
+
// approaching the end of the sentence):
|
| 116 |
+
for (std::size_t n = 1; n <= std::min(kMaxNGramLength, len-i); ++n) {
|
| 117 |
+
ngram.clear();
|
| 118 |
+
for (std::size_t j = 0; j < n; ++j) {
|
| 119 |
+
ngram.push_back(s[i+j]);
|
| 120 |
+
}
|
| 121 |
+
m_ngramCoordinateMap[ngram].intraSentencePositions[sentNum].push_back(i);
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
bool StringCfgFilter::GeneratePattern(const std::vector<StringPiece> &symbols,
|
| 127 |
+
Pattern &pattern) const
|
| 128 |
+
{
|
| 129 |
+
pattern.subpatterns.clear();
|
| 130 |
+
pattern.minGapWidths.clear();
|
| 131 |
+
|
| 132 |
+
int gapWidth = 0;
|
| 133 |
+
|
| 134 |
+
// The first symbol is handled as a special case because there is always a
|
| 135 |
+
// leading gap / non-gap.
|
| 136 |
+
if (IsNonTerminal(symbols[0])) {
|
| 137 |
+
++gapWidth;
|
| 138 |
+
} else {
|
| 139 |
+
pattern.minGapWidths.push_back(0);
|
| 140 |
+
// Add the symbol to the first n-gram.
|
| 141 |
+
Vocabulary::IdType vocabId =
|
| 142 |
+
m_testVocab.Lookup(symbols[0], StringPieceCompatibleHash(),
|
| 143 |
+
StringPieceCompatibleEquals());
|
| 144 |
+
if (vocabId == Vocabulary::NullId()) {
|
| 145 |
+
return false;
|
| 146 |
+
}
|
| 147 |
+
pattern.subpatterns.push_back(NGram(1, vocabId));
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
// Process the remaining symbols (except the last which is the RHS).
|
| 151 |
+
for (std::size_t i = 1; i < symbols.size()-1; ++i) {
|
| 152 |
+
// Is current symbol a non-terminal?
|
| 153 |
+
if (IsNonTerminal(symbols[i])) {
|
| 154 |
+
++gapWidth;
|
| 155 |
+
continue;
|
| 156 |
+
}
|
| 157 |
+
// Does the current terminal follow a non-terminal?
|
| 158 |
+
if (gapWidth > 0) {
|
| 159 |
+
pattern.minGapWidths.push_back(gapWidth);
|
| 160 |
+
gapWidth = 0;
|
| 161 |
+
pattern.subpatterns.resize(pattern.subpatterns.size()+1);
|
| 162 |
+
// Is the current n-gram full?
|
| 163 |
+
} else if (pattern.subpatterns.back().size() == kMaxNGramLength) {
|
| 164 |
+
pattern.minGapWidths.push_back(0);
|
| 165 |
+
pattern.subpatterns.resize(pattern.subpatterns.size()+1);
|
| 166 |
+
}
|
| 167 |
+
// Add the symbol to the current n-gram.
|
| 168 |
+
Vocabulary::IdType vocabId =
|
| 169 |
+
m_testVocab.Lookup(symbols[i], StringPieceCompatibleHash(),
|
| 170 |
+
StringPieceCompatibleEquals());
|
| 171 |
+
if (vocabId == Vocabulary::NullId()) {
|
| 172 |
+
return false;
|
| 173 |
+
}
|
| 174 |
+
pattern.subpatterns.back().push_back(vocabId);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
// Add the final gap width value (0 if the last symbol was a terminal).
|
| 178 |
+
pattern.minGapWidths.push_back(gapWidth);
|
| 179 |
+
return true;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
bool StringCfgFilter::IsNonTerminal(const StringPiece &symbol) const
|
| 183 |
+
{
|
| 184 |
+
return symbol.size() >= 3 && symbol[0] == '[' &&
|
| 185 |
+
symbol[symbol.size()-1] == ']';
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
bool StringCfgFilter::MatchPattern(const Pattern &pattern) const
|
| 189 |
+
{
|
| 190 |
+
// Step 0: If the pattern is just a single gap (i.e. the original rule
|
| 191 |
+
// was fully non-lexical) then the pattern matches unless the
|
| 192 |
+
// minimum gap width is wider than any sentence.
|
| 193 |
+
if (pattern.subpatterns.empty()) {
|
| 194 |
+
assert(pattern.minGapWidths.size() == 1);
|
| 195 |
+
return pattern.minGapWidths[0] <= m_maxSentenceLength;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
// Step 1: Look up all of the subpatterns in m_ngramCoordinateMap and record
|
| 199 |
+
// pointers to their CoordinateTables.
|
| 200 |
+
std::vector<const CoordinateTable *> tables;
|
| 201 |
+
for (std::vector<NGram>::const_iterator p = pattern.subpatterns.begin();
|
| 202 |
+
p != pattern.subpatterns.end(); ++p) {
|
| 203 |
+
NGramCoordinateMap::const_iterator q = m_ngramCoordinateMap.find(*p);
|
| 204 |
+
// If a subpattern doesn't appear in m_ngramCoordinateMap then the match
|
| 205 |
+
// has already failed.
|
| 206 |
+
if (q == m_ngramCoordinateMap.end()) {
|
| 207 |
+
return false;
|
| 208 |
+
}
|
| 209 |
+
tables.push_back(&(q->second));
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
// Step 2: Intersect the CoordinateTables' sentence sets to find the set of
|
| 213 |
+
// test set sentences in which all subpatterns occur.
|
| 214 |
+
std::vector<int> intersection = tables[0]->sentences;
|
| 215 |
+
std::vector<int> tmp(intersection.size());
|
| 216 |
+
for (std::size_t i = 1; i < tables.size(); ++i) {
|
| 217 |
+
std::vector<int>::iterator p = std::set_intersection(
|
| 218 |
+
intersection.begin(), intersection.end(), tables[i]->sentences.begin(),
|
| 219 |
+
tables[i]->sentences.end(), tmp.begin());
|
| 220 |
+
tmp.resize(p-tmp.begin());
|
| 221 |
+
if (tmp.empty()) {
|
| 222 |
+
return false;
|
| 223 |
+
}
|
| 224 |
+
intersection.swap(tmp);
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
// Step 3: For each sentence in the intersection, try to find a consistent
|
| 228 |
+
// sequence of intra-sentence positions (one for each subpattern).
|
| 229 |
+
// 'Consistent' here means that the subpatterns occur in the right
|
| 230 |
+
// order and are separated by at least the minimum widths required
|
| 231 |
+
// by the pattern's gaps).
|
| 232 |
+
for (std::vector<int>::const_iterator p = intersection.begin();
|
| 233 |
+
p != intersection.end(); ++p) {
|
| 234 |
+
if (MatchPattern(pattern, tables, *p)) {
|
| 235 |
+
return true;
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
return false;
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
bool StringCfgFilter::MatchPattern(
|
| 242 |
+
const Pattern &pattern,
|
| 243 |
+
std::vector<const CoordinateTable *> &tables,
|
| 244 |
+
int sentenceId) const
|
| 245 |
+
{
|
| 246 |
+
const int sentenceLength = m_sentenceLengths[sentenceId];
|
| 247 |
+
|
| 248 |
+
// In the for loop below, we need to know the set of start position ranges
|
| 249 |
+
// where subpattern i is allowed to occur (rangeSet) and we are generating
|
| 250 |
+
// the ranges for subpattern i+1 (nextRangeSet).
|
| 251 |
+
// TODO Merge ranges if subpattern i follows a non-zero gap.
|
| 252 |
+
std::vector<Range> rangeSet;
|
| 253 |
+
std::vector<Range> nextRangeSet;
|
| 254 |
+
|
| 255 |
+
// Calculate the range for the first subpattern.
|
| 256 |
+
int minStart = pattern.minGapWidths[0];
|
| 257 |
+
int maxStart = sentenceLength - MinWidth(pattern, 0);
|
| 258 |
+
rangeSet.push_back(Range(minStart, maxStart));
|
| 259 |
+
|
| 260 |
+
// Attempt to match subpatterns.
|
| 261 |
+
for (int i = 0; i < pattern.subpatterns.size(); ++i) {
|
| 262 |
+
// Look-up the intra-sentence position sequence.
|
| 263 |
+
boost::unordered_map<int, PositionSeq>::const_iterator r =
|
| 264 |
+
tables[i]->intraSentencePositions.find(sentenceId);
|
| 265 |
+
assert(r != tables[i]->intraSentencePositions.end());
|
| 266 |
+
const PositionSeq &col = r->second;
|
| 267 |
+
for (PositionSeq::const_iterator p = col.begin(); p != col.end(); ++p) {
|
| 268 |
+
bool inRange = false;
|
| 269 |
+
for (std::vector<Range>::const_iterator q = rangeSet.begin();
|
| 270 |
+
q != rangeSet.end(); ++q) {
|
| 271 |
+
// TODO Use the fact that the ranges are ordered to break early.
|
| 272 |
+
if (*p >= q->first && *p <= q->second) {
|
| 273 |
+
inRange = true;
|
| 274 |
+
break;
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
if (!inRange) {
|
| 278 |
+
continue;
|
| 279 |
+
}
|
| 280 |
+
// If this is the last subpattern then we're done.
|
| 281 |
+
if (i+1 == pattern.subpatterns.size()) {
|
| 282 |
+
return true;
|
| 283 |
+
}
|
| 284 |
+
nextRangeSet.push_back(CalcNextRange(pattern, i, *p, sentenceLength));
|
| 285 |
+
}
|
| 286 |
+
if (nextRangeSet.empty()) {
|
| 287 |
+
return false;
|
| 288 |
+
}
|
| 289 |
+
rangeSet.swap(nextRangeSet);
|
| 290 |
+
nextRangeSet.clear();
|
| 291 |
+
}
|
| 292 |
+
return true;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
StringCfgFilter::Range StringCfgFilter::CalcNextRange(
|
| 296 |
+
const Pattern &pattern, int i, int x, int sentenceLength) const
|
| 297 |
+
{
|
| 298 |
+
assert(i+1 < pattern.subpatterns.size());
|
| 299 |
+
Range range;
|
| 300 |
+
if (pattern.minGapWidths[i+1] == 0) {
|
| 301 |
+
// The next subpattern follows this one without a gap.
|
| 302 |
+
range.first = range.second = x + pattern.subpatterns[i].size();
|
| 303 |
+
} else {
|
| 304 |
+
range.first = x + pattern.subpatterns[i].size() + pattern.minGapWidths[i+1];
|
| 305 |
+
// TODO MinWidth should only be computed once per subpattern.
|
| 306 |
+
range.second = sentenceLength - MinWidth(pattern, i+1);
|
| 307 |
+
}
|
| 308 |
+
return range;
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
int StringCfgFilter::MinWidth(const Pattern &pattern, int i) const
|
| 312 |
+
{
|
| 313 |
+
int minWidth = 0;
|
| 314 |
+
for (; i < pattern.subpatterns.size(); ++i) {
|
| 315 |
+
minWidth += pattern.subpatterns[i].size();
|
| 316 |
+
minWidth += pattern.minGapWidths[i+1];
|
| 317 |
+
}
|
| 318 |
+
return minWidth;
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
} // namespace FilterRuleTable
|
| 322 |
+
} // namespace Syntax
|
| 323 |
+
} // namespace MosesTraining
|
mosesdecoder/phrase-extract/filter-rule-table/StringCfgFilter.h
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <string>
|
| 4 |
+
#include <vector>
|
| 5 |
+
|
| 6 |
+
#include "syntax-common/numbered_set.h"
|
| 7 |
+
|
| 8 |
+
#include <boost/shared_ptr.hpp>
|
| 9 |
+
#include <boost/unordered_map.hpp>
|
| 10 |
+
|
| 11 |
+
#include "util/string_piece.hh"
|
| 12 |
+
#include "util/tokenize_piece.hh"
|
| 13 |
+
|
| 14 |
+
#include "CfgFilter.h"
|
| 15 |
+
|
| 16 |
+
namespace MosesTraining
|
| 17 |
+
{
|
| 18 |
+
namespace Syntax
|
| 19 |
+
{
|
| 20 |
+
namespace FilterRuleTable
|
| 21 |
+
{
|
| 22 |
+
|
| 23 |
+
// Filters a rule table, discarding rules that cannot be applied to a given
|
| 24 |
+
// test set. The rule table must have a CFG source-side and the test sentences
|
| 25 |
+
// must be strings.
|
| 26 |
+
class StringCfgFilter : public CfgFilter
|
| 27 |
+
{
|
| 28 |
+
public:
|
| 29 |
+
// Initialize the filter for a given set of test sentences.
|
| 30 |
+
StringCfgFilter(const std::vector<boost::shared_ptr<std::string> > &);
|
| 31 |
+
|
| 32 |
+
void Filter(std::istream &in, std::ostream &out);
|
| 33 |
+
|
| 34 |
+
private:
|
| 35 |
+
// Filtering works by converting the source LHSs of translation rules to
|
| 36 |
+
// patterns containing variable length gaps and then pattern matching
|
| 37 |
+
// against the test set.
|
| 38 |
+
//
|
| 39 |
+
// The algorithm is vaguely similar to Algorithm 1 from Rahman et al. (2006),
|
| 40 |
+
// but with a slightly different definition of a pattern and designed for a
|
| 41 |
+
// text containing sentence boundaries. Here the text is assumed to be
|
| 42 |
+
// short (a few thousand sentences) and the number of patterns is assumed to
|
| 43 |
+
// be large (tens of millions of rules).
|
| 44 |
+
//
|
| 45 |
+
// M. Sohel Rahman, Costas S. Iliopoulos, Inbok Lee, Manal Mohamed, and
|
| 46 |
+
// William F. Smyth
|
| 47 |
+
// "Finding Patterns with Variable Length Gaps or Don't Cares"
|
| 48 |
+
// In proceedings of COCOON, 2006
|
| 49 |
+
|
| 50 |
+
// Max NGram length.
|
| 51 |
+
static const std::size_t kMaxNGramLength;
|
| 52 |
+
|
| 53 |
+
// Maps words from strings to integers.
|
| 54 |
+
typedef NumberedSet<std::string, std::size_t> Vocabulary;
|
| 55 |
+
|
| 56 |
+
// A NGram is a sequence of words.
|
| 57 |
+
typedef std::vector<Vocabulary::IdType> NGram;
|
| 58 |
+
|
| 59 |
+
// A pattern is an alternating sequence of gaps and NGram subpatterns,
|
| 60 |
+
// starting and ending with a gap. Every gap has a minimum width, which
|
| 61 |
+
// can be any integer >= 0 (a gap of width 0 is really a non-gap).
|
| 62 |
+
//
|
| 63 |
+
// The source LHSs of translation rules are converted to patterns where each
|
| 64 |
+
// sequence of m consecutive non-terminals is converted to a gap with minimum
|
| 65 |
+
// width m. For example, if a rule has the source LHS:
|
| 66 |
+
//
|
| 67 |
+
// [NP] and all the king 's men could n't [VB] [NP] together again
|
| 68 |
+
//
|
| 69 |
+
// and kMaxN is set to 5 then the following pattern is used:
|
| 70 |
+
//
|
| 71 |
+
// * <and all the king 's> * <men could n't> * <together again> *
|
| 72 |
+
//
|
| 73 |
+
// where the gaps have minimum widths of 1, 0, 2, and 0.
|
| 74 |
+
//
|
| 75 |
+
struct Pattern {
|
| 76 |
+
std::vector<NGram> subpatterns;
|
| 77 |
+
std::vector<int> minGapWidths;
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
// A sorted (ascending) sequence of start positions.
|
| 81 |
+
typedef std::vector<int> PositionSeq;
|
| 82 |
+
|
| 83 |
+
// A range of start positions.
|
| 84 |
+
typedef std::pair<int, int> Range;
|
| 85 |
+
|
| 86 |
+
// A CoordinateTable records the set of sentences in which a single
|
| 87 |
+
// n-gram occurs and for each of those sentences, the start positions
|
| 88 |
+
struct CoordinateTable {
|
| 89 |
+
// Sentences IDs (ascending). This contains the same values as the key set
|
| 90 |
+
// from intraSentencePositions but sorted into ascending order.
|
| 91 |
+
std::vector<int> sentences;
|
| 92 |
+
// Map from sentence ID to set of intra-sentence start positions.
|
| 93 |
+
boost::unordered_map<int, PositionSeq> intraSentencePositions;
|
| 94 |
+
};
|
| 95 |
+
|
| 96 |
+
// NGramCoordinateMap is the main search structure. It maps a NGram to
|
| 97 |
+
// a CoordinateTable containing the positions that the n-gram occurs at
|
| 98 |
+
// in the test set.
|
| 99 |
+
typedef boost::unordered_map<NGram, CoordinateTable> NGramCoordinateMap;
|
| 100 |
+
|
| 101 |
+
// Add all n-grams and coordinates for a single sentence s with index i.
|
| 102 |
+
void AddSentenceNGrams(const std::vector<Vocabulary::IdType> &s,
|
| 103 |
+
std::size_t i);
|
| 104 |
+
|
| 105 |
+
// Calculate the range of possible start positions for subpattern i+1
|
| 106 |
+
// assuming that subpattern i has position x.
|
| 107 |
+
Range CalcNextRange(const Pattern &p, int i, int x, int sentenceLength) const;
|
| 108 |
+
|
| 109 |
+
// Generate the pattern corresponding to the given source-side of a rule.
|
| 110 |
+
// This will fail if the rule's source-side contains any terminals that
|
| 111 |
+
// do not occur in the test sentence vocabulary.
|
| 112 |
+
bool GeneratePattern(const std::vector<StringPiece> &, Pattern &) const;
|
| 113 |
+
|
| 114 |
+
// Calculate the minimum width of the pattern suffix starting
|
| 115 |
+
// at subpattern i.
|
| 116 |
+
int MinWidth(const Pattern &p, int i) const;
|
| 117 |
+
|
| 118 |
+
bool IsNonTerminal(const StringPiece &symbol) const;
|
| 119 |
+
|
| 120 |
+
// Try to match the pattern p against any sentence in the test set.
|
| 121 |
+
bool MatchPattern(const Pattern &p) const;
|
| 122 |
+
|
| 123 |
+
// Try to match the pattern p against the sentence with the given ID.
|
| 124 |
+
bool MatchPattern(const Pattern &p,
|
| 125 |
+
std::vector<const CoordinateTable *> &tables,
|
| 126 |
+
int id) const;
|
| 127 |
+
|
| 128 |
+
// The main search structure constructed from the test set sentences.
|
| 129 |
+
NGramCoordinateMap m_ngramCoordinateMap;
|
| 130 |
+
|
| 131 |
+
// The lengths of the test sentences.
|
| 132 |
+
std::vector<int> m_sentenceLengths;
|
| 133 |
+
|
| 134 |
+
// The maximum length of any test sentence.
|
| 135 |
+
int m_maxSentenceLength;
|
| 136 |
+
|
| 137 |
+
// The symbol vocabulary of the test sentences.
|
| 138 |
+
Vocabulary m_testVocab;
|
| 139 |
+
};
|
| 140 |
+
|
| 141 |
+
} // namespace FilterRuleTable
|
| 142 |
+
} // namespace Syntax
|
| 143 |
+
} // namespace MosesTraining
|
mosesdecoder/phrase-extract/filter-rule-table/StringForest.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <string>
|
| 4 |
+
|
| 5 |
+
#include "Forest.h"
|
| 6 |
+
|
| 7 |
+
namespace MosesTraining
|
| 8 |
+
{
|
| 9 |
+
namespace Syntax
|
| 10 |
+
{
|
| 11 |
+
namespace FilterRuleTable
|
| 12 |
+
{
|
| 13 |
+
|
| 14 |
+
struct StringForestValue {
|
| 15 |
+
std::string symbol; // terminal or non-terminal (without square brackets)
|
| 16 |
+
std::size_t start;
|
| 17 |
+
std::size_t end;
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
typedef Forest<StringForestValue> StringForest;
|
| 21 |
+
|
| 22 |
+
} // namespace FilterRuleTable
|
| 23 |
+
} // namespace Syntax
|
| 24 |
+
} // namespace Moses
|
mosesdecoder/phrase-extract/filter-rule-table/TreeTsgFilter.h
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <istream>
|
| 4 |
+
#include <ostream>
|
| 5 |
+
#include <string>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
#include <boost/shared_ptr.hpp>
|
| 9 |
+
#include <boost/unordered_map.hpp>
|
| 10 |
+
|
| 11 |
+
#include "SyntaxTree.h"
|
| 12 |
+
|
| 13 |
+
#include "syntax-common/numbered_set.h"
|
| 14 |
+
#include "syntax-common/tree.h"
|
| 15 |
+
#include "syntax-common/tree_fragment_tokenizer.h"
|
| 16 |
+
|
| 17 |
+
#include "TsgFilter.h"
|
| 18 |
+
|
| 19 |
+
namespace MosesTraining
|
| 20 |
+
{
|
| 21 |
+
namespace Syntax
|
| 22 |
+
{
|
| 23 |
+
namespace FilterRuleTable
|
| 24 |
+
{
|
| 25 |
+
|
| 26 |
+
// Filters a rule table, discarding rules that cannot be applied to a given
|
| 27 |
+
// test set. The rule table must have a TSG source-side and the test sentences
|
| 28 |
+
// must be parse trees.
|
| 29 |
+
class TreeTsgFilter : public TsgFilter
|
| 30 |
+
{
|
| 31 |
+
public:
|
| 32 |
+
// Initialize the filter for a given set of test sentences.
|
| 33 |
+
TreeTsgFilter(const std::vector<boost::shared_ptr<SyntaxTree> > &);
|
| 34 |
+
|
| 35 |
+
private:
|
| 36 |
+
// Add an entry to m_labelToTree for every subtree of the given tree.
|
| 37 |
+
void AddNodesToMap(const IdTree &);
|
| 38 |
+
|
| 39 |
+
// Tree-specific implementation of virtual function.
|
| 40 |
+
bool MatchFragment(const IdTree &, const std::vector<IdTree *> &);
|
| 41 |
+
|
| 42 |
+
// Try to match a fragment against a specific subtree of a test tree.
|
| 43 |
+
bool MatchFragment(const IdTree &, const IdTree &);
|
| 44 |
+
|
| 45 |
+
// Convert a SyntaxTree to an IdTree (wrt m_testVocab). Inserts symbols into
|
| 46 |
+
// m_testVocab.
|
| 47 |
+
IdTree *SyntaxTreeToIdTree(const SyntaxTree &);
|
| 48 |
+
|
| 49 |
+
std::vector<boost::shared_ptr<IdTree> > m_sentences;
|
| 50 |
+
std::vector<std::vector<const IdTree *> > m_labelToTree;
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
} // namespace FilterRuleTable
|
| 54 |
+
} // namespace Syntax
|
| 55 |
+
} // namespace MosesTraining
|
mosesdecoder/phrase-extract/filter-rule-table/TsgFilter.h
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <istream>
|
| 4 |
+
#include <ostream>
|
| 5 |
+
#include <string>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
#include "syntax-common/numbered_set.h"
|
| 9 |
+
#include "syntax-common/tree.h"
|
| 10 |
+
#include "syntax-common/tree_fragment_tokenizer.h"
|
| 11 |
+
|
| 12 |
+
namespace MosesTraining
|
| 13 |
+
{
|
| 14 |
+
namespace Syntax
|
| 15 |
+
{
|
| 16 |
+
namespace FilterRuleTable
|
| 17 |
+
{
|
| 18 |
+
|
| 19 |
+
// Base class for TreeTsgFilter and ForestTsgFilter, both of which filter rule
|
| 20 |
+
// tables where the source-side is TSG.
|
| 21 |
+
class TsgFilter
|
| 22 |
+
{
|
| 23 |
+
public:
|
| 24 |
+
virtual ~TsgFilter() {}
|
| 25 |
+
|
| 26 |
+
// Read a rule table from 'in' and filter it according to the test sentences.
|
| 27 |
+
void Filter(std::istream &in, std::ostream &out);
|
| 28 |
+
|
| 29 |
+
protected:
|
| 30 |
+
// Maps symbols (terminals and non-terminals) from strings to integers.
|
| 31 |
+
typedef NumberedSet<std::string, std::size_t> Vocabulary;
|
| 32 |
+
|
| 33 |
+
// Represents a tree using integer vocabulary values.
|
| 34 |
+
typedef Tree<Vocabulary::IdType> IdTree;
|
| 35 |
+
|
| 36 |
+
// Build an IdTree (wrt m_testVocab) for the tree beginning at position i of
|
| 37 |
+
// the token sequence or return 0 if any symbol in the fragment is not in
|
| 38 |
+
// m_testVocab. If successful then on return, i will be set to the position
|
| 39 |
+
// immediately after the last token of the tree and leaves will contain the
|
| 40 |
+
// pointers to the fragment's leaves. If the build fails then i and leaves
|
| 41 |
+
// are undefined.
|
| 42 |
+
IdTree *BuildTree(const std::vector<TreeFragmentToken> &tokens, int &i,
|
| 43 |
+
std::vector<IdTree *> &leaves);
|
| 44 |
+
|
| 45 |
+
// Try to match a fragment. The implementation depends on whether the test
|
| 46 |
+
// sentences are trees or forests.
|
| 47 |
+
virtual bool MatchFragment(const IdTree &, const std::vector<IdTree *> &) = 0;
|
| 48 |
+
|
| 49 |
+
// The symbol vocabulary of the test sentences.
|
| 50 |
+
Vocabulary m_testVocab;
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
} // namespace FilterRuleTable
|
| 54 |
+
} // namespace Syntax
|
| 55 |
+
} // namespace MosesTraining
|
mosesdecoder/phrase-extract/lexical-reordering/InputFileStream.cpp
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// $Id: InputFileStream.cpp 2780 2010-01-29 17:11:17Z bojar $
|
| 2 |
+
|
| 3 |
+
/***********************************************************************
|
| 4 |
+
Moses - factored phrase-based language decoder
|
| 5 |
+
Copyright (C) 2006 University of Edinburgh
|
| 6 |
+
|
| 7 |
+
This library is free software; you can redistribute it and/or
|
| 8 |
+
modify it under the terms of the GNU Lesser General Public
|
| 9 |
+
License as published by the Free Software Foundation; either
|
| 10 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 11 |
+
|
| 12 |
+
This library is distributed in the hope that it will be useful,
|
| 13 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 14 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 15 |
+
Lesser General Public License for more details.
|
| 16 |
+
|
| 17 |
+
You should have received a copy of the GNU Lesser General Public
|
| 18 |
+
License along with this library; if not, write to the Free Software
|
| 19 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 20 |
+
***********************************************************************/
|
| 21 |
+
|
| 22 |
+
#include "InputFileStream.h"
|
| 23 |
+
#include "gzfilebuf.h"
|
| 24 |
+
#include <iostream>
|
| 25 |
+
#include <boost/algorithm/string/predicate.hpp>
|
| 26 |
+
|
| 27 |
+
using namespace std;
|
| 28 |
+
using namespace boost::algorithm;
|
| 29 |
+
|
| 30 |
+
namespace Moses
|
| 31 |
+
{
|
| 32 |
+
InputFileStream::InputFileStream(const std::string &filePath)
|
| 33 |
+
: std::istream(NULL)
|
| 34 |
+
, m_streambuf(NULL)
|
| 35 |
+
{
|
| 36 |
+
Open(filePath);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
InputFileStream::~InputFileStream()
|
| 40 |
+
{
|
| 41 |
+
Close();
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
void InputFileStream::Open(const std::string &filePath)
|
| 45 |
+
{
|
| 46 |
+
if (ends_with(filePath, ".gz")) {
|
| 47 |
+
m_streambuf = new gzfilebuf(filePath.c_str());
|
| 48 |
+
} else {
|
| 49 |
+
std::filebuf* fb = new std::filebuf();
|
| 50 |
+
fb = fb->open(filePath.c_str(), std::ios::in);
|
| 51 |
+
if (! fb) {
|
| 52 |
+
cerr << "Can't read " << filePath.c_str() << endl;
|
| 53 |
+
exit(1);
|
| 54 |
+
}
|
| 55 |
+
m_streambuf = fb;
|
| 56 |
+
}
|
| 57 |
+
this->init(m_streambuf);
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
void InputFileStream::Close()
|
| 61 |
+
{
|
| 62 |
+
delete m_streambuf;
|
| 63 |
+
m_streambuf = NULL;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
}
|
| 68 |
+
|
mosesdecoder/phrase-extract/lexical-reordering/InputFileStream.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// $Id: InputFileStream.h 2939 2010-02-24 11:15:44Z jfouet $
|
| 2 |
+
|
| 3 |
+
/***********************************************************************
|
| 4 |
+
Moses - factored phrase-based language decoder
|
| 5 |
+
Copyright (C) 2006 University of Edinburgh
|
| 6 |
+
|
| 7 |
+
This library is free software; you can redistribute it and/or
|
| 8 |
+
modify it under the terms of the GNU Lesser General Public
|
| 9 |
+
License as published by the Free Software Foundation; either
|
| 10 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 11 |
+
|
| 12 |
+
This library is distributed in the hope that it will be useful,
|
| 13 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 14 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 15 |
+
Lesser General Public License for more details.
|
| 16 |
+
|
| 17 |
+
You should have received a copy of the GNU Lesser General Public
|
| 18 |
+
License along with this library; if not, write to the Free Software
|
| 19 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 20 |
+
***********************************************************************/
|
| 21 |
+
|
| 22 |
+
#ifndef moses_InputFileStream_h
|
| 23 |
+
#define moses_InputFileStream_h
|
| 24 |
+
|
| 25 |
+
#include <cstdlib>
|
| 26 |
+
#include <fstream>
|
| 27 |
+
#include <string>
|
| 28 |
+
|
| 29 |
+
namespace Moses
|
| 30 |
+
{
|
| 31 |
+
|
| 32 |
+
/** Used in place of std::istream, can read zipped files if it ends in .gz
|
| 33 |
+
*/
|
| 34 |
+
class InputFileStream : public std::istream
|
| 35 |
+
{
|
| 36 |
+
protected:
|
| 37 |
+
std::streambuf *m_streambuf;
|
| 38 |
+
public:
|
| 39 |
+
|
| 40 |
+
explicit InputFileStream(const std::string &filePath);
|
| 41 |
+
~InputFileStream();
|
| 42 |
+
|
| 43 |
+
void Open(const std::string &filePath);
|
| 44 |
+
void Close();
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
#endif
|
mosesdecoder/phrase-extract/lexical-reordering/Jamfile
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exe lexical-reordering-score : InputFileStream.cpp reordering_classes.cpp score.cpp ../OutputFileStream.cpp ../..//boost_iostreams ../..//boost_filesystem ../../util//kenutil ../..//z ;
|
| 2 |
+
|
mosesdecoder/phrase-extract/lexical-reordering/gzfilebuf.h
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef moses_gzfile_buf_h
|
| 2 |
+
#define moses_gzfile_buf_h
|
| 3 |
+
|
| 4 |
+
#include <stdexcept>
|
| 5 |
+
#include <streambuf>
|
| 6 |
+
#include <zlib.h>
|
| 7 |
+
#include <cstring>
|
| 8 |
+
|
| 9 |
+
class gzfilebuf : public std::streambuf
|
| 10 |
+
{
|
| 11 |
+
public:
|
| 12 |
+
gzfilebuf(const char *filename) {
|
| 13 |
+
_gzf = gzopen(filename, "rb");
|
| 14 |
+
if (!_gzf)
|
| 15 |
+
throw std::runtime_error("Could not open " + std::string(filename) + ".");
|
| 16 |
+
setg (_buff+sizeof(int), // beginning of putback area
|
| 17 |
+
_buff+sizeof(int), // read position
|
| 18 |
+
_buff+sizeof(int)); // end position
|
| 19 |
+
}
|
| 20 |
+
~gzfilebuf() {
|
| 21 |
+
gzclose(_gzf);
|
| 22 |
+
}
|
| 23 |
+
protected:
|
| 24 |
+
virtual int_type overflow (int_type c) {
|
| 25 |
+
throw;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
// write multiple characters
|
| 29 |
+
virtual
|
| 30 |
+
std::streamsize xsputn (const char* s,
|
| 31 |
+
std::streamsize num) {
|
| 32 |
+
throw;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
virtual std::streampos seekpos ( std::streampos sp, std::ios_base::openmode which = std::ios_base::in | std::ios_base::out ) {
|
| 36 |
+
throw;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
//read one character
|
| 40 |
+
virtual int_type underflow () {
|
| 41 |
+
// is read position before end of _buff?
|
| 42 |
+
if (gptr() < egptr()) {
|
| 43 |
+
return traits_type::to_int_type(*gptr());
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
/* process size of putback area
|
| 47 |
+
* - use number of characters read
|
| 48 |
+
* - but at most four
|
| 49 |
+
*/
|
| 50 |
+
unsigned int numPutback = gptr() - eback();
|
| 51 |
+
if (numPutback > sizeof(int)) {
|
| 52 |
+
numPutback = sizeof(int);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
/* copy up to four characters previously read into
|
| 56 |
+
* the putback _buff (area of first four characters)
|
| 57 |
+
*/
|
| 58 |
+
std::memmove (_buff+(sizeof(int)-numPutback), gptr()-numPutback,
|
| 59 |
+
numPutback);
|
| 60 |
+
|
| 61 |
+
// read new characters
|
| 62 |
+
int num = gzread(_gzf, _buff+sizeof(int), _buffsize-sizeof(int));
|
| 63 |
+
if (num <= 0) {
|
| 64 |
+
// ERROR or EOF
|
| 65 |
+
return EOF;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// reset _buff pointers
|
| 69 |
+
setg (_buff+(sizeof(int)-numPutback), // beginning of putback area
|
| 70 |
+
_buff+sizeof(int), // read position
|
| 71 |
+
_buff+sizeof(int)+num); // end of buffer
|
| 72 |
+
|
| 73 |
+
// return next character
|
| 74 |
+
return traits_type::to_int_type(*gptr());
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
std::streamsize xsgetn (char* s,
|
| 78 |
+
std::streamsize num) {
|
| 79 |
+
return gzread(_gzf,s,num);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
private:
|
| 83 |
+
gzFile _gzf;
|
| 84 |
+
static const unsigned int _buffsize = 1024;
|
| 85 |
+
char _buff[_buffsize];
|
| 86 |
+
};
|
| 87 |
+
|
| 88 |
+
#endif
|
mosesdecoder/phrase-extract/lexical-reordering/reordering_classes.cpp
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#include <vector>
|
| 3 |
+
#include <iostream>
|
| 4 |
+
#include <cstdlib>
|
| 5 |
+
#include <numeric>
|
| 6 |
+
#include <cstdio>
|
| 7 |
+
#include <sstream>
|
| 8 |
+
#include <string>
|
| 9 |
+
#include "zlib.h"
|
| 10 |
+
|
| 11 |
+
#include "reordering_classes.h"
|
| 12 |
+
|
| 13 |
+
using namespace std;
|
| 14 |
+
|
| 15 |
+
ModelScore::ModelScore()
|
| 16 |
+
{
|
| 17 |
+
for(int i=MONO; i<=NOMONO; ++i) {
|
| 18 |
+
count_fe_prev.push_back(0);
|
| 19 |
+
count_fe_next.push_back(0);
|
| 20 |
+
count_f_prev.push_back(0);
|
| 21 |
+
count_f_next.push_back(0);
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
ModelScore::~ModelScore() {}
|
| 26 |
+
|
| 27 |
+
ModelScore* ModelScore::createModelScore(const string& modeltype)
|
| 28 |
+
{
|
| 29 |
+
if (modeltype.compare("mslr") == 0) {
|
| 30 |
+
return new ModelScoreMSLR();
|
| 31 |
+
} else if (modeltype.compare("msd") == 0) {
|
| 32 |
+
return new ModelScoreMSD();
|
| 33 |
+
} else if (modeltype.compare("monotonicity") == 0 ) {
|
| 34 |
+
return new ModelScoreMonotonicity();
|
| 35 |
+
} else if (modeltype.compare("leftright") == 0) {
|
| 36 |
+
return new ModelScoreLR();
|
| 37 |
+
} else {
|
| 38 |
+
cerr << "Illegal model type given for lexical reordering model scoring: "
|
| 39 |
+
<< modeltype
|
| 40 |
+
<< ". The allowed types are: mslr, msd, monotonicity, leftright"
|
| 41 |
+
<< endl;
|
| 42 |
+
exit(1);
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
void ModelScore::reset_fe()
|
| 47 |
+
{
|
| 48 |
+
for(int i=MONO; i<=NOMONO; ++i) {
|
| 49 |
+
count_fe_prev[i] = 0;
|
| 50 |
+
count_fe_next[i] = 0;
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
void ModelScore::reset_f()
|
| 55 |
+
{
|
| 56 |
+
for(int i=MONO; i<=NOMONO; ++i) {
|
| 57 |
+
count_f_prev[i] = 0;
|
| 58 |
+
count_f_next[i] = 0;
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
void ModelScore::add_example
|
| 63 |
+
(const StringPiece& previous, const StringPiece& next, float weight)
|
| 64 |
+
{
|
| 65 |
+
count_fe_prev[getType(previous)]+=weight;
|
| 66 |
+
count_f_prev[getType(previous)]+=weight;
|
| 67 |
+
count_fe_next[getType(next)]+=weight;
|
| 68 |
+
count_f_next[getType(next)]+=weight;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
const vector<double>& ModelScore::get_scores_fe_prev() const
|
| 72 |
+
{
|
| 73 |
+
return count_fe_prev;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
const vector<double>& ModelScore::get_scores_fe_next() const
|
| 77 |
+
{
|
| 78 |
+
return count_fe_next;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
const vector<double>& ModelScore::get_scores_f_prev() const
|
| 82 |
+
{
|
| 83 |
+
return count_f_prev;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
const vector<double>& ModelScore::get_scores_f_next() const
|
| 87 |
+
{
|
| 88 |
+
return count_f_next;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
ORIENTATION ModelScore::getType(const StringPiece& s)
|
| 93 |
+
{
|
| 94 |
+
if (s.compare("mono") == 0) {
|
| 95 |
+
return MONO;
|
| 96 |
+
} else if (s.compare("swap") == 0) {
|
| 97 |
+
return SWAP;
|
| 98 |
+
} else if (s.compare("dright") == 0) {
|
| 99 |
+
return DRIGHT;
|
| 100 |
+
} else if (s.compare("dleft") == 0) {
|
| 101 |
+
return DLEFT;
|
| 102 |
+
} else if (s.compare("other") == 0) {
|
| 103 |
+
return OTHER;
|
| 104 |
+
} else if (s.compare("nomono") == 0) {
|
| 105 |
+
return NOMONO;
|
| 106 |
+
} else {
|
| 107 |
+
cerr << "Illegal reordering type used: " << s << endl;
|
| 108 |
+
exit(1);
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
ORIENTATION ModelScoreMSLR::getType(const StringPiece& s)
|
| 114 |
+
{
|
| 115 |
+
if (s.compare("mono") == 0) {
|
| 116 |
+
return MONO;
|
| 117 |
+
} else if (s.compare("swap") == 0) {
|
| 118 |
+
return SWAP;
|
| 119 |
+
} else if (s.compare("dright") == 0) {
|
| 120 |
+
return DRIGHT;
|
| 121 |
+
} else if (s.compare("dleft") == 0) {
|
| 122 |
+
return DLEFT;
|
| 123 |
+
} else if (s.compare("other") == 0 || s.compare("nomono") == 0) {
|
| 124 |
+
cerr << "Illegal reordering type used: " << s << " for model type mslr. You have to re-run step 5 in order to train such a model." << endl;
|
| 125 |
+
exit(1);
|
| 126 |
+
} else {
|
| 127 |
+
cerr << "Illegal reordering type used: " << s << endl;
|
| 128 |
+
exit(1);
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
ORIENTATION ModelScoreLR::getType(const StringPiece& s)
|
| 134 |
+
{
|
| 135 |
+
if (s.compare("mono") == 0 || s.compare("dright") == 0) {
|
| 136 |
+
return DRIGHT;
|
| 137 |
+
} else if (s.compare("swap") == 0 || s.compare("dleft") == 0) {
|
| 138 |
+
return DLEFT;
|
| 139 |
+
} else if (s.compare("other") == 0 || s.compare("nomono") == 0) {
|
| 140 |
+
cerr << "Illegal reordering type used: " << s << " for model type LeftRight. You have to re-run step 5 in order to train such a model." << endl;
|
| 141 |
+
exit(1);
|
| 142 |
+
} else {
|
| 143 |
+
cerr << "Illegal reordering type used: " << s << endl;
|
| 144 |
+
exit(1);
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
ORIENTATION ModelScoreMSD::getType(const StringPiece& s)
|
| 150 |
+
{
|
| 151 |
+
if (s.compare("mono") == 0) {
|
| 152 |
+
return MONO;
|
| 153 |
+
} else if (s.compare("swap") == 0) {
|
| 154 |
+
return SWAP;
|
| 155 |
+
} else if (s.compare("dleft") == 0 ||
|
| 156 |
+
s.compare("dright") == 0 ||
|
| 157 |
+
s.compare("other") == 0) {
|
| 158 |
+
return OTHER;
|
| 159 |
+
} else if (s.compare("nomono") == 0) {
|
| 160 |
+
cerr << "Illegal reordering type used: " << s << " for model type msd. You have to re-run step 5 in order to train such a model." << endl;
|
| 161 |
+
exit(1);
|
| 162 |
+
} else {
|
| 163 |
+
cerr << "Illegal reordering type used: " << s << endl;
|
| 164 |
+
exit(1);
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
ORIENTATION ModelScoreMonotonicity::getType(const StringPiece& s)
|
| 169 |
+
{
|
| 170 |
+
if (s.compare("mono") == 0) {
|
| 171 |
+
return MONO;
|
| 172 |
+
} else if (s.compare("swap") == 0 ||
|
| 173 |
+
s.compare("dleft") == 0 ||
|
| 174 |
+
s.compare("dright") == 0 ||
|
| 175 |
+
s.compare("other") == 0 ||
|
| 176 |
+
s.compare("nomono") == 0 ) {
|
| 177 |
+
return NOMONO;
|
| 178 |
+
} else {
|
| 179 |
+
cerr << "Illegal reordering type used: " << s << endl;
|
| 180 |
+
exit(1);
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
void ScorerMSLR::score(const vector<double>& all_scores, vector<double>& scores) const
|
| 187 |
+
{
|
| 188 |
+
scores.push_back(all_scores[MONO]);
|
| 189 |
+
scores.push_back(all_scores[SWAP]);
|
| 190 |
+
scores.push_back(all_scores[DLEFT]);
|
| 191 |
+
scores.push_back(all_scores[DRIGHT]);
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
void ScorerMSD::score(const vector<double>& all_scores, vector<double>& scores) const
|
| 195 |
+
{
|
| 196 |
+
scores.push_back(all_scores[MONO]);
|
| 197 |
+
scores.push_back(all_scores[SWAP]);
|
| 198 |
+
scores.push_back(all_scores[DRIGHT]+all_scores[DLEFT]+all_scores[OTHER]);
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
void ScorerMonotonicity::score(const vector<double>& all_scores, vector<double>& scores) const
|
| 202 |
+
{
|
| 203 |
+
scores.push_back(all_scores[MONO]);
|
| 204 |
+
scores.push_back(all_scores[SWAP]+all_scores[DRIGHT]+all_scores[DLEFT]+all_scores[OTHER]+all_scores[NOMONO]);
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
void ScorerLR::score(const vector<double>& all_scores, vector<double>& scores) const
|
| 209 |
+
{
|
| 210 |
+
scores.push_back(all_scores[MONO]+all_scores[DRIGHT]);
|
| 211 |
+
scores.push_back(all_scores[SWAP]+all_scores[DLEFT]);
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
void ScorerMSLR::createSmoothing(const vector<double>& scores, double weight, vector<double>& smoothing) const
|
| 216 |
+
{
|
| 217 |
+
double total = accumulate(scores.begin(), scores.end(), 0);
|
| 218 |
+
smoothing.push_back(weight*(scores[MONO]+0.1)/total);
|
| 219 |
+
smoothing.push_back(weight*(scores[SWAP]+0.1)/total);
|
| 220 |
+
smoothing.push_back(weight*(scores[DLEFT]+0.1)/total);
|
| 221 |
+
smoothing.push_back(weight*(scores[DRIGHT]+0.1)/total);
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
void ScorerMSLR::createConstSmoothing(double weight, vector<double>& smoothing) const
|
| 225 |
+
{
|
| 226 |
+
for (int i=1; i<=4; ++i) {
|
| 227 |
+
smoothing.push_back(weight);
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
void ScorerMSD::createSmoothing(const vector<double>& scores, double weight, vector<double>& smoothing) const
|
| 233 |
+
{
|
| 234 |
+
double total = accumulate(scores.begin(), scores.end(), 0);
|
| 235 |
+
smoothing.push_back(weight*(scores[MONO]+0.1)/total);
|
| 236 |
+
smoothing.push_back(weight*(scores[SWAP]+0.1)/total);
|
| 237 |
+
smoothing.push_back(weight*(scores[DLEFT]+scores[DRIGHT]+scores[OTHER]+0.1)/total);
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
void ScorerMSD::createConstSmoothing(double weight, vector<double>& smoothing) const
|
| 241 |
+
{
|
| 242 |
+
for (int i=1; i<=3; ++i) {
|
| 243 |
+
smoothing.push_back(weight);
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
void ScorerMonotonicity::createSmoothing(const vector<double>& scores, double weight, vector<double>& smoothing) const
|
| 248 |
+
{
|
| 249 |
+
double total = accumulate(scores.begin(), scores.end(), 0);
|
| 250 |
+
smoothing.push_back(weight*(scores[MONO]+0.1)/total);
|
| 251 |
+
smoothing.push_back(weight*(scores[SWAP]+scores[DLEFT]+scores[DRIGHT]+scores[OTHER]+scores[NOMONO]+0.1)/total);
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
void ScorerMonotonicity::createConstSmoothing(double weight, vector<double>& smoothing) const
|
| 255 |
+
{
|
| 256 |
+
for (double i=1; i<=2; ++i) {
|
| 257 |
+
smoothing.push_back(weight);
|
| 258 |
+
}
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
void ScorerLR::createSmoothing(const vector<double>& scores, double weight, vector<double>& smoothing) const
|
| 263 |
+
{
|
| 264 |
+
double total = accumulate(scores.begin(), scores.end(), 0);
|
| 265 |
+
smoothing.push_back(weight*(scores[MONO]+scores[DRIGHT]+0.1)/total);
|
| 266 |
+
smoothing.push_back(weight*(scores[SWAP]+scores[DLEFT])/total);
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
void ScorerLR::createConstSmoothing(double weight, vector<double>& smoothing) const
|
| 270 |
+
{
|
| 271 |
+
for (int i=1; i<=2; ++i) {
|
| 272 |
+
smoothing.push_back(weight);
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
void Model::score_fe(const string& f, const string& e)
|
| 277 |
+
{
|
| 278 |
+
if (!fe) //Make sure we do not do anything if it is not a fe model
|
| 279 |
+
return;
|
| 280 |
+
outputFile << f << " ||| " << e << " |||";
|
| 281 |
+
//condition on the previous phrase
|
| 282 |
+
if (previous) {
|
| 283 |
+
vector<double> scores;
|
| 284 |
+
scorer->score(modelscore->get_scores_fe_prev(), scores);
|
| 285 |
+
double sum = 0;
|
| 286 |
+
for(size_t i=0; i<scores.size(); ++i) {
|
| 287 |
+
scores[i] += smoothing_prev[i];
|
| 288 |
+
sum += scores[i];
|
| 289 |
+
}
|
| 290 |
+
for(size_t i=0; i<scores.size(); ++i) {
|
| 291 |
+
outputFile << " " << (scores[i]/sum);
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
//condition on the next phrase
|
| 295 |
+
if (next) {
|
| 296 |
+
vector<double> scores;
|
| 297 |
+
scorer->score(modelscore->get_scores_fe_next(), scores);
|
| 298 |
+
double sum = 0;
|
| 299 |
+
for(size_t i=0; i<scores.size(); ++i) {
|
| 300 |
+
scores[i] += smoothing_next[i];
|
| 301 |
+
sum += scores[i];
|
| 302 |
+
}
|
| 303 |
+
for(size_t i=0; i<scores.size(); ++i) {
|
| 304 |
+
outputFile << " " << (scores[i]/sum);
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
outputFile << endl;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
void Model::score_f(const string& f)
|
| 311 |
+
{
|
| 312 |
+
if (fe) //Make sure we do not do anything if it is not a f model
|
| 313 |
+
return;
|
| 314 |
+
cout << f << " |||";
|
| 315 |
+
//condition on the previous phrase
|
| 316 |
+
if (previous) {
|
| 317 |
+
vector<double> scores;
|
| 318 |
+
scorer->score(modelscore->get_scores_f_prev(), scores);
|
| 319 |
+
double sum = 0;
|
| 320 |
+
for(size_t i=0; i<scores.size(); ++i) {
|
| 321 |
+
scores[i] += smoothing_prev[i];
|
| 322 |
+
sum += scores[i];
|
| 323 |
+
}
|
| 324 |
+
for(size_t i=0; i<scores.size(); ++i) {
|
| 325 |
+
outputFile << " " << (scores[i]/sum);
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
//condition on the next phrase
|
| 329 |
+
if (next) {
|
| 330 |
+
vector<double> scores;
|
| 331 |
+
scorer->score(modelscore->get_scores_f_next(), scores);
|
| 332 |
+
double sum = 0;
|
| 333 |
+
for(size_t i=0; i<scores.size(); ++i) {
|
| 334 |
+
scores[i] += smoothing_next[i];
|
| 335 |
+
sum += scores[i];
|
| 336 |
+
}
|
| 337 |
+
for(size_t i=0; i<scores.size(); ++i) {
|
| 338 |
+
outputFile << " " << (scores[i]/sum);
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
outputFile << endl;
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
Model::Model(ModelScore* ms, Scorer* sc, const string& dir, const string& lang, const string& fn)
|
| 345 |
+
: modelscore(ms), scorer(sc), filename(fn)
|
| 346 |
+
{
|
| 347 |
+
outputFile.Open( (filename+".gz").c_str() );
|
| 348 |
+
fe = false;
|
| 349 |
+
if (lang.compare("fe") == 0) {
|
| 350 |
+
fe = true;
|
| 351 |
+
} else if (lang.compare("f") != 0) {
|
| 352 |
+
cerr << "You have given an illegal language to condition on: " << lang
|
| 353 |
+
<< "\nLegal types: fe (on both languages), f (only on source language)\n";
|
| 354 |
+
exit(1);
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
previous = true;
|
| 358 |
+
next = true;
|
| 359 |
+
if (dir.compare("backward") == 0) {
|
| 360 |
+
next = false;
|
| 361 |
+
} else if (dir.compare("forward") == 0) {
|
| 362 |
+
previous = false;
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
Model::~Model()
|
| 367 |
+
{
|
| 368 |
+
outputFile.Close();
|
| 369 |
+
delete modelscore;
|
| 370 |
+
delete scorer;
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
void Model::split_config(const string& config, string& dir, string& lang, string& orient)
|
| 374 |
+
{
|
| 375 |
+
istringstream is(config);
|
| 376 |
+
string type;
|
| 377 |
+
getline(is, type, '-');
|
| 378 |
+
getline(is, orient, '-');
|
| 379 |
+
getline(is, dir, '-');
|
| 380 |
+
getline(is, lang, '-');
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
Model* Model::createModel(ModelScore* modelscore, const string& config, const string& filepath)
|
| 384 |
+
{
|
| 385 |
+
string dir, lang, orient, filename;
|
| 386 |
+
split_config(config,dir,lang,orient);
|
| 387 |
+
|
| 388 |
+
filename = filepath + config;
|
| 389 |
+
if (orient.compare("mslr") == 0) {
|
| 390 |
+
return new Model(modelscore, new ScorerMSLR(), dir, lang, filename);
|
| 391 |
+
} else if (orient.compare("msd") == 0) {
|
| 392 |
+
return new Model(modelscore, new ScorerMSD(), dir, lang, filename);
|
| 393 |
+
} else if (orient.compare("monotonicity") == 0) {
|
| 394 |
+
return new Model(modelscore, new ScorerMonotonicity(), dir, lang, filename);
|
| 395 |
+
} else if (orient.compare("leftright") == 0) {
|
| 396 |
+
return new Model(modelscore, new ScorerLR(), dir, lang, filename);
|
| 397 |
+
} else {
|
| 398 |
+
cerr << "Illegal orientation type of reordering model: " << orient
|
| 399 |
+
<< "\n allowed types: mslr, msd, monotonicity, leftright\n";
|
| 400 |
+
exit(1);
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
void Model::createSmoothing(double w)
|
| 407 |
+
{
|
| 408 |
+
scorer->createSmoothing(modelscore->get_scores_fe_prev(), w, smoothing_prev);
|
| 409 |
+
scorer->createSmoothing(modelscore->get_scores_fe_next(), w, smoothing_next);
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
void Model::createConstSmoothing(double w)
|
| 413 |
+
{
|
| 414 |
+
scorer->createConstSmoothing(w, smoothing_prev);
|
| 415 |
+
scorer->createConstSmoothing(w, smoothing_next);
|
| 416 |
+
}
|
mosesdecoder/phrase-extract/lexical-reordering/reordering_classes.h
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* reordering_classes.h
|
| 3 |
+
* Utility classes for lexical reordering table scoring
|
| 4 |
+
*
|
| 5 |
+
* Created by: Sara Stymne - Linköping University
|
| 6 |
+
* Machine Translation Marathon 2010, Dublin
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <vector>
|
| 12 |
+
#include <string>
|
| 13 |
+
#include <fstream>
|
| 14 |
+
|
| 15 |
+
#include "util/string_piece.hh"
|
| 16 |
+
#include "../OutputFileStream.h"
|
| 17 |
+
|
| 18 |
+
enum ORIENTATION {MONO, SWAP, DRIGHT, DLEFT, OTHER, NOMONO};
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
//Keeps the counts for the different reordering types
|
| 22 |
+
//(Instantiated in 1-3 instances, one for each type of model (hier, phrase, wbe))
|
| 23 |
+
class ModelScore
|
| 24 |
+
{
|
| 25 |
+
private:
|
| 26 |
+
std::vector<double> count_fe_prev;
|
| 27 |
+
std::vector<double> count_fe_next;
|
| 28 |
+
std::vector<double> count_f_prev;
|
| 29 |
+
std::vector<double> count_f_next;
|
| 30 |
+
|
| 31 |
+
protected:
|
| 32 |
+
virtual ORIENTATION getType(const StringPiece& s);
|
| 33 |
+
|
| 34 |
+
public:
|
| 35 |
+
ModelScore();
|
| 36 |
+
virtual ~ModelScore();
|
| 37 |
+
void add_example(const StringPiece& previous, const StringPiece& next, float weight);
|
| 38 |
+
void reset_fe();
|
| 39 |
+
void reset_f();
|
| 40 |
+
const std::vector<double>& get_scores_fe_prev() const;
|
| 41 |
+
const std::vector<double>& get_scores_fe_next() const;
|
| 42 |
+
const std::vector<double>& get_scores_f_prev() const;
|
| 43 |
+
const std::vector<double>& get_scores_f_next() const;
|
| 44 |
+
|
| 45 |
+
static ModelScore* createModelScore(const std::string& modeltype);
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
class ModelScoreMSLR : public ModelScore
|
| 49 |
+
{
|
| 50 |
+
protected:
|
| 51 |
+
virtual ORIENTATION getType(const StringPiece& s);
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
class ModelScoreLR : public ModelScore
|
| 55 |
+
{
|
| 56 |
+
protected:
|
| 57 |
+
virtual ORIENTATION getType(const StringPiece& s);
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
class ModelScoreMSD : public ModelScore
|
| 61 |
+
{
|
| 62 |
+
protected:
|
| 63 |
+
virtual ORIENTATION getType(const StringPiece& s);
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
class ModelScoreMonotonicity : public ModelScore
|
| 67 |
+
{
|
| 68 |
+
protected:
|
| 69 |
+
virtual ORIENTATION getType(const StringPiece& s);
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
//Class for calculating total counts, and to calculate smoothing
|
| 73 |
+
class Scorer
|
| 74 |
+
{
|
| 75 |
+
public:
|
| 76 |
+
virtual ~Scorer() {}
|
| 77 |
+
virtual void score(const std::vector<double>&, std::vector<double>&) const = 0;
|
| 78 |
+
virtual void createSmoothing(const std::vector<double>&, double, std::vector<double>&) const = 0;
|
| 79 |
+
virtual void createConstSmoothing(double, std::vector<double>&) const = 0;
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
class ScorerMSLR : public Scorer
|
| 83 |
+
{
|
| 84 |
+
public:
|
| 85 |
+
virtual void score(const std::vector<double>&, std::vector<double>&) const;
|
| 86 |
+
virtual void createSmoothing(const std::vector<double>&, double, std::vector<double>&) const;
|
| 87 |
+
virtual void createConstSmoothing(double, std::vector<double>&) const;
|
| 88 |
+
};
|
| 89 |
+
|
| 90 |
+
class ScorerMSD : public Scorer
|
| 91 |
+
{
|
| 92 |
+
public:
|
| 93 |
+
virtual void score(const std::vector<double>&, std::vector<double>&) const;
|
| 94 |
+
virtual void createSmoothing(const std::vector<double>&, double, std::vector<double>&) const;
|
| 95 |
+
virtual void createConstSmoothing(double, std::vector<double>&) const;
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
class ScorerMonotonicity : public Scorer
|
| 99 |
+
{
|
| 100 |
+
public:
|
| 101 |
+
virtual void score(const std::vector<double>&, std::vector<double>&) const;
|
| 102 |
+
virtual void createSmoothing(const std::vector<double>&, double, std::vector<double>&) const;
|
| 103 |
+
virtual void createConstSmoothing(double, std::vector<double>&) const;
|
| 104 |
+
};
|
| 105 |
+
|
| 106 |
+
class ScorerLR : public Scorer
|
| 107 |
+
{
|
| 108 |
+
public:
|
| 109 |
+
virtual void score(const std::vector<double>&, std::vector<double>&) const;
|
| 110 |
+
virtual void createSmoothing(const std::vector<double>&, double, std::vector<double>&) const;
|
| 111 |
+
virtual void createConstSmoothing(double, std::vector<double>&) const;
|
| 112 |
+
};
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
//Class for representing each model
|
| 116 |
+
//Contains a modelscore and scorer (which can be of different model types (mslr, msd...)),
|
| 117 |
+
//and file handling.
|
| 118 |
+
//This class also keeps track of bidirectionality, and which language to condition on
|
| 119 |
+
class Model
|
| 120 |
+
{
|
| 121 |
+
private:
|
| 122 |
+
ModelScore* modelscore;
|
| 123 |
+
Scorer* scorer;
|
| 124 |
+
|
| 125 |
+
std::string filename;
|
| 126 |
+
Moses::OutputFileStream outputFile;
|
| 127 |
+
|
| 128 |
+
bool fe;
|
| 129 |
+
bool previous;
|
| 130 |
+
bool next;
|
| 131 |
+
|
| 132 |
+
std::vector<double> smoothing_prev;
|
| 133 |
+
std::vector<double> smoothing_next;
|
| 134 |
+
|
| 135 |
+
static void split_config(const std::string& config, std::string& dir,
|
| 136 |
+
std::string& lang, std::string& orient);
|
| 137 |
+
public:
|
| 138 |
+
Model(ModelScore* ms, Scorer* sc, const std::string& dir,
|
| 139 |
+
const std::string& lang, const std::string& fn);
|
| 140 |
+
~Model();
|
| 141 |
+
static Model* createModel(ModelScore*, const std::string&, const std::string&);
|
| 142 |
+
void createSmoothing(double w);
|
| 143 |
+
void createConstSmoothing(double w);
|
| 144 |
+
void score_fe(const std::string& f, const std::string& e);
|
| 145 |
+
void score_f(const std::string& f);
|
| 146 |
+
void zipFile();
|
| 147 |
+
};
|
| 148 |
+
|
mosesdecoder/phrase-extract/lexical-reordering/score.cpp
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* score_reordering.cpp
|
| 3 |
+
*
|
| 4 |
+
* Created by: Sara Stymne - Linköping University
|
| 5 |
+
* Machine Translation Marathon 2010, Dublin
|
| 6 |
+
*/
|
| 7 |
+
|
| 8 |
+
#include <string>
|
| 9 |
+
#include <vector>
|
| 10 |
+
#include <map>
|
| 11 |
+
#include <iostream>
|
| 12 |
+
#include <fstream>
|
| 13 |
+
#include <sstream>
|
| 14 |
+
#include <cstdlib>
|
| 15 |
+
#include <cstring>
|
| 16 |
+
|
| 17 |
+
#include "util/exception.hh"
|
| 18 |
+
#include "util/file_piece.hh"
|
| 19 |
+
#include "util/string_piece.hh"
|
| 20 |
+
#include "util/tokenize_piece.hh"
|
| 21 |
+
|
| 22 |
+
#include "InputFileStream.h"
|
| 23 |
+
#include "reordering_classes.h"
|
| 24 |
+
|
| 25 |
+
using namespace std;
|
| 26 |
+
|
| 27 |
+
void split_line(const StringPiece& line, StringPiece& foreign, StringPiece& english, StringPiece& wbe, StringPiece& phrase, StringPiece& hier, float& weight);
|
| 28 |
+
void get_orientations(const StringPiece& pair, StringPiece& previous, StringPiece& next);
|
| 29 |
+
|
| 30 |
+
class FileFormatException : public util::Exception
|
| 31 |
+
{
|
| 32 |
+
public:
|
| 33 |
+
FileFormatException() throw() {
|
| 34 |
+
*this << "Invalid extract file format: ";
|
| 35 |
+
}
|
| 36 |
+
~FileFormatException() throw() {}
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
int main(int argc, char* argv[])
|
| 40 |
+
{
|
| 41 |
+
|
| 42 |
+
cerr << "Lexical Reordering Scorer\n"
|
| 43 |
+
<< "scores lexical reordering models of several types (hierarchical, phrase-based and word-based-extraction\n";
|
| 44 |
+
|
| 45 |
+
if (argc < 3) {
|
| 46 |
+
cerr << "syntax: score_reordering extractFile smoothingValue filepath (--model \"type max-orientation (specification-strings)\" )+\n";
|
| 47 |
+
exit(1);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
char* extractFileName = argv[1];
|
| 51 |
+
double smoothingValue = atof(argv[2]);
|
| 52 |
+
string filepath = argv[3];
|
| 53 |
+
|
| 54 |
+
util::FilePiece eFile(extractFileName);
|
| 55 |
+
|
| 56 |
+
bool smoothWithCounts = false;
|
| 57 |
+
map<string,ModelScore*> modelScores;
|
| 58 |
+
vector<Model*> models;
|
| 59 |
+
bool hier = false;
|
| 60 |
+
bool phrase = false;
|
| 61 |
+
bool wbe = false;
|
| 62 |
+
|
| 63 |
+
StringPiece e,f,w,p,h;
|
| 64 |
+
StringPiece prev, next;
|
| 65 |
+
|
| 66 |
+
int i = 4;
|
| 67 |
+
while (i<argc) {
|
| 68 |
+
if (strcmp(argv[i],"--SmoothWithCounts") == 0) {
|
| 69 |
+
smoothWithCounts = true;
|
| 70 |
+
} else if (strcmp(argv[i],"--model") == 0) {
|
| 71 |
+
if (i+1 >= argc) {
|
| 72 |
+
cerr << "score: syntax error, no model information provided to the option" << argv[i] << endl;
|
| 73 |
+
exit(1);
|
| 74 |
+
}
|
| 75 |
+
istringstream is(argv[++i]);
|
| 76 |
+
string m,t;
|
| 77 |
+
is >> m >> t;
|
| 78 |
+
modelScores[m] = ModelScore::createModelScore(t);
|
| 79 |
+
if (m.compare("hier") == 0) {
|
| 80 |
+
hier = true;
|
| 81 |
+
} else if (m.compare("phrase") == 0) {
|
| 82 |
+
phrase = true;
|
| 83 |
+
}
|
| 84 |
+
if (m.compare("wbe") == 0) {
|
| 85 |
+
wbe = true;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
if (!hier && !phrase && !wbe) {
|
| 89 |
+
cerr << "WARNING: No models specified for lexical reordering. No lexical reordering table will be trained.\n";
|
| 90 |
+
return 0;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
string config;
|
| 94 |
+
//Store all models
|
| 95 |
+
while (is >> config) {
|
| 96 |
+
models.push_back(Model::createModel(modelScores[m],config,filepath));
|
| 97 |
+
}
|
| 98 |
+
} else {
|
| 99 |
+
cerr << "illegal option given to lexical reordering model score\n";
|
| 100 |
+
exit(1);
|
| 101 |
+
}
|
| 102 |
+
i++;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
////////////////////////////////////
|
| 106 |
+
//calculate smoothing
|
| 107 |
+
if (smoothWithCounts) {
|
| 108 |
+
util::FilePiece eFileForCounts(extractFileName);
|
| 109 |
+
while (true) {
|
| 110 |
+
StringPiece line;
|
| 111 |
+
try {
|
| 112 |
+
line = eFileForCounts.ReadLine();
|
| 113 |
+
} catch (util::EndOfFileException &e) {
|
| 114 |
+
break;
|
| 115 |
+
}
|
| 116 |
+
float weight = 1;
|
| 117 |
+
split_line(line,e,f,w,p,h,weight);
|
| 118 |
+
if (hier) {
|
| 119 |
+
get_orientations(h, prev, next);
|
| 120 |
+
modelScores["hier"]->add_example(prev,next,weight);
|
| 121 |
+
}
|
| 122 |
+
if (phrase) {
|
| 123 |
+
get_orientations(p, prev, next);
|
| 124 |
+
modelScores["phrase"]->add_example(prev,next,weight);
|
| 125 |
+
}
|
| 126 |
+
if (wbe) {
|
| 127 |
+
get_orientations(w, prev, next);
|
| 128 |
+
modelScores["wbe"]->add_example(prev,next,weight);
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
// calculate smoothing for each model
|
| 133 |
+
for (size_t i=0; i<models.size(); ++i) {
|
| 134 |
+
models[i]->createSmoothing(smoothingValue);
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
} else {
|
| 138 |
+
//constant smoothing
|
| 139 |
+
for (size_t i=0; i<models.size(); ++i) {
|
| 140 |
+
models[i]->createConstSmoothing(smoothingValue);
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
////////////////////////////////////
|
| 145 |
+
//calculate scores for reordering table
|
| 146 |
+
string f_current,e_current;
|
| 147 |
+
bool first = true;
|
| 148 |
+
while (true) {
|
| 149 |
+
StringPiece line;
|
| 150 |
+
try {
|
| 151 |
+
line = eFile.ReadLine();
|
| 152 |
+
} catch (util::EndOfFileException &e) {
|
| 153 |
+
break;
|
| 154 |
+
}
|
| 155 |
+
float weight = 1;
|
| 156 |
+
split_line(line,f,e,w,p,h,weight);
|
| 157 |
+
|
| 158 |
+
if (first) {
|
| 159 |
+
f_current = f.as_string(); //FIXME: Avoid the copy.
|
| 160 |
+
e_current = e.as_string();
|
| 161 |
+
first = false;
|
| 162 |
+
} else if (f.compare(f_current) != 0 || e.compare(e_current) != 0) {
|
| 163 |
+
//fe - score
|
| 164 |
+
for (size_t i=0; i<models.size(); ++i) {
|
| 165 |
+
models[i]->score_fe(f_current,e_current);
|
| 166 |
+
}
|
| 167 |
+
//reset
|
| 168 |
+
for(map<string,ModelScore*>::const_iterator it = modelScores.begin(); it != modelScores.end(); ++it) {
|
| 169 |
+
it->second->reset_fe();
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
if (f.compare(f_current) != 0) {
|
| 173 |
+
//f - score
|
| 174 |
+
for (size_t i=0; i<models.size(); ++i) {
|
| 175 |
+
models[i]->score_f(f_current);
|
| 176 |
+
}
|
| 177 |
+
//reset
|
| 178 |
+
for(map<string,ModelScore*>::const_iterator it = modelScores.begin(); it != modelScores.end(); ++it) {
|
| 179 |
+
it->second->reset_f();
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
f_current = f.as_string();
|
| 183 |
+
e_current = e.as_string();
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
// uppdate counts
|
| 187 |
+
if (hier) {
|
| 188 |
+
get_orientations(h, prev, next);
|
| 189 |
+
modelScores["hier"]->add_example(prev,next,weight);
|
| 190 |
+
}
|
| 191 |
+
if (phrase) {
|
| 192 |
+
get_orientations(p, prev, next);
|
| 193 |
+
modelScores["phrase"]->add_example(prev,next,weight);
|
| 194 |
+
}
|
| 195 |
+
if (wbe) {
|
| 196 |
+
get_orientations(w, prev, next);
|
| 197 |
+
modelScores["wbe"]->add_example(prev,next,weight);
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
//Score the last phrases
|
| 201 |
+
for (size_t i=0; i<models.size(); ++i) {
|
| 202 |
+
models[i]->score_fe(f_current,e_current);
|
| 203 |
+
}
|
| 204 |
+
for (size_t i=0; i<models.size(); ++i) {
|
| 205 |
+
models[i]->score_f(f_current);
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
// delete model objects (and close files)
|
| 209 |
+
for (size_t i=0; i<models.size(); ++i) {
|
| 210 |
+
delete models[i];
|
| 211 |
+
}
|
| 212 |
+
return 0;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
template <class It> StringPiece
|
| 216 |
+
GrabOrDie(It &it, const StringPiece& line)
|
| 217 |
+
{
|
| 218 |
+
UTIL_THROW_IF(!it, FileFormatException, line.as_string());
|
| 219 |
+
return *it++;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
void split_line(
|
| 224 |
+
const StringPiece& line,
|
| 225 |
+
StringPiece& foreign,
|
| 226 |
+
StringPiece& english,
|
| 227 |
+
StringPiece& wbe,
|
| 228 |
+
StringPiece& phrase,
|
| 229 |
+
StringPiece& hier,
|
| 230 |
+
float& weight)
|
| 231 |
+
{
|
| 232 |
+
/*Format is source ||| target ||| orientations
|
| 233 |
+
followed by one of the following 4 possibilities
|
| 234 |
+
eps
|
| 235 |
+
||| weight
|
| 236 |
+
| phrase | hier
|
| 237 |
+
| phrase | hier ||| weight
|
| 238 |
+
*/
|
| 239 |
+
|
| 240 |
+
util::TokenIter<util::MultiCharacter> pipes(line, util::MultiCharacter(" ||| "));
|
| 241 |
+
foreign = GrabOrDie(pipes,line);
|
| 242 |
+
english = GrabOrDie(pipes,line);
|
| 243 |
+
StringPiece next = GrabOrDie(pipes,line);
|
| 244 |
+
|
| 245 |
+
util::TokenIter<util::MultiCharacter> singlePipe(next, util::MultiCharacter(" | "));
|
| 246 |
+
wbe = GrabOrDie(singlePipe,line);
|
| 247 |
+
if (singlePipe) {
|
| 248 |
+
phrase = GrabOrDie(singlePipe, line);
|
| 249 |
+
hier = GrabOrDie(singlePipe, line);
|
| 250 |
+
} else {
|
| 251 |
+
phrase.clear();
|
| 252 |
+
hier.clear();
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
if (pipes) {
|
| 256 |
+
// read the weight
|
| 257 |
+
char* errIndex;
|
| 258 |
+
next = *pipes++;
|
| 259 |
+
weight = static_cast<float>(strtod(next.data(), &errIndex));
|
| 260 |
+
UTIL_THROW_IF(errIndex == next.data(), FileFormatException, line.as_string());
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
void get_orientations(const StringPiece& pair, StringPiece& previous, StringPiece& next)
|
| 265 |
+
{
|
| 266 |
+
util::TokenIter<util::SingleCharacter> tok(pair, util::SingleCharacter(' '));
|
| 267 |
+
previous = GrabOrDie(tok,pair);
|
| 268 |
+
next = GrabOrDie(tok,pair);
|
| 269 |
+
}
|
mosesdecoder/phrase-extract/pcfg-extract/Jamfile
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
exe pcfg-extract : [ glob *.cc ] ..//syntax-common ../..//boost_program_options : <include>.. ;
|
mosesdecoder/phrase-extract/pcfg-extract/options.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***********************************************************************
|
| 2 |
+
Moses - statistical machine translation system
|
| 3 |
+
Copyright (C) 2006-2012 University of Edinburgh
|
| 4 |
+
|
| 5 |
+
This library is free software; you can redistribute it and/or
|
| 6 |
+
modify it under the terms of the GNU Lesser General Public
|
| 7 |
+
License as published by the Free Software Foundation; either
|
| 8 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This library is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 13 |
+
Lesser General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU Lesser General Public
|
| 16 |
+
License along with this library; if not, write to the Free Software
|
| 17 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 18 |
+
***********************************************************************/
|
| 19 |
+
|
| 20 |
+
#pragma once
|
| 21 |
+
#ifndef PCFG_EXTRACT_OPTIONS_H_
|
| 22 |
+
#define PCFG_EXTRACT_OPTIONS_H_
|
| 23 |
+
|
| 24 |
+
#include <string>
|
| 25 |
+
|
| 26 |
+
namespace MosesTraining
|
| 27 |
+
{
|
| 28 |
+
namespace Syntax
|
| 29 |
+
{
|
| 30 |
+
namespace PCFG
|
| 31 |
+
{
|
| 32 |
+
|
| 33 |
+
struct Options {
|
| 34 |
+
std::string corpus_file;
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
} // namespace PCFG
|
| 38 |
+
} // namespace Syntax
|
| 39 |
+
} // namespace MosesTraining
|
| 40 |
+
|
| 41 |
+
#endif
|
mosesdecoder/phrase-extract/pcfg-extract/pcfg_extract.cc
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***********************************************************************
|
| 2 |
+
Moses - statistical machine translation system
|
| 3 |
+
Copyright (C) 2006-2012 University of Edinburgh
|
| 4 |
+
|
| 5 |
+
This library is free software; you can redistribute it and/or
|
| 6 |
+
modify it under the terms of the GNU Lesser General Public
|
| 7 |
+
License as published by the Free Software Foundation; either
|
| 8 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This library is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 13 |
+
Lesser General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU Lesser General Public
|
| 16 |
+
License along with this library; if not, write to the Free Software
|
| 17 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 18 |
+
***********************************************************************/
|
| 19 |
+
|
| 20 |
+
#include "pcfg_extract.h"
|
| 21 |
+
|
| 22 |
+
#include <cassert>
|
| 23 |
+
#include <cstdlib>
|
| 24 |
+
#include <fstream>
|
| 25 |
+
#include <iostream>
|
| 26 |
+
#include <map>
|
| 27 |
+
#include <memory>
|
| 28 |
+
#include <set>
|
| 29 |
+
#include <string>
|
| 30 |
+
#include <vector>
|
| 31 |
+
|
| 32 |
+
#include <boost/program_options.hpp>
|
| 33 |
+
|
| 34 |
+
#include "syntax-common/exception.h"
|
| 35 |
+
#include "syntax-common/pcfg.h"
|
| 36 |
+
#include "syntax-common/vocabulary.h"
|
| 37 |
+
#include "syntax-common/xml_tree_parser.h"
|
| 38 |
+
|
| 39 |
+
#include "SyntaxTree.h"
|
| 40 |
+
|
| 41 |
+
#include "options.h"
|
| 42 |
+
#include "rule_collection.h"
|
| 43 |
+
#include "rule_extractor.h"
|
| 44 |
+
|
| 45 |
+
namespace MosesTraining
|
| 46 |
+
{
|
| 47 |
+
namespace Syntax
|
| 48 |
+
{
|
| 49 |
+
namespace PCFG
|
| 50 |
+
{
|
| 51 |
+
|
| 52 |
+
int PcfgExtract::Main(int argc, char *argv[])
|
| 53 |
+
{
|
| 54 |
+
// Process command-line options.
|
| 55 |
+
Options options;
|
| 56 |
+
ProcessOptions(argc, argv, options);
|
| 57 |
+
|
| 58 |
+
// Extract PCFG rules from corpus.
|
| 59 |
+
Vocabulary non_term_vocab;
|
| 60 |
+
RuleExtractor rule_extractor(non_term_vocab);
|
| 61 |
+
RuleCollection rule_collection;
|
| 62 |
+
XmlTreeParser parser;
|
| 63 |
+
std::string line;
|
| 64 |
+
std::size_t line_num = 0;
|
| 65 |
+
std::auto_ptr<MosesTraining::SyntaxTree> tree;
|
| 66 |
+
while (std::getline(std::cin, line)) {
|
| 67 |
+
++line_num;
|
| 68 |
+
try {
|
| 69 |
+
tree = parser.Parse(line);
|
| 70 |
+
} catch (Exception &e) {
|
| 71 |
+
std::ostringstream msg;
|
| 72 |
+
msg << "line " << line_num << ": " << e.msg();
|
| 73 |
+
Error(msg.str());
|
| 74 |
+
}
|
| 75 |
+
if (!tree.get()) {
|
| 76 |
+
std::ostringstream msg;
|
| 77 |
+
msg << "no tree at line " << line_num;
|
| 78 |
+
Warn(msg.str());
|
| 79 |
+
continue;
|
| 80 |
+
}
|
| 81 |
+
rule_extractor.Extract(*tree, rule_collection);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// Score rules and write PCFG to output.
|
| 85 |
+
Pcfg pcfg;
|
| 86 |
+
rule_collection.CreatePcfg(pcfg);
|
| 87 |
+
pcfg.Write(non_term_vocab, std::cout);
|
| 88 |
+
|
| 89 |
+
return 0;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
void PcfgExtract::ProcessOptions(int argc, char *argv[],
|
| 93 |
+
Options &options) const
|
| 94 |
+
{
|
| 95 |
+
namespace po = boost::program_options;
|
| 96 |
+
|
| 97 |
+
std::ostringstream usage_top;
|
| 98 |
+
usage_top << "Usage: " << name() << "\n\n" << "Options";
|
| 99 |
+
|
| 100 |
+
// Declare the command line options that are visible to the user.
|
| 101 |
+
po::options_description visible(usage_top.str());
|
| 102 |
+
visible.add_options()
|
| 103 |
+
("help", "print help message and exit")
|
| 104 |
+
;
|
| 105 |
+
|
| 106 |
+
// Declare the command line options that are hidden from the user
|
| 107 |
+
// (these are used as positional options).
|
| 108 |
+
po::options_description hidden("Hidden options");
|
| 109 |
+
hidden.add_options();
|
| 110 |
+
|
| 111 |
+
// Compose the full set of command-line options.
|
| 112 |
+
po::options_description cmd_line_options;
|
| 113 |
+
cmd_line_options.add(visible).add(hidden);
|
| 114 |
+
|
| 115 |
+
// Register the positional options.
|
| 116 |
+
po::positional_options_description p;
|
| 117 |
+
|
| 118 |
+
// Process the command-line.
|
| 119 |
+
po::variables_map vm;
|
| 120 |
+
try {
|
| 121 |
+
po::store(po::command_line_parser(argc, argv).style(MosesOptionStyle()).
|
| 122 |
+
options(cmd_line_options).positional(p).run(), vm);
|
| 123 |
+
po::notify(vm);
|
| 124 |
+
} catch (const std::exception &e) {
|
| 125 |
+
std::ostringstream msg;
|
| 126 |
+
msg << e.what() << "\n\n" << visible;
|
| 127 |
+
Error(msg.str());
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
if (vm.count("help")) {
|
| 131 |
+
std::cout << visible << std::endl;
|
| 132 |
+
std::exit(0);
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
} // namespace PCFG
|
| 137 |
+
} // namespace Syntax
|
| 138 |
+
} // namespace MosesTraining
|
mosesdecoder/phrase-extract/pcfg-extract/pcfg_extract.h
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***********************************************************************
|
| 2 |
+
Moses - statistical machine translation system
|
| 3 |
+
Copyright (C) 2006-2012 University of Edinburgh
|
| 4 |
+
|
| 5 |
+
This library is free software; you can redistribute it and/or
|
| 6 |
+
modify it under the terms of the GNU Lesser General Public
|
| 7 |
+
License as published by the Free Software Foundation; either
|
| 8 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This library is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 13 |
+
Lesser General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU Lesser General Public
|
| 16 |
+
License along with this library; if not, write to the Free Software
|
| 17 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 18 |
+
***********************************************************************/
|
| 19 |
+
|
| 20 |
+
#pragma once
|
| 21 |
+
#ifndef PCFG_EXTRACT_PCFG_EXTRACT_H_
|
| 22 |
+
#define PCFG_EXTRACT_PCFG_EXTRACT_H_
|
| 23 |
+
|
| 24 |
+
#include "syntax-common/tool.h"
|
| 25 |
+
|
| 26 |
+
namespace MosesTraining
|
| 27 |
+
{
|
| 28 |
+
namespace Syntax
|
| 29 |
+
{
|
| 30 |
+
namespace PCFG
|
| 31 |
+
{
|
| 32 |
+
|
| 33 |
+
struct Options;
|
| 34 |
+
|
| 35 |
+
class PcfgExtract : public Tool
|
| 36 |
+
{
|
| 37 |
+
public:
|
| 38 |
+
PcfgExtract() : Tool("pcfg-extract") {}
|
| 39 |
+
virtual int Main(int, char *[]);
|
| 40 |
+
private:
|
| 41 |
+
void ProcessOptions(int, char *[], Options &) const;
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
} // namespace PCFG
|
| 45 |
+
} // namespace Syntax
|
| 46 |
+
} // namespace MosesTraining
|
| 47 |
+
|
| 48 |
+
#endif
|
mosesdecoder/phrase-extract/pcfg-extract/rule_collection.h
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***********************************************************************
|
| 2 |
+
Moses - statistical machine translation system
|
| 3 |
+
Copyright (C) 2006-2012 University of Edinburgh
|
| 4 |
+
|
| 5 |
+
This library is free software; you can redistribute it and/or
|
| 6 |
+
modify it under the terms of the GNU Lesser General Public
|
| 7 |
+
License as published by the Free Software Foundation; either
|
| 8 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This library is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 13 |
+
Lesser General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU Lesser General Public
|
| 16 |
+
License along with this library; if not, write to the Free Software
|
| 17 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 18 |
+
***********************************************************************/
|
| 19 |
+
|
| 20 |
+
#pragma once
|
| 21 |
+
#ifndef PCFG_EXTRACT_RULE_COLLECTION_H_
|
| 22 |
+
#define PCFG_EXTRACT_RULE_COLLECTION_H_
|
| 23 |
+
|
| 24 |
+
#include <vector>
|
| 25 |
+
|
| 26 |
+
#include <boost/unordered_map.hpp>
|
| 27 |
+
|
| 28 |
+
#include "syntax-common/pcfg.h"
|
| 29 |
+
|
| 30 |
+
namespace MosesTraining
|
| 31 |
+
{
|
| 32 |
+
namespace Syntax
|
| 33 |
+
{
|
| 34 |
+
namespace PCFG
|
| 35 |
+
{
|
| 36 |
+
|
| 37 |
+
// Contains PCFG rules and their counts.
|
| 38 |
+
class RuleCollection
|
| 39 |
+
{
|
| 40 |
+
public:
|
| 41 |
+
typedef boost::unordered_map<std::vector<std::size_t>, std::size_t> RhsCountMap;
|
| 42 |
+
typedef boost::unordered_map<std::size_t, RhsCountMap> Map;
|
| 43 |
+
typedef Map::iterator iterator;
|
| 44 |
+
typedef Map::const_iterator const_iterator;
|
| 45 |
+
|
| 46 |
+
RuleCollection() {}
|
| 47 |
+
|
| 48 |
+
iterator begin() {
|
| 49 |
+
return collection_.begin();
|
| 50 |
+
}
|
| 51 |
+
const_iterator begin() const {
|
| 52 |
+
return collection_.begin();
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
iterator end() {
|
| 56 |
+
return collection_.end();
|
| 57 |
+
}
|
| 58 |
+
const_iterator end() const {
|
| 59 |
+
return collection_.end();
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
void Add(std::size_t, const std::vector<std::size_t> &);
|
| 63 |
+
void CreatePcfg(Pcfg &);
|
| 64 |
+
|
| 65 |
+
private:
|
| 66 |
+
Map collection_;
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
+
} // namespace PCFG
|
| 70 |
+
} // namespace Synatx
|
| 71 |
+
} // namespace MosesTraining
|
| 72 |
+
|
| 73 |
+
#endif
|
mosesdecoder/phrase-extract/pcfg-extract/rule_extractor.h
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***********************************************************************
|
| 2 |
+
Moses - statistical machine translation system
|
| 3 |
+
Copyright (C) 2006-2012 University of Edinburgh
|
| 4 |
+
|
| 5 |
+
This library is free software; you can redistribute it and/or
|
| 6 |
+
modify it under the terms of the GNU Lesser General Public
|
| 7 |
+
License as published by the Free Software Foundation; either
|
| 8 |
+
version 2.1 of the License, or (at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This library is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
| 13 |
+
Lesser General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU Lesser General Public
|
| 16 |
+
License along with this library; if not, write to the Free Software
|
| 17 |
+
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
| 18 |
+
***********************************************************************/
|
| 19 |
+
|
| 20 |
+
#pragma once
|
| 21 |
+
#ifndef PCFG_EXTRACT_RULE_EXTRACTOR_H_
|
| 22 |
+
#define PCFG_EXTRACT_RULE_EXTRACTOR_H_
|
| 23 |
+
|
| 24 |
+
#include "SyntaxTree.h"
|
| 25 |
+
|
| 26 |
+
#include "syntax-common/vocabulary.h"
|
| 27 |
+
|
| 28 |
+
#include "rule_collection.h"
|
| 29 |
+
|
| 30 |
+
namespace MosesTraining
|
| 31 |
+
{
|
| 32 |
+
namespace Syntax
|
| 33 |
+
{
|
| 34 |
+
namespace PCFG
|
| 35 |
+
{
|
| 36 |
+
|
| 37 |
+
// Extracts PCFG rules from syntax trees and adds them to a RuleCollection.
|
| 38 |
+
class RuleExtractor
|
| 39 |
+
{
|
| 40 |
+
public:
|
| 41 |
+
RuleExtractor(Vocabulary &);
|
| 42 |
+
void Extract(const SyntaxTree &, RuleCollection &) const;
|
| 43 |
+
private:
|
| 44 |
+
Vocabulary &non_term_vocab_;
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
} // namespace PCFG
|
| 48 |
+
} // namespace Syntax
|
| 49 |
+
} // namespace MosesTraining
|
| 50 |
+
|
| 51 |
+
#endif
|