File size: 6,069 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
"""Stanza models classifier data functions."""
import collections
from collections import namedtuple
import logging
import json
import random
import re
from typing import List
from stanza.models.classifiers.utils import WVType
from stanza.models.common.vocab import PAD, PAD_ID, UNK, UNK_ID
import stanza.models.constituency.tree_reader as tree_reader
logger = logging.getLogger('stanza')
class SentimentDatum:
def __init__(self, sentiment, text, constituency=None):
self.sentiment = sentiment
self.text = text
self.constituency = constituency
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, SentimentDatum):
return False
return self.sentiment == other.sentiment and self.text == other.text and self.constituency == other.constituency
def __str__(self):
return str(self._asdict())
def _asdict(self):
if self.constituency is None:
return {'sentiment': self.sentiment, 'text': self.text}
else:
return {'sentiment': self.sentiment, 'text': self.text, 'constituency': str(self.constituency)}
def update_text(sentence: List[str], wordvec_type: WVType) -> List[str]:
"""
Process a line of text (with tokenization provided as whitespace)
into a list of strings.
"""
# stanford sentiment dataset has a lot of random - and /
# remove those characters and flatten the newly created sublists into one list each time
sentence = [y for x in sentence for y in x.split("-") if y]
sentence = [y for x in sentence for y in x.split("/") if y]
sentence = [x.strip() for x in sentence]
sentence = [x for x in sentence if x]
if sentence == []:
# removed too much
sentence = ["-"]
# our current word vectors are all entirely lowercased
sentence = [word.lower() for word in sentence]
if wordvec_type == WVType.WORD2VEC:
return sentence
elif wordvec_type == WVType.GOOGLE:
new_sentence = []
for word in sentence:
if word != '0' and word != '1':
word = re.sub('[0-9]', '#', word)
new_sentence.append(word)
return new_sentence
elif wordvec_type == WVType.FASTTEXT:
return sentence
elif wordvec_type == WVType.OTHER:
return sentence
else:
raise ValueError("Unknown wordvec_type {}".format(wordvec_type))
def read_dataset(dataset, wordvec_type: WVType, min_len: int) -> List[SentimentDatum]:
"""
returns a list where the values of the list are
label, [token...]
"""
lines = []
for filename in str(dataset).split(","):
with open(filename, encoding="utf-8") as fin:
new_lines = json.load(fin)
new_lines = [(str(x['sentiment']), x['text'], x.get('constituency', None)) for x in new_lines]
lines.extend(new_lines)
# TODO: maybe do this processing later, once the model is built.
# then move the processing into the model so we can use
# overloading to potentially make future model types
lines = [SentimentDatum(x[0], update_text(x[1], wordvec_type), tree_reader.read_trees(x[2])[0] if x[2] else None) for x in lines]
if min_len:
lines = [x for x in lines if len(x.text) >= min_len]
return lines
def dataset_labels(dataset):
"""
Returns a sorted list of label name
"""
labels = set([x.sentiment for x in dataset])
if all(re.match("^[0-9]+$", label) for label in labels):
# if all of the labels are integers, sort numerically
# maybe not super important, but it would be nicer than having
# 10 before 2
labels = [str(x) for x in sorted(map(int, list(labels)))]
else:
labels = sorted(list(labels))
return labels
def dataset_vocab(dataset):
vocab = set()
for line in dataset:
for word in line.text:
vocab.add(word)
vocab = [PAD, UNK] + list(vocab)
if vocab[PAD_ID] != PAD or vocab[UNK_ID] != UNK:
raise ValueError("Unexpected values for PAD and UNK!")
return vocab
def sort_dataset_by_len(dataset, keep_index=False):
"""
returns a dict mapping length -> list of items of that length
an OrderedDict is used so that the mapping is sorted from smallest to largest
"""
sorted_dataset = collections.OrderedDict()
lengths = sorted(list(set(len(x.text) for x in dataset)))
for l in lengths:
sorted_dataset[l] = []
for item_idx, item in enumerate(dataset):
if keep_index:
sorted_dataset[len(item.text)].append((item, item_idx))
else:
sorted_dataset[len(item.text)].append(item)
return sorted_dataset
def shuffle_dataset(sorted_dataset, batch_size, batch_single_item):
"""
Given a dataset sorted by len, sorts within each length to make
chunks of roughly the same size. Returns all items as a single list.
"""
dataset = []
for l in sorted_dataset.keys():
items = list(sorted_dataset[l])
random.shuffle(items)
dataset.extend(items)
batches = []
next_batch = []
for item in dataset:
if batch_single_item > 0 and len(item.text) >= batch_single_item:
batches.append([item])
else:
next_batch.append(item)
if len(next_batch) >= batch_size:
batches.append(next_batch)
next_batch = []
if len(next_batch) > 0:
batches.append(next_batch)
random.shuffle(batches)
return batches
def check_labels(labels, dataset):
"""
Check that all of the labels in the dataset are in the known labels.
Actually, unknown labels could be acceptable if we just treat the model as always wrong.
However, this is a good sanity check to make sure the datasets match
"""
new_labels = dataset_labels(dataset)
not_found = [i for i in new_labels if i not in labels]
if not_found:
raise RuntimeError('Dataset contains labels which the model does not know about:' + str(not_found))
|