File size: 6,323 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import os
import logging
from stanza.models.common import utils
from stanza.models.constituency.utils import retag_tags
from stanza.models.constituency.trainer import Trainer
from stanza.models.constituency.tree_reader import read_trees
from stanza.utils.get_tqdm import get_tqdm
logger = logging.getLogger('stanza')
tqdm = get_tqdm()
def read_tokenized_file(tokenized_file):
"""
Read sentences from a tokenized file, potentially replacing _ with space for languages such as VI
"""
with open(tokenized_file, encoding='utf-8') as fin:
lines = fin.readlines()
lines = [x.strip() for x in lines]
lines = [x for x in lines if x]
docs = [[word if all(x == '_' for x in word) else word.replace("_", " ") for word in sentence.split()] for sentence in lines]
ids = [None] * len(docs)
return docs, ids
def read_xml_tree_file(tree_file):
"""
Read sentences from a file of the format unique to VLSP test sets
in particular, it should be multiple blocks of
<s id=1>
(tree ...)
</s>
"""
with open(tree_file, encoding='utf-8') as fin:
lines = fin.readlines()
lines = [x.strip() for x in lines]
lines = [x for x in lines if x]
docs = []
ids = []
tree_id = None
tree_text = []
for line in lines:
if line.startswith("<s"):
tree_id = line.split("=")
if len(tree_id) > 1:
tree_id = tree_id[1]
if tree_id.endswith(">"):
tree_id = tree_id[:-1]
tree_id = int(tree_id)
else:
tree_id = None
elif line.startswith("</s"):
if len(tree_text) == 0:
raise ValueError("Found a blank tree in %s" % tree_file)
ids.append(tree_id)
tree_text = "\n".join(tree_text)
trees = read_trees(tree_text)
# TODO: perhaps the processing can be put into read_trees instead
trees = [t.prune_none().simplify_labels() for t in trees]
if len(trees) != 1:
raise ValueError("Found a tree with %d trees in %s" % (len(trees), tree_file))
tree = trees[0]
text = tree.leaf_labels()
text = [word if all(x == '_' for x in word) else word.replace("_", " ") for word in text]
docs.append(text)
tree_text = []
tree_id = None
else:
tree_text.append(line)
return docs, ids
def parse_tokenized_sentences(args, model, retag_pipeline, sentences):
"""
Parse the given sentences, return a list of ParseResult objects
"""
tags = retag_tags(sentences, retag_pipeline, model.uses_xpos())
words = [[(word, tag) for word, tag in zip(s_words, s_tags)] for s_words, s_tags in zip(sentences, tags)]
logger.info("Retagging finished. Parsing tagged text")
assert len(words) == len(sentences)
treebank = model.parse_sentences_no_grad(iter(tqdm(words)), model.build_batch_from_tagged_words, args['eval_batch_size'], model.predict, keep_scores=False)
return treebank
def parse_text(args, model, retag_pipeline, tokenized_file=None, predict_file=None):
"""
Use the given model to parse text and write it
refactored so it can be used elsewhere, such as Ensemble
"""
model.eval()
if predict_file is None:
if args['predict_file']:
predict_file = args['predict_file']
if args['predict_dir']:
predict_file = os.path.join(args['predict_dir'], predict_file)
if tokenized_file is None:
tokenized_file = args['tokenized_file']
docs, ids = None, None
if tokenized_file is not None:
docs, ids = read_tokenized_file(tokenized_file)
elif args['xml_tree_file']:
logger.info("Reading trees from %s" % args['xml_tree_file'])
docs, ids = read_xml_tree_file(args['xml_tree_file'])
if not docs:
logger.error("No sentences to process!")
return
logger.info("Processing %d sentences", len(docs))
with utils.output_stream(predict_file) as fout:
chunk_size = 10000
for chunk_start in range(0, len(docs), chunk_size):
chunk = docs[chunk_start:chunk_start+chunk_size]
ids_chunk = ids[chunk_start:chunk_start+chunk_size]
logger.info("Processing trees %d to %d", chunk_start, chunk_start+len(chunk))
treebank = parse_tokenized_sentences(args, model, retag_pipeline, chunk)
for result, tree_id in zip(treebank, ids_chunk):
tree = result.predictions[0].tree
if tree_id is not None:
tree.tree_id = tree_id
fout.write(args['predict_format'].format(tree))
fout.write("\n")
def parse_dir(args, model, retag_pipeline, tokenized_dir, predict_dir):
os.makedirs(predict_dir, exist_ok=True)
for filename in os.listdir(tokenized_dir):
input_path = os.path.join(tokenized_dir, filename)
output_path = os.path.join(predict_dir, os.path.splitext(filename)[0] + ".mrg")
logger.info("Processing %s to %s", input_path, output_path)
parse_text(args, model, retag_pipeline, tokenized_file=input_path, predict_file=output_path)
def load_model_parse_text(args, model_file, retag_pipeline):
"""
Load a model, then parse text and write it to stdout or args['predict_file']
retag_pipeline: a list of Pipeline meant to use for retagging
"""
foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
load_args = {
'wordvec_pretrain_file': args['wordvec_pretrain_file'],
'charlm_forward_file': args['charlm_forward_file'],
'charlm_backward_file': args['charlm_backward_file'],
'device': args['device'],
}
trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache)
model = trainer.model
model.eval()
logger.info("Loaded model from %s", model_file)
if args['tokenized_dir']:
if not args['predict_dir']:
raise ValueError("Must specific --predict_dir to go with --tokenized_dir")
parse_dir(args, model, retag_pipeline, args['tokenized_dir'], args['predict_dir'])
else:
parse_text(args, model, retag_pipeline)
|