#!/usr/bin/env python # -*- coding: utf-8 -*- # Author: Rico Sennrich # # This file is part of moses. Its use is licensed under the GNU Lesser General # Public License version 2.1 or, at your option, any later version. # extract 5 vocabulary files from parsed corpus in moses XML format from __future__ import print_function, unicode_literals, division import sys import codecs import argparse from collections import Counter from textwrap import dedent # hack for python2/3 compatibility from io import open argparse.open = open try: from lxml import etree as ET except ImportError: from xml.etree import cElementTree as ET HELP_TEXT = dedent("""\ generate 5 vocabulary files from parsed corpus in moses XML format [PREFIX].special: around 40 symbols reserved for RDLM [PREFIX].preterminals: preterminal symbols [PREFIX].nonterminals: nonterminal symbols (which are not preterminal) [PREFIX].terminals: terminal symbols [PREFIX].all: all of the above """) def create_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description=HELP_TEXT) parser.add_argument( '--input', '-i', type=argparse.FileType('r'), default=sys.stdin, metavar='PATH', help="Input text (default: standard input).") parser.add_argument( '--output', '-o', type=str, default='vocab', metavar='PREFIX', help="Output prefix (default: 'vocab')") return parser def escape_text(s): s = s.replace('|', '|') # factor separator s = s.replace('[', '[') # syntax non-terminal s = s.replace(']', ']') # syntax non-terminal s = s.replace('\'', ''') # xml special character s = s.replace('"', '"') # xml special character return s def get_head(xml, args): """Deterministic heuristic to get head of subtree.""" head = None preterminal = None for child in xml: if not len(child): preterminal = child.get('label') head = escape_text(child.text.strip()) return head, preterminal return head, preterminal def get_vocab(xml, args): if len(xml): head, preterminal = get_head(xml, args) if not head: head = '' preterminal = '' heads[head] += 1 preterminals[preterminal] += 1 label = xml.get('label') nonterminals[label] += 1 for child in xml: if not len(child): continue get_vocab(child, args) def main(args): global heads global preterminals global nonterminals heads = Counter() preterminals = Counter() nonterminals = Counter() i = 0 for line in args.input: if i and not i % 50000: sys.stderr.write('.') if i and not i % 1000000: sys.stderr.write('{0}\n'.format(i)) if line == '\n': continue xml = ET.fromstring(line) get_vocab(xml, args) i += 1 special_tokens = [ '', '', '', '', '', '', '', '', '', '', '', '', '', ] for i in range(30): special_tokens.append(''.format(i)) f = open(args.output + '.special', 'w', encoding='UTF-8') for item in special_tokens: f.write(item + '\n') f.close() f = open(args.output + '.preterminals', 'w', encoding='UTF-8') for item in sorted(preterminals, key=preterminals.get, reverse=True): f.write(item + '\n') f.close() f = open(args.output + '.nonterminals', 'w', encoding='UTF-8') for item in sorted(nonterminals, key=nonterminals.get, reverse=True): f.write(item + '\n') f.close() f = open(args.output + '.terminals', 'w', encoding='UTF-8') for item in sorted(heads, key=heads.get, reverse=True): f.write(item + '\n') f.close() f = open(args.output + '.all', 'w', encoding='UTF-8') special_tokens_set = set(special_tokens) for item in sorted(nonterminals, key=nonterminals.get, reverse=True): if item not in special_tokens: special_tokens.append(item) special_tokens_set.add(item) for item in sorted(preterminals, key=preterminals.get, reverse=True): if item not in special_tokens: special_tokens.append(item) special_tokens_set.add(item) for item in special_tokens: f.write(item + '\n') i = len(special_tokens) for item in sorted(heads, key=heads.get, reverse=True): if item in special_tokens_set: continue i += 1 f.write(item + '\n') f.close() if __name__ == '__main__': if sys.version_info < (3, 0): sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) sys.stdin = codecs.getreader('UTF-8')(sys.stdin) parser = create_parser() args = parser.parse_args() main(args)