|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import pickle |
|
|
import re |
|
|
import string |
|
|
from typing import Dict |
|
|
|
|
|
from nemo.collections.nlp.data.data_utils.data_preprocessing import ( |
|
|
fill_class_weights, |
|
|
get_freq_weights, |
|
|
get_label_stats, |
|
|
) |
|
|
from nemo.utils import logging |
|
|
|
|
|
__all__ = ['get_label_ids', 'create_text_and_labels'] |
|
|
|
|
|
|
|
|
def remove_punctuation(word: str): |
|
|
""" |
|
|
Removes all punctuation marks from a word except for ' |
|
|
that is often a part of word: don't, it's, and so on |
|
|
""" |
|
|
all_punct_marks = string.punctuation.replace("'", '') |
|
|
return re.sub('[' + all_punct_marks + ']', '', word) |
|
|
|
|
|
|
|
|
def create_text_and_labels(output_dir: str, file_path: str, punct_marks: str = ',.?'): |
|
|
""" |
|
|
Create datasets for training and evaluation. |
|
|
|
|
|
Args: |
|
|
output_dir: path to the output data directory |
|
|
file_path: path to file name |
|
|
punct_marks: supported punctuation marks |
|
|
|
|
|
The data will be split into 2 files: text.txt and labels.txt. \ |
|
|
Each line of the text.txt file contains text sequences, where words\ |
|
|
are separated with spaces. The labels.txt file contains \ |
|
|
corresponding labels for each word in text.txt, the labels are \ |
|
|
separated with spaces. Each line of the files should follow the \ |
|
|
format: \ |
|
|
[WORD] [SPACE] [WORD] [SPACE] [WORD] (for text.txt) and \ |
|
|
[LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).' |
|
|
""" |
|
|
if not os.path.exists(file_path): |
|
|
raise ValueError(f'{file_path} not found') |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
base_name = os.path.basename(file_path) |
|
|
labels_file = os.path.join(output_dir, 'labels_' + base_name) |
|
|
text_file = os.path.join(output_dir, 'text_' + base_name) |
|
|
|
|
|
with open(file_path, 'r') as f: |
|
|
with open(text_file, 'w') as text_f: |
|
|
with open(labels_file, 'w') as labels_f: |
|
|
for line in f: |
|
|
line = line.split() |
|
|
text = '' |
|
|
labels = '' |
|
|
for word in line: |
|
|
label = word[-1] if word[-1] in punct_marks else 'O' |
|
|
word = remove_punctuation(word) |
|
|
if len(word) > 0: |
|
|
if word[0].isupper(): |
|
|
label += 'U' |
|
|
else: |
|
|
label += 'O' |
|
|
|
|
|
word = word.lower() |
|
|
text += word + ' ' |
|
|
labels += label + ' ' |
|
|
|
|
|
text_f.write(text.strip() + '\n') |
|
|
labels_f.write(labels.strip() + '\n') |
|
|
|
|
|
print(f'{text_file} and {labels_file} created from {file_path}.') |
|
|
|
|
|
|
|
|
def get_label_ids( |
|
|
label_file: str, |
|
|
is_training: bool = False, |
|
|
pad_label: str = 'O', |
|
|
label_ids_dict: Dict[str, int] = None, |
|
|
get_weights: bool = True, |
|
|
class_labels_file_artifact='label_ids.csv', |
|
|
): |
|
|
""" |
|
|
Generates str to int labels mapping for training data or checks correctness of the label_ids_dict |
|
|
file for non-training files or if label_ids_dict is specified |
|
|
|
|
|
Args: |
|
|
label_file: the path of the label file to process |
|
|
is_training: indicates whether the label_file is used for training |
|
|
pad_label: token used for padding |
|
|
label_ids_dict: str label name to int ids mapping. Required for non-training data. |
|
|
If specified, the check that all labels from label_file are present in label_ids_dict will be performed. |
|
|
For training data, if label_ids_dict is None, a new mapping will be generated from label_file. |
|
|
get_weights: set to True to calculate class weights, required for Weighted Loss. |
|
|
class_labels_file_artifact: name of the file to save in .nemo |
|
|
""" |
|
|
if not os.path.exists(label_file): |
|
|
raise ValueError(f'File {label_file} was not found.') |
|
|
|
|
|
logging.info(f'Processing {label_file}') |
|
|
if not is_training and label_ids_dict is None: |
|
|
raise ValueError( |
|
|
f'For non training data, label_ids_dict created during preprocessing of the training data ' |
|
|
f'should be provided' |
|
|
) |
|
|
|
|
|
|
|
|
data_dir = os.path.dirname(label_file) |
|
|
unique_labels = set(pad_label) |
|
|
all_labels = [] |
|
|
with open(label_file, 'r') as f: |
|
|
for line in f: |
|
|
line = line.strip().split() |
|
|
all_labels.extend(line) |
|
|
unique_labels.update(line) |
|
|
|
|
|
|
|
|
|
|
|
if label_ids_dict: |
|
|
logging.info(f'Using provided labels mapping {label_ids_dict}') |
|
|
for name in unique_labels: |
|
|
if name not in label_ids_dict: |
|
|
raise ValueError(f'{name} class from {label_file} not found in the provided mapping: {label_ids_dict}') |
|
|
else: |
|
|
label_ids_dict = {pad_label: 0} |
|
|
if pad_label in unique_labels: |
|
|
unique_labels.remove(pad_label) |
|
|
for label in sorted(unique_labels): |
|
|
label_ids_dict[label] = len(label_ids_dict) |
|
|
|
|
|
label_ids_filename = os.path.join(data_dir, class_labels_file_artifact) |
|
|
if is_training: |
|
|
with open(label_ids_filename, 'w') as f: |
|
|
labels, _ = zip(*sorted(label_ids_dict.items(), key=lambda x: x[1])) |
|
|
f.write('\n'.join(labels)) |
|
|
logging.info(f'Labels mapping {label_ids_dict} saved to : {label_ids_filename}') |
|
|
|
|
|
|
|
|
base_name = os.path.splitext(os.path.basename(label_file))[0] |
|
|
stats_file = os.path.join(data_dir, f'{base_name}_label_stats.tsv') |
|
|
if os.path.exists(stats_file) and not is_training and not get_weights: |
|
|
logging.info(f'{stats_file} found, skipping stats calculation.') |
|
|
else: |
|
|
all_labels = [label_ids_dict[label] for label in all_labels] |
|
|
logging.info(f'Three most popular labels in {label_file}:') |
|
|
total_labels, label_frequencies, max_id = get_label_stats(all_labels, stats_file) |
|
|
logging.info(f'Total labels: {total_labels}. Label frequencies - {label_frequencies}') |
|
|
|
|
|
if get_weights: |
|
|
class_weights_pkl = os.path.join(data_dir, f'{base_name}_weights.p') |
|
|
if os.path.exists(class_weights_pkl): |
|
|
class_weights = pickle.load(open(class_weights_pkl, 'rb')) |
|
|
logging.info(f'Class weights restored from {class_weights_pkl}') |
|
|
else: |
|
|
class_weights_dict = get_freq_weights(label_frequencies) |
|
|
logging.info(f'Class Weights: {class_weights_dict}') |
|
|
class_weights = fill_class_weights(class_weights_dict, max_id) |
|
|
|
|
|
pickle.dump(class_weights, open(class_weights_pkl, "wb")) |
|
|
logging.info(f'Class weights saved to {class_weights_pkl}') |
|
|
else: |
|
|
class_weights = None |
|
|
|
|
|
return label_ids_dict, label_ids_filename, class_weights |
|
|
|