File size: 2,169 Bytes
2d06dcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import random
from typing import Any, Dict, Union
def create_negative_dataset(train_dataloader):
list = []
for step, inputs in enumerate(train_dataloader):
input_ids, input_masks, segment_ids, label_ids = inputs
input_ids = input_ids.tolist()
input_masks = input_masks.tolist()
segment_ids = segment_ids.tolist()
label_ids = label_ids.tolist()
for i in range(len(input_ids)-1):
input_id = input_ids[i]
input_mask = input_masks[i]
segment_id = segment_ids[i]
label_id = label_ids[i]
batch_dict = {"labels": label_id, "input_ids": input_id, "token_type_ids": segment_id,
"attention_mask": input_mask}
list.append(batch_dict)
negative_dataset = {}
for line in list:
label = int(line["labels"])
inputs = line
inputs.pop("labels")
if label not in negative_dataset.keys():
negative_dataset[label] = [inputs]
else:
negative_dataset[label].append(inputs)
return negative_dataset
def generate_positive_sample(negative_data, args, label: torch.Tensor):
positive_num = args.positive_num # 3
# positive_num = 16
positive_sample = []
for index in range(label.shape[0]):
input_label = int(label[index])
positive_sample.extend(random.sample(negative_data[input_label], positive_num))
return list_item_to_tensor(positive_sample)
def _prepare_inputs(device, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(device)
return inputs
def list_item_to_tensor(inputs_list):
batch_list = {}
for key, value in inputs_list[0].items():
batch_list[key] = []
for inputs in inputs_list:
for key, value in inputs.items():
batch_list[key].append(value)
batch_tensor = {}
for key, value in batch_list.items():
batch_tensor[key] = torch.tensor(value)
return batch_tensor
|