File size: 6,323 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os

import logging

from stanza.models.common import utils
from stanza.models.constituency.utils import retag_tags
from stanza.models.constituency.trainer import Trainer
from stanza.models.constituency.tree_reader import read_trees
from stanza.utils.get_tqdm import get_tqdm

logger = logging.getLogger('stanza')
tqdm = get_tqdm()

def read_tokenized_file(tokenized_file):
    """
    Read sentences from a tokenized file, potentially replacing _ with space for languages such as VI
    """
    with open(tokenized_file, encoding='utf-8') as fin:
        lines = fin.readlines()
    lines = [x.strip() for x in lines]
    lines = [x for x in lines if x]
    docs = [[word if all(x == '_' for x in word) else word.replace("_", " ") for word in sentence.split()] for sentence in lines]
    ids = [None] * len(docs)
    return docs, ids

def read_xml_tree_file(tree_file):
    """
    Read sentences from a file of the format unique to VLSP test sets

    in particular, it should be multiple blocks of

    <s id=1>
      (tree ...)
    </s>
    """
    with open(tree_file, encoding='utf-8') as fin:
        lines = fin.readlines()
    lines = [x.strip() for x in lines]
    lines = [x for x in lines if x]
    docs = []
    ids = []
    tree_id = None
    tree_text = []
    for line in lines:
        if line.startswith("<s"):
            tree_id = line.split("=")
            if len(tree_id) > 1:
                tree_id = tree_id[1]
                if tree_id.endswith(">"):
                    tree_id = tree_id[:-1]
                tree_id = int(tree_id)
            else:
                tree_id = None
        elif line.startswith("</s"):
            if len(tree_text) == 0:
                raise ValueError("Found a blank tree in %s" % tree_file)
            ids.append(tree_id)
            tree_text = "\n".join(tree_text)
            trees = read_trees(tree_text)
            # TODO: perhaps the processing can be put into read_trees instead
            trees = [t.prune_none().simplify_labels() for t in trees]
            if len(trees) != 1:
                raise ValueError("Found a tree with %d trees in %s" % (len(trees), tree_file))
            tree = trees[0]
            text = tree.leaf_labels()
            text = [word if all(x == '_' for x in word) else word.replace("_", " ") for word in text]
            docs.append(text)
            tree_text = []
            tree_id = None
        else:
            tree_text.append(line)

    return docs, ids


def parse_tokenized_sentences(args, model, retag_pipeline, sentences):
    """
    Parse the given sentences, return a list of ParseResult objects
    """
    tags = retag_tags(sentences, retag_pipeline, model.uses_xpos())
    words = [[(word, tag) for word, tag in zip(s_words, s_tags)] for s_words, s_tags in zip(sentences, tags)]
    logger.info("Retagging finished.  Parsing tagged text")

    assert len(words) == len(sentences)
    treebank = model.parse_sentences_no_grad(iter(tqdm(words)), model.build_batch_from_tagged_words, args['eval_batch_size'], model.predict, keep_scores=False)
    return treebank

def parse_text(args, model, retag_pipeline, tokenized_file=None, predict_file=None):
    """
    Use the given model to parse text and write it

    refactored so it can be used elsewhere, such as Ensemble
    """
    model.eval()

    if predict_file is None:
        if args['predict_file']:
            predict_file = args['predict_file']
            if args['predict_dir']:
                predict_file = os.path.join(args['predict_dir'], predict_file)

    if tokenized_file is None:
        tokenized_file = args['tokenized_file']

    docs, ids = None, None
    if tokenized_file is not None:
        docs, ids = read_tokenized_file(tokenized_file)
    elif args['xml_tree_file']:
        logger.info("Reading trees from %s" % args['xml_tree_file'])
        docs, ids = read_xml_tree_file(args['xml_tree_file'])

    if not docs:
        logger.error("No sentences to process!")
        return

    logger.info("Processing %d sentences", len(docs))

    with utils.output_stream(predict_file) as fout:
        chunk_size = 10000
        for chunk_start in range(0, len(docs), chunk_size):
            chunk = docs[chunk_start:chunk_start+chunk_size]
            ids_chunk = ids[chunk_start:chunk_start+chunk_size]
            logger.info("Processing trees %d to %d", chunk_start, chunk_start+len(chunk))
            treebank = parse_tokenized_sentences(args, model, retag_pipeline, chunk)

            for result, tree_id in zip(treebank, ids_chunk):
                tree = result.predictions[0].tree
                if tree_id is not None:
                    tree.tree_id = tree_id
                fout.write(args['predict_format'].format(tree))
                fout.write("\n")

def parse_dir(args, model, retag_pipeline, tokenized_dir, predict_dir):
    os.makedirs(predict_dir, exist_ok=True)
    for filename in os.listdir(tokenized_dir):
        input_path = os.path.join(tokenized_dir, filename)
        output_path = os.path.join(predict_dir, os.path.splitext(filename)[0] + ".mrg")
        logger.info("Processing %s to %s", input_path, output_path)
        parse_text(args, model, retag_pipeline, tokenized_file=input_path, predict_file=output_path)


def load_model_parse_text(args, model_file, retag_pipeline):
    """
    Load a model, then parse text and write it to stdout or args['predict_file']

    retag_pipeline: a list of Pipeline meant to use for retagging
    """
    foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
    load_args = {
        'wordvec_pretrain_file': args['wordvec_pretrain_file'],
        'charlm_forward_file': args['charlm_forward_file'],
        'charlm_backward_file': args['charlm_backward_file'],
        'device': args['device'],
    }
    trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache)
    model = trainer.model
    model.eval()
    logger.info("Loaded model from %s", model_file)

    if args['tokenized_dir']:
        if not args['predict_dir']:
            raise ValueError("Must specific --predict_dir to go with --tokenized_dir")
        parse_dir(args, model, retag_pipeline, args['tokenized_dir'], args['predict_dir'])
    else:
        parse_text(args, model, retag_pipeline)