| import os | |
| import pickle | |
| import numpy as np | |
| from tqdm import tqdm | |
| from sklearn.preprocessing import normalize | |
| from collections import defaultdict | |
| from utils import get_grams_info, CONFIG, save_dataset | |
| VECT_PATH = CONFIG['vect_words_path'] | |
| CLS_CLASSES_PATH = CONFIG['cls_classes_path'] | |
| GRAMMEMES_TYPES = CONFIG['grammemes_types'] | |
| SRC_CONVERT, CLASSES_INDEXES = get_grams_info(CONFIG) | |
| def generate_dataset(vec_words, cls_type, cls_dic): | |
| ordered_keys = [cls for cls in sorted(cls_dic, key=lambda cls: cls_dic[cls])] | |
| weights = [0 for key in ordered_keys] | |
| for word in tqdm(vec_words, desc=f"Calculating {cls_type} weights"): | |
| for form in vec_words[word]['forms']: | |
| if cls_type in form: | |
| i = cls_dic[form[cls_type]] | |
| weights[i] = weights[i] + 1 | |
| weights = normalize(np.asarray(weights).reshape(1, -1)) | |
| weights = np.ones((len(ordered_keys),)) - weights | |
| rez_items = defaultdict(list) | |
| cur_cls = None | |
| for word in tqdm(vec_words, desc=f"Generating classification {cls_type} dataset"): | |
| y = np.zeros((len(ordered_keys),), dtype=np.int) | |
| has_classes = False | |
| for form in vec_words[word]['forms']: | |
| if cls_type in form: | |
| cur_cls = form[cls_type] | |
| index = cls_dic[cur_cls] | |
| y[index] = 1 | |
| has_classes = True | |
| if has_classes: | |
| items = rez_items[cur_cls] | |
| items.append({ | |
| 'src': word, | |
| 'x': vec_words[word]['vect'], | |
| 'y': y, | |
| 'weight': weights.reshape(-1, 1)[y == 1].max() | |
| }) | |
| rez_items[cur_cls] = items | |
| save_dataset(rez_items, cls_type) | |
| def generate_all(vec_words): | |
| for cls_type in CLASSES_INDEXES: | |
| cls_dic = CLASSES_INDEXES[cls_type] | |
| generate_dataset(vec_words, cls_type, cls_dic) | |
| un_classes = [] | |
| for word in tqdm(vec_words, desc="Setting main class"): | |
| for form in vec_words[word]['forms']: | |
| tpl = tuple( | |
| form[key] if key in form else None | |
| for key in GRAMMEMES_TYPES | |
| ) | |
| if tpl not in un_classes: | |
| un_classes.append(tpl) | |
| form['main'] = tpl | |
| with open(VECT_PATH, 'wb+') as f: | |
| pickle.dump(vec_words, f) | |
| cls_dic = { | |
| tpl: index | |
| for index, tpl in enumerate(un_classes) | |
| } | |
| generate_dataset(vec_words, 'main', cls_dic) | |
| print(f"Main classes count: {len(cls_dic)}") | |
| with open(CLS_CLASSES_PATH, 'wb+') as f: | |
| pickle.dump(cls_dic, f) | |
| with open(VECT_PATH, 'rb') as f: | |
| vwords = pickle.load(f) | |
| generate_all(vwords) | |