File size: 3,759 Bytes
239ff1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch, scipy
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
from sklearn.metrics import classification_report
import seaborn as sns
import matplotlib.pyplot as plt

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
        self.data_len=len(encodings['input_ids'])

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
        return item

    def __len__(self):
        return self.data_len

class CEFRDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, sent_levels_a, sent_levels_b):
        self.encodings = encodings
        self.slabels_low = np.minimum(sent_levels_a, sent_levels_b)
        self.slabels_high = np.maximum(sent_levels_a, sent_levels_b)

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
        item['slabels_low'] = self.slabels_low[idx].clone().detach()
        item['slabels_high'] = self.slabels_high[idx].clone().detach()
        return item

    def __len__(self):
        return len(self.slabels_high)


class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)


def read_corpus(path, num_labels):
    levels_a, levels_b, sents = [], [], []
    with open(path) as f:
        for line in f:
            array = line.strip().split('\t')
            sents.append(array[0].split(' '))
            levels_a.append(float(array[1]) - 1)  # Convert 1-6 to 0-5
            levels_b.append(float(array[2]) - 1)  # Convert 1-6 to 0-5

    levels_a = np.array(levels_a)
    levels_b = np.array(levels_b)

    return levels_a, levels_b, sents


def convert_numeral_to_six_levels(levels):
    level_thresholds = np.array([0.0, 0.5, 1.5, 2.5, 3.5, 4.5])
    return _conversion(level_thresholds, levels)


def _conversion(level_thresholds, values):
    thresh_array = np.tile(level_thresholds, reps=(values.shape[0], 1))
    array = np.tile(values, reps=(1, level_thresholds.shape[0]))
    levels = np.maximum(np.zeros((values.shape[0], 1)),
                        np.count_nonzero(thresh_array <= array, axis=1, keepdims=True) - 1).astype(int)
    return levels


# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(token_embeddings, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


# Take attention mask into account for excluding padding
def token_embeddings_filtering_padding(token_embeddings, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return token_embeddings * input_mask_expanded


def eval_multiclass(out_path, labels, predictions):
    cm = confusion_matrix(labels, predictions)
    plt.figure()
    sns.heatmap(cm)
    plt.savefig(out_path + '_confusion_matrix.png')

    report = classification_report(labels, predictions, digits=4)
    print(report)
    with open(out_path + '_test_report.txt', 'w') as fw:
        fw.write('{0}\n'.format(report))


def mean_confidence_interval(data, confidence=0.95):
    if len(data) > 5:
        data.remove(max(data))
        data.remove(min(data))
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
    return m, h