supcon / code /src /data_collators.py
IGandarillas1's picture
add code
0797029
raw
history blame
5.56 kB
import numpy as np
np.random.seed(42)
import random
random.seed(42)
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
from pdb import set_trace
# collator for pair-wise cross-entropy fine-tuning
@dataclass
class DataCollatorContrastiveClassification:
tokenizer: PreTrainedTokenizerBase
max_length: Optional[int] = 128
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, input):
features_left = [x['features_left'] for x in input]
features_right = [x['features_right'] for x in input]
labels = [x['labels'] for x in input]
batch_left = self.tokenizer(features_left, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors)
batch_right = self.tokenizer(features_right, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors)
batch = batch_left
if 'token_type_ids' in batch.keys():
del batch['token_type_ids']
batch['input_ids_right'] = batch_right['input_ids']
batch['attention_mask_right'] = batch_right['attention_mask']
batch['labels'] = torch.LongTensor(labels)
return batch
# collator for self-supervised contrastive pre-training
@dataclass
class DataCollatorContrastivePretrainSelfSupervised:
tokenizer: PreTrainedTokenizerBase
max_length: Optional[int] = 128
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, input):
features_left = [x[0]['features'] for x in input]
labels = [x[0]['labels'] for x in input]
batch = self.tokenizer(features_left, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors)
if 'token_type_ids' in batch.keys():
del batch['token_type_ids']
batch['labels'] = torch.LongTensor(labels)
return batch
# collator for supervised contrastive pre-training for WDC Computers
@dataclass
class DataCollatorContrastivePretrain:
tokenizer: PreTrainedTokenizerBase
max_length: Optional[int] = 128
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, input):
features_left = [x[0]['features'] for x in input]
features_right = [x[1]['features'] for x in input]
labels = [x[0]['labels'] for x in input]
batch_left = self.tokenizer(features_left, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors)
batch_right = self.tokenizer(features_right, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors)
batch = batch_left
if 'token_type_ids' in batch.keys():
del batch['token_type_ids']
batch['input_ids_right'] = batch_right['input_ids']
batch['attention_mask_right'] = batch_right['attention_mask']
batch['labels'] = torch.LongTensor(labels)
return batch
# collator for supervised contrastive pre-training for Abt-Buy and Amazon-Google
# randomly chooses the sampling dataset when using source-aware sampling
@dataclass
class DataCollatorContrastivePretrainDeepmatcher:
tokenizer: PreTrainedTokenizerBase
max_length: Optional[int] = 128
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, input_both):
rnd = random.choice([0,1])
input = [x[rnd] for x in input_both]
features_left = [x[0]['features'] for x in input]
features_right = [x[1]['features'] for x in input]
labels = [x[0]['labels'] for x in input]
batch_left = self.tokenizer(features_left, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors)
batch_right = self.tokenizer(features_right, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors)
batch = batch_left
if 'token_type_ids' in batch.keys():
del batch['token_type_ids']
batch['input_ids_right'] = batch_right['input_ids']
batch['attention_mask_right'] = batch_right['attention_mask']
batch['labels'] = torch.LongTensor(labels)
return batch
# collator for pair-wise cross-entropy fine-tuning
@dataclass
class DataCollatorContrastiveClassification:
tokenizer: PreTrainedTokenizerBase
max_length: Optional[int] = 128
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, input):
features_left = [x['features_left'] for x in input]
features_right = [x['features_right'] for x in input]
labels = [x['labels'] for x in input]
batch_left = self.tokenizer(features_left, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors)
batch_right = self.tokenizer(features_right, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors)
batch = batch_left
if 'token_type_ids' in batch.keys():
del batch['token_type_ids']
batch['input_ids_right'] = batch_right['input_ids']
batch['attention_mask_right'] = batch_right['attention_mask']
batch['labels'] = torch.LongTensor(labels)
return batch