| import os | |
| import yaml | |
| import random | |
| import pickle | |
| import logging | |
| from collections import defaultdict | |
| from tqdm import tqdm | |
| def _get_config(): | |
| with open('config.yml', 'r') as f: | |
| config = yaml.load(f) | |
| classes_path = config['cls_classes_path'] | |
| if os.path.exists(classes_path): | |
| with open(classes_path, 'rb') as f: | |
| config['main_classes'] = pickle.load(f) | |
| config['main_classes_count'] = len(config['main_classes']) | |
| return config | |
| CONFIG = _get_config() | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s %(levelname)s: %(message)s', | |
| datefmt='%d.%m.%Y %H:%M:%S' | |
| ) | |
| DATASET_PATH = CONFIG['dataset_path'] | |
| DICS_PATH = CONFIG['dics_path'] | |
| RANDOM = random.Random(CONFIG['random_seed']) | |
| GRAMMEMES_TYPES = CONFIG['grammemes_types'] | |
| class MyDefaultDict(defaultdict): | |
| def __missing__(self, key): | |
| if self.default_factory is None: | |
| raise KeyError( key ) | |
| else: | |
| ret = self[key] = self.default_factory(key) | |
| return ret | |
| def get_grams_info(config): | |
| dict_post_types = config['dict_post_types'] | |
| grammemes_types = config['grammemes_types'] | |
| src_convert = {} | |
| classes_indexes = {} | |
| for gram_key in grammemes_types: | |
| gram = grammemes_types[gram_key] | |
| cls_dic = gram['classes'] | |
| classes_indexes[gram_key] = {} | |
| for cls_key in cls_dic: | |
| cls_obj = cls_dic[cls_key] | |
| classes_indexes[gram_key][cls_key] = cls_obj['index'] | |
| for key in cls_obj['keys']: | |
| src_convert[key.lower()] = (gram_key, cls_key) | |
| for post_key in dict_post_types: | |
| cls_obj = dict_post_types[post_key.lower()] | |
| p_key = ('post', post_key.lower()) | |
| for key in cls_obj['keys']: | |
| src_convert[key.lower()] = p_key | |
| return src_convert, classes_indexes | |
| def decode_word(vect_mas): | |
| conf = CONFIG | |
| word = [] | |
| for ci in vect_mas: | |
| if ci == conf['end_token']: | |
| return "".join(word) | |
| elif ci < len((conf['chars'])): | |
| word.append(conf['chars'][ci]) | |
| else: | |
| word.append("0") | |
| return "".join(word) | |
| def select_uniform_items(items_dict, persent, ds_info): | |
| for cls in tqdm(items_dict, desc=f"Selecting {ds_info} dataset"): | |
| i = 0 | |
| items = items_dict[cls] | |
| per_group_count = persent * len(items) / 100 | |
| while i <= per_group_count and len(items) > 0: | |
| item = items[0] | |
| items.remove(item) | |
| yield item | |
| i += 1 | |
| def save_dataset(items_dict, file_prefix): | |
| if not os.path.isdir(DATASET_PATH): | |
| os.mkdir(DATASET_PATH) | |
| total_count = sum([len(items_dict[key]) for key in items_dict]) | |
| logging.info(f"Class '{file_prefix}': {total_count}") | |
| for key in tqdm(items_dict, desc=f"Shuffling {file_prefix} items"): | |
| random.shuffle(items_dict[key]) | |
| test_items = list(select_uniform_items(items_dict, CONFIG['test_persent'], f"test {file_prefix}")) | |
| valid_items = list(select_uniform_items(items_dict, CONFIG['validation_persent'], f"valid {file_prefix}")) | |
| items = [] | |
| for key in items_dict: | |
| items.extend(items_dict[key]) | |
| random.shuffle(items) | |
| logging.info(f"Saving '{file_prefix}' train dataset") | |
| with open(os.path.join(DATASET_PATH, f"{file_prefix}_train_dataset.pkl"), 'wb+') as f: | |
| pickle.dump(items, f) | |
| logging.info(f"Saving '{file_prefix}' valid dataset") | |
| with open(os.path.join(DATASET_PATH, f"{file_prefix}_valid_dataset.pkl"), 'wb+') as f: | |
| pickle.dump(valid_items, f) | |
| logging.info(f"Saving '{file_prefix}' test dataset") | |
| with open(os.path.join(DATASET_PATH, f"{file_prefix}_test_dataset.pkl"), 'wb+') as f: | |
| pickle.dump(test_items, f) | |
| def get_dict_path(file_prefix): | |
| return os.path.join(DICS_PATH, f"{file_prefix}_dict_items.pkl") | |
| def save_dictionary_items(items, file_prefix): | |
| with open(get_dict_path(file_prefix), 'wb+') as f: | |
| pickle.dump(items, f) | |
| def create_cls_tuple(item): | |
| return tuple( | |
| item[key] if key in item else None | |
| for key in GRAMMEMES_TYPES | |
| ) | |
| def load_datasets(main_type, *ds_type): | |
| words = [] | |
| def load_words(type): | |
| path = os.path.join(CONFIG['dataset_path'], f"{main_type}_{type}_dataset.pkl") | |
| with open(path, 'rb') as f: | |
| words.extend(pickle.load(f)) | |
| for key in ds_type: | |
| load_words(key) | |
| return words | |