File size: 5,671 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import argparse
import json
import os
import re
import sys

from collections import Counter

"""
Data is output in 4 files:

a file containing the mwt information
a file containing the words and sentences in conllu format
a file containing the raw text of each paragraph
a file of 0,1,2 indicating word break or sentence break on a character level for the raw text
  1: end of word
  2: end of sentence
"""

PARAGRAPH_BREAK = re.compile(r'\n\s*\n')

def is_para_break(index, text):
    """ Detect if a paragraph break can be found, and return the length of the paragraph break sequence. """
    if text[index] == '\n':
        para_break = PARAGRAPH_BREAK.match(text, index)
        if para_break:
            break_len = len(para_break.group(0))
            return True, break_len
    return False, 0

def find_next_word(index, text, word, output):
    """
    Locate the next word in the text. In case a paragraph break is found, also write paragraph break to labels.
    """
    idx = 0
    word_sofar = ''
    while index < len(text) and idx < len(word):
        para_break, break_len = is_para_break(index, text)
        if para_break:
            # multiple newlines found, paragraph break
            if len(word_sofar) > 0:
                assert re.match(r'^\s+$', word_sofar), 'Found non-empty string at the end of a paragraph that doesn\'t match any token: |{}|'.format(word_sofar)
                word_sofar = ''

            output.write('\n\n')
            index += break_len - 1
        elif re.match(r'^\s$', text[index]) and not re.match(r'^\s$', word[idx]):
            # whitespace found, and whitespace is not part of a word
            word_sofar += text[index]
        else:
            # non-whitespace char, or a whitespace char that's part of a word
            word_sofar += text[index]
            assert text[index].replace('\n', ' ') == word[idx], "Character mismatch: raw text contains |%s| but the next word is |%s|." % (word_sofar, word)
            idx += 1
        index += 1
    return index, word_sofar

def main(args):
    parser = argparse.ArgumentParser()

    parser.add_argument('plaintext_file', type=str, help="Plaintext file containing the raw input")
    parser.add_argument('conllu_file', type=str, help="CoNLL-U file containing tokens and sentence breaks")
    parser.add_argument('-o', '--output', default=None, type=str, help="Output file name; output to the console if not specified (the default)")
    parser.add_argument('-m', '--mwt_output', default=None, type=str, help="Output file name for MWT expansions; output to the console if not specified (the default)")

    args = parser.parse_args(args=args)

    with open(args.plaintext_file, 'r', encoding='utf-8') as f:
        text = ''.join(f.readlines())
    textlen = len(text)

    if args.output is None:
        output = sys.stdout
    else:
        outdir = os.path.split(args.output)[0]
        os.makedirs(outdir, exist_ok=True)
        output = open(args.output, 'w')

    index = 0 # character offset in rawtext

    mwt_expansions = []
    with open(args.conllu_file, 'r', encoding='utf-8') as f:
        buf = ''
        mwtbegin = 0
        mwtend = -1
        expanded = []
        last_comments = ""
        for line in f:
            line = line.strip()
            if len(line):
                if line[0] == "#":
                    # comment, don't do anything
                    if len(last_comments) == 0:
                        last_comments = line
                    continue

                line = line.split('\t')
                if '.' in line[0]:
                    # the tokenizer doesn't deal with ellipsis
                    continue

                word = line[1]
                if '-' in line[0]:
                    # multiword token
                    mwtbegin, mwtend = [int(x) for x in line[0].split('-')]
                    lastmwt = word
                    expanded = []
                elif mwtbegin <= int(line[0]) < mwtend:
                    expanded += [word]
                    continue
                elif int(line[0]) == mwtend:
                    expanded += [word]
                    expanded = [x.lower() for x in expanded] # evaluation doesn't care about case
                    mwt_expansions += [(lastmwt, tuple(expanded))]
                    if lastmwt[0].islower() and not expanded[0][0].islower():
                        print('Sentence ID with potential wrong MWT expansion: ', last_comments, file=sys.stderr)
                    mwtbegin = 0
                    mwtend = -1
                    lastmwt = None
                    continue

                if len(buf):
                    output.write(buf)
                index, word_found = find_next_word(index, text, word, output)
                buf = '0' * (len(word_found)-1) + ('1' if '-' not in line[0] else '3')
            else:
                # sentence break found
                if len(buf):
                    assert int(buf[-1]) >= 1
                    output.write(buf[:-1] + '{}'.format(int(buf[-1]) + 1))
                    buf = ''

                last_comments = ''

    status_line = ""
    if args.output:
        output.close()
        status_line = 'Tokenizer labels written to %s\n  ' % args.output

    mwts = Counter(mwt_expansions)
    if args.mwt_output is None:
        print('MWTs:', mwts)
    else:
        with open(args.mwt_output, 'w') as f:
            json.dump(list(mwts.items()), f, indent=2)

        status_line = status_line + '{} unique MWTs found in data.  MWTs written to {}'.format(len(mwts), args.mwt_output)
        print(status_line)

if __name__ == '__main__':
    main(sys.argv[1:])