|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function, unicode_literals, division |
|
|
import sys |
|
|
import codecs |
|
|
import argparse |
|
|
from collections import Counter |
|
|
from textwrap import dedent |
|
|
|
|
|
|
|
|
from io import open |
|
|
argparse.open = open |
|
|
|
|
|
try: |
|
|
from lxml import etree as ET |
|
|
except ImportError: |
|
|
from xml.etree import cElementTree as ET |
|
|
|
|
|
|
|
|
HELP_TEXT = dedent("""\ |
|
|
generate 5 vocabulary files from parsed corpus in moses XML format |
|
|
[PREFIX].special: around 40 symbols reserved for RDLM |
|
|
[PREFIX].preterminals: preterminal symbols |
|
|
[PREFIX].nonterminals: nonterminal symbols (which are not preterminal) |
|
|
[PREFIX].terminals: terminal symbols |
|
|
[PREFIX].all: all of the above |
|
|
""") |
|
|
|
|
|
|
|
|
def create_parser(): |
|
|
parser = argparse.ArgumentParser( |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
description=HELP_TEXT) |
|
|
|
|
|
parser.add_argument( |
|
|
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin, |
|
|
metavar='PATH', |
|
|
help="Input text (default: standard input).") |
|
|
parser.add_argument( |
|
|
'--output', '-o', type=str, default='vocab', metavar='PREFIX', |
|
|
help="Output prefix (default: 'vocab')") |
|
|
|
|
|
return parser |
|
|
|
|
|
|
|
|
def escape_text(s): |
|
|
s = s.replace('|', '|') |
|
|
s = s.replace('[', '[') |
|
|
s = s.replace(']', ']') |
|
|
s = s.replace('\'', ''') |
|
|
s = s.replace('"', '"') |
|
|
return s |
|
|
|
|
|
|
|
|
def get_head(xml, args): |
|
|
"""Deterministic heuristic to get head of subtree.""" |
|
|
head = None |
|
|
preterminal = None |
|
|
for child in xml: |
|
|
if not len(child): |
|
|
preterminal = child.get('label') |
|
|
head = escape_text(child.text.strip()) |
|
|
return head, preterminal |
|
|
|
|
|
return head, preterminal |
|
|
|
|
|
|
|
|
def get_vocab(xml, args): |
|
|
|
|
|
if len(xml): |
|
|
|
|
|
head, preterminal = get_head(xml, args) |
|
|
if not head: |
|
|
head = '<null>' |
|
|
preterminal = '<null>' |
|
|
|
|
|
heads[head] += 1 |
|
|
preterminals[preterminal] += 1 |
|
|
|
|
|
label = xml.get('label') |
|
|
|
|
|
nonterminals[label] += 1 |
|
|
|
|
|
for child in xml: |
|
|
if not len(child): |
|
|
continue |
|
|
get_vocab(child, args) |
|
|
|
|
|
|
|
|
def main(args): |
|
|
|
|
|
global heads |
|
|
global preterminals |
|
|
global nonterminals |
|
|
|
|
|
heads = Counter() |
|
|
preterminals = Counter() |
|
|
nonterminals = Counter() |
|
|
|
|
|
i = 0 |
|
|
for line in args.input: |
|
|
if i and not i % 50000: |
|
|
sys.stderr.write('.') |
|
|
if i and not i % 1000000: |
|
|
sys.stderr.write('{0}\n'.format(i)) |
|
|
if line == '\n': |
|
|
continue |
|
|
|
|
|
xml = ET.fromstring(line) |
|
|
get_vocab(xml, args) |
|
|
i += 1 |
|
|
|
|
|
special_tokens = [ |
|
|
'<unk>', |
|
|
'<null>', |
|
|
'<null_label>', |
|
|
'<null_head>', |
|
|
'<head_label>', |
|
|
'<root_label>', |
|
|
'<start_label>', |
|
|
'<stop_label>', |
|
|
'<head_head>', |
|
|
'<root_head>', |
|
|
'<start_head>', |
|
|
'<dummy_head>', |
|
|
'<stop_head>', |
|
|
] |
|
|
|
|
|
for i in range(30): |
|
|
special_tokens.append('<null_{0}>'.format(i)) |
|
|
|
|
|
f = open(args.output + '.special', 'w', encoding='UTF-8') |
|
|
for item in special_tokens: |
|
|
f.write(item + '\n') |
|
|
f.close() |
|
|
|
|
|
f = open(args.output + '.preterminals', 'w', encoding='UTF-8') |
|
|
for item in sorted(preterminals, key=preterminals.get, reverse=True): |
|
|
f.write(item + '\n') |
|
|
f.close() |
|
|
|
|
|
f = open(args.output + '.nonterminals', 'w', encoding='UTF-8') |
|
|
for item in sorted(nonterminals, key=nonterminals.get, reverse=True): |
|
|
f.write(item + '\n') |
|
|
f.close() |
|
|
|
|
|
f = open(args.output + '.terminals', 'w', encoding='UTF-8') |
|
|
for item in sorted(heads, key=heads.get, reverse=True): |
|
|
f.write(item + '\n') |
|
|
f.close() |
|
|
|
|
|
f = open(args.output + '.all', 'w', encoding='UTF-8') |
|
|
special_tokens_set = set(special_tokens) |
|
|
for item in sorted(nonterminals, key=nonterminals.get, reverse=True): |
|
|
if item not in special_tokens: |
|
|
special_tokens.append(item) |
|
|
special_tokens_set.add(item) |
|
|
for item in sorted(preterminals, key=preterminals.get, reverse=True): |
|
|
if item not in special_tokens: |
|
|
special_tokens.append(item) |
|
|
special_tokens_set.add(item) |
|
|
for item in special_tokens: |
|
|
f.write(item + '\n') |
|
|
i = len(special_tokens) |
|
|
|
|
|
for item in sorted(heads, key=heads.get, reverse=True): |
|
|
if item in special_tokens_set: |
|
|
continue |
|
|
i += 1 |
|
|
f.write(item + '\n') |
|
|
f.close() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
if sys.version_info < (3, 0): |
|
|
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) |
|
|
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) |
|
|
sys.stdin = codecs.getreader('UTF-8')(sys.stdin) |
|
|
|
|
|
parser = create_parser() |
|
|
args = parser.parse_args() |
|
|
main(args) |
|
|
|