File size: 2,672 Bytes
0240c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import pickle
import numpy as np
from tqdm import tqdm
from sklearn.preprocessing import normalize
from collections import defaultdict
from utils import get_grams_info, CONFIG, save_dataset


VECT_PATH = CONFIG['vect_words_path']
CLS_CLASSES_PATH = CONFIG['cls_classes_path']
GRAMMEMES_TYPES = CONFIG['grammemes_types']
SRC_CONVERT, CLASSES_INDEXES = get_grams_info(CONFIG)


def generate_dataset(vec_words, cls_type, cls_dic):
    ordered_keys = [cls for cls in sorted(cls_dic, key=lambda cls: cls_dic[cls])]
    weights = [0 for key in ordered_keys]
    for word in tqdm(vec_words, desc=f"Calculating {cls_type} weights"):
        for form in vec_words[word]['forms']:
            if cls_type in form:
                i = cls_dic[form[cls_type]]
                weights[i] = weights[i] + 1

    weights = normalize(np.asarray(weights).reshape(1, -1))
    weights = np.ones((len(ordered_keys),)) - weights

    rez_items = defaultdict(list)
    cur_cls = None
    for word in tqdm(vec_words, desc=f"Generating classification {cls_type} dataset"):
        y = np.zeros((len(ordered_keys),), dtype=np.int)
        has_classes = False
        for form in vec_words[word]['forms']:
            if cls_type in form:
                cur_cls = form[cls_type]
                index = cls_dic[cur_cls]
                y[index] = 1
                has_classes = True

        if has_classes:
            items = rez_items[cur_cls]
            items.append({
                'src': word,
                'x': vec_words[word]['vect'],
                'y': y,
                'weight': weights.reshape(-1, 1)[y == 1].max()
            })
            rez_items[cur_cls] = items

    save_dataset(rez_items, cls_type)


def generate_all(vec_words):
    for cls_type in CLASSES_INDEXES:
        cls_dic = CLASSES_INDEXES[cls_type]
        generate_dataset(vec_words, cls_type, cls_dic)

    un_classes = []
    for word in tqdm(vec_words, desc="Setting main class"):
        for form in vec_words[word]['forms']:
            tpl = tuple(
                form[key] if key in form else None
                for key in GRAMMEMES_TYPES
            )
            if tpl not in un_classes:
                un_classes.append(tpl)
            form['main'] = tpl

    with open(VECT_PATH, 'wb+') as f:
        pickle.dump(vec_words, f)

    cls_dic = {
        tpl: index
        for index, tpl in enumerate(un_classes)
    }
    generate_dataset(vec_words, 'main', cls_dic)
    print(f"Main classes count: {len(cls_dic)}")
    with open(CLS_CLASSES_PATH, 'wb+') as f:
        pickle.dump(cls_dic, f)


with open(VECT_PATH, 'rb') as f:
    vwords = pickle.load(f)

generate_all(vwords)