File size: 2,169 Bytes
2d06dcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import random
from typing import Any, Dict, Union


def create_negative_dataset(train_dataloader):

    list = []
    for step, inputs in enumerate(train_dataloader):
        input_ids, input_masks, segment_ids, label_ids = inputs
        input_ids = input_ids.tolist()
        input_masks = input_masks.tolist()
        segment_ids = segment_ids.tolist()
        label_ids = label_ids.tolist()

        for i in range(len(input_ids)-1):
            input_id = input_ids[i]
            input_mask = input_masks[i]
            segment_id = segment_ids[i]
            label_id = label_ids[i]
            batch_dict = {"labels": label_id, "input_ids": input_id, "token_type_ids": segment_id,
                                "attention_mask": input_mask}
            list.append(batch_dict)
            
    negative_dataset = {}

    for line in list:
        label = int(line["labels"])
        inputs = line
        inputs.pop("labels")
        if label not in negative_dataset.keys():
            negative_dataset[label] = [inputs]
        else:
            negative_dataset[label].append(inputs)

    return negative_dataset


def generate_positive_sample(negative_data, args, label: torch.Tensor):
    positive_num = args.positive_num # 3
    # positive_num = 16
    positive_sample = []
    for index in range(label.shape[0]):
        input_label = int(label[index])
        positive_sample.extend(random.sample(negative_data[input_label], positive_num))

    return list_item_to_tensor(positive_sample)

def _prepare_inputs(device, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor):
            inputs[k] = v.to(device)
                    
    return inputs


def list_item_to_tensor(inputs_list):
    batch_list = {}
    for key, value in inputs_list[0].items():
        batch_list[key] = []
    for inputs in inputs_list:
        for key, value in inputs.items():
            batch_list[key].append(value)

    batch_tensor = {}
    for key, value in batch_list.items():
        batch_tensor[key] = torch.tensor(value)
    return batch_tensor