| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | from __future__ import print_function, unicode_literals, division |
| | import sys |
| | import codecs |
| | import argparse |
| | from collections import Counter |
| | from textwrap import dedent |
| |
|
| | |
| | from io import open |
| | argparse.open = open |
| |
|
| | try: |
| | from lxml import etree as ET |
| | except ImportError: |
| | from xml.etree import cElementTree as ET |
| |
|
| |
|
| | HELP_TEXT = dedent("""\ |
| | generate 5 vocabulary files from parsed corpus in moses XML format |
| | [PREFIX].special: around 40 symbols reserved for RDLM |
| | [PREFIX].preterminals: preterminal symbols |
| | [PREFIX].nonterminals: nonterminal symbols (which are not preterminal) |
| | [PREFIX].terminals: terminal symbols |
| | [PREFIX].all: all of the above |
| | """) |
| |
|
| |
|
| | def create_parser(): |
| | parser = argparse.ArgumentParser( |
| | formatter_class=argparse.RawDescriptionHelpFormatter, |
| | description=HELP_TEXT) |
| |
|
| | parser.add_argument( |
| | '--input', '-i', type=argparse.FileType('r'), default=sys.stdin, |
| | metavar='PATH', |
| | help="Input text (default: standard input).") |
| | parser.add_argument( |
| | '--output', '-o', type=str, default='vocab', metavar='PREFIX', |
| | help="Output prefix (default: 'vocab')") |
| |
|
| | return parser |
| |
|
| |
|
| | def escape_text(s): |
| | s = s.replace('|', '|') |
| | s = s.replace('[', '[') |
| | s = s.replace(']', ']') |
| | s = s.replace('\'', ''') |
| | s = s.replace('"', '"') |
| | return s |
| |
|
| |
|
| | def get_head(xml, args): |
| | """Deterministic heuristic to get head of subtree.""" |
| | head = None |
| | preterminal = None |
| | for child in xml: |
| | if not len(child): |
| | preterminal = child.get('label') |
| | head = escape_text(child.text.strip()) |
| | return head, preterminal |
| |
|
| | return head, preterminal |
| |
|
| |
|
| | def get_vocab(xml, args): |
| |
|
| | if len(xml): |
| |
|
| | head, preterminal = get_head(xml, args) |
| | if not head: |
| | head = '<null>' |
| | preterminal = '<null>' |
| |
|
| | heads[head] += 1 |
| | preterminals[preterminal] += 1 |
| |
|
| | label = xml.get('label') |
| |
|
| | nonterminals[label] += 1 |
| |
|
| | for child in xml: |
| | if not len(child): |
| | continue |
| | get_vocab(child, args) |
| |
|
| |
|
| | def main(args): |
| |
|
| | global heads |
| | global preterminals |
| | global nonterminals |
| |
|
| | heads = Counter() |
| | preterminals = Counter() |
| | nonterminals = Counter() |
| |
|
| | i = 0 |
| | for line in args.input: |
| | if i and not i % 50000: |
| | sys.stderr.write('.') |
| | if i and not i % 1000000: |
| | sys.stderr.write('{0}\n'.format(i)) |
| | if line == '\n': |
| | continue |
| |
|
| | xml = ET.fromstring(line) |
| | get_vocab(xml, args) |
| | i += 1 |
| |
|
| | special_tokens = [ |
| | '<unk>', |
| | '<null>', |
| | '<null_label>', |
| | '<null_head>', |
| | '<head_label>', |
| | '<root_label>', |
| | '<start_label>', |
| | '<stop_label>', |
| | '<head_head>', |
| | '<root_head>', |
| | '<start_head>', |
| | '<dummy_head>', |
| | '<stop_head>', |
| | ] |
| |
|
| | for i in range(30): |
| | special_tokens.append('<null_{0}>'.format(i)) |
| |
|
| | f = open(args.output + '.special', 'w', encoding='UTF-8') |
| | for item in special_tokens: |
| | f.write(item + '\n') |
| | f.close() |
| |
|
| | f = open(args.output + '.preterminals', 'w', encoding='UTF-8') |
| | for item in sorted(preterminals, key=preterminals.get, reverse=True): |
| | f.write(item + '\n') |
| | f.close() |
| |
|
| | f = open(args.output + '.nonterminals', 'w', encoding='UTF-8') |
| | for item in sorted(nonterminals, key=nonterminals.get, reverse=True): |
| | f.write(item + '\n') |
| | f.close() |
| |
|
| | f = open(args.output + '.terminals', 'w', encoding='UTF-8') |
| | for item in sorted(heads, key=heads.get, reverse=True): |
| | f.write(item + '\n') |
| | f.close() |
| |
|
| | f = open(args.output + '.all', 'w', encoding='UTF-8') |
| | special_tokens_set = set(special_tokens) |
| | for item in sorted(nonterminals, key=nonterminals.get, reverse=True): |
| | if item not in special_tokens: |
| | special_tokens.append(item) |
| | special_tokens_set.add(item) |
| | for item in sorted(preterminals, key=preterminals.get, reverse=True): |
| | if item not in special_tokens: |
| | special_tokens.append(item) |
| | special_tokens_set.add(item) |
| | for item in special_tokens: |
| | f.write(item + '\n') |
| | i = len(special_tokens) |
| |
|
| | for item in sorted(heads, key=heads.get, reverse=True): |
| | if item in special_tokens_set: |
| | continue |
| | i += 1 |
| | f.write(item + '\n') |
| | f.close() |
| |
|
| |
|
| | if __name__ == '__main__': |
| |
|
| | if sys.version_info < (3, 0): |
| | sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) |
| | sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) |
| | sys.stdin = codecs.getreader('UTF-8')(sys.stdin) |
| |
|
| | parser = create_parser() |
| | args = parser.parse_args() |
| | main(args) |
| |
|