File size: 4,455 Bytes
0240c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import yaml
import random
import pickle
import logging
from collections import defaultdict
from tqdm import tqdm


def _get_config():
    with open('config.yml', 'r') as f:
        config = yaml.load(f)

    classes_path = config['cls_classes_path']
    if os.path.exists(classes_path):
        with open(classes_path, 'rb') as f:
            config['main_classes'] = pickle.load(f)
            config['main_classes_count'] = len(config['main_classes'])

    return config


CONFIG = _get_config()
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s: %(message)s',
    datefmt='%d.%m.%Y %H:%M:%S'
)
DATASET_PATH = CONFIG['dataset_path']
DICS_PATH = CONFIG['dics_path']
RANDOM = random.Random(CONFIG['random_seed'])
GRAMMEMES_TYPES = CONFIG['grammemes_types']


class MyDefaultDict(defaultdict):
    def __missing__(self, key):

        if self.default_factory is None:
            raise KeyError( key )
        else:
            ret = self[key] = self.default_factory(key)
            return ret


def get_grams_info(config):
    dict_post_types = config['dict_post_types']
    grammemes_types = config['grammemes_types']
    src_convert = {}
    classes_indexes = {}

    for gram_key in grammemes_types:
        gram = grammemes_types[gram_key]
        cls_dic = gram['classes']
        classes_indexes[gram_key] = {}
        for cls_key in cls_dic:
            cls_obj = cls_dic[cls_key]
            classes_indexes[gram_key][cls_key] = cls_obj['index']

            for key in cls_obj['keys']:
                src_convert[key.lower()] = (gram_key, cls_key)

    for post_key in dict_post_types:
        cls_obj = dict_post_types[post_key.lower()]
        p_key = ('post', post_key.lower())
        for key in cls_obj['keys']:
            src_convert[key.lower()] = p_key

    return src_convert, classes_indexes


def decode_word(vect_mas):
    conf = CONFIG
    word = []
    for ci in vect_mas:
        if ci == conf['end_token']:
            return "".join(word)
        elif ci < len((conf['chars'])):
            word.append(conf['chars'][ci])
        else:
            word.append("0")

    return "".join(word)


def select_uniform_items(items_dict, persent, ds_info):
    for cls in tqdm(items_dict, desc=f"Selecting {ds_info} dataset"):
        i = 0
        items = items_dict[cls]
        per_group_count = persent * len(items) / 100

        while i <= per_group_count and len(items) > 0:
            item = items[0]
            items.remove(item)
            yield item
            i += 1


def save_dataset(items_dict, file_prefix):
    if not os.path.isdir(DATASET_PATH):
        os.mkdir(DATASET_PATH)

    total_count = sum([len(items_dict[key]) for key in items_dict])
    logging.info(f"Class '{file_prefix}': {total_count}")
    for key in tqdm(items_dict, desc=f"Shuffling {file_prefix} items"):
        random.shuffle(items_dict[key])

    test_items = list(select_uniform_items(items_dict, CONFIG['test_persent'], f"test {file_prefix}"))
    valid_items = list(select_uniform_items(items_dict, CONFIG['validation_persent'], f"valid {file_prefix}"))
    items = []
    for key in items_dict:
        items.extend(items_dict[key])
    random.shuffle(items)

    logging.info(f"Saving '{file_prefix}' train dataset")
    with open(os.path.join(DATASET_PATH, f"{file_prefix}_train_dataset.pkl"), 'wb+') as f:
        pickle.dump(items, f)

    logging.info(f"Saving '{file_prefix}' valid dataset")
    with open(os.path.join(DATASET_PATH, f"{file_prefix}_valid_dataset.pkl"), 'wb+') as f:
        pickle.dump(valid_items, f)

    logging.info(f"Saving '{file_prefix}' test dataset")
    with open(os.path.join(DATASET_PATH, f"{file_prefix}_test_dataset.pkl"), 'wb+') as f:
        pickle.dump(test_items, f)


def get_dict_path(file_prefix):
    return os.path.join(DICS_PATH, f"{file_prefix}_dict_items.pkl")


def save_dictionary_items(items, file_prefix):
    with open(get_dict_path(file_prefix), 'wb+') as f:
        pickle.dump(items, f)


def create_cls_tuple(item):
    return tuple(
        item[key] if key in item else None
        for key in GRAMMEMES_TYPES
    )


def load_datasets(main_type, *ds_type):
    words = []

    def load_words(type):
        path = os.path.join(CONFIG['dataset_path'], f"{main_type}_{type}_dataset.pkl")
        with open(path, 'rb') as f:
            words.extend(pickle.load(f))

    for key in ds_type:
        load_words(key)

    return words