supcon / code /src /datasets.py
IGandarillas1's picture
add code
0797029
raw
history blame
19.3 kB
import numpy as np
np.random.seed(42)
import random
random.seed(42)
import pandas as pd
import json
from pathlib import Path
import glob
import gzip
import pickle
from copy import deepcopy
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoConfig
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.char as nac
from sklearn.preprocessing import LabelEncoder
from pdb import set_trace
def assign_clusterid(identifier, cluster_id_dict, cluster_id_amount):
try:
result = cluster_id_dict[identifier]
except KeyError:
result = cluster_id_amount
return result
# Methods for serializing examples by dataset
def serialize_sample_lspc(sample):
string = ''
string = f'{string}[COL] brand [VAL] {" ".join(sample[f"brand"].split(" ")[:5])}'.strip()
string = f'{string} [COL] title [VAL] {" ".join(sample[f"title"].split(" ")[:50])}'.strip()
string = f'{string} [COL] description [VAL] {" ".join(sample[f"description"].split(" ")[:100])}'.strip()
string = f'{string} [COL] specTableContent [VAL] {" ".join(sample[f"specTableContent"].split(" ")[:200])}'.strip()
return string
def serialize_sample_abtbuy(sample):
string = ''
string = f'{string}[COL] brand [VAL] {" ".join(sample[f"brand"].split())}'.strip()
string = f'{string} [COL] title [VAL] {" ".join(sample[f"name"].split())}'.strip()
string = f'{string} [COL] price [VAL] {" ".join(str(sample[f"price"]).split())}'.strip()
string = f'{string} [COL] description [VAL] {" ".join(sample[f"description"].split()[:100])}'.strip()
return string
def serialize_sample_amazongoogle(sample):
string = ''
string = f'{string}[COL] brand [VAL] {" ".join(sample[f"manufacturer"].split())}'.strip()
string = f'{string} [COL] title [VAL] {" ".join(sample[f"title"].split())}'.strip()
string = f'{string} [COL] price [VAL] {" ".join(str(sample[f"price"]).split())}'.strip()
string = f'{string} [COL] description [VAL] {" ".join(sample[f"description"].split()[:100])}'.strip()
return string
# Class for Data Augmentation
class Augmenter():
def __init__(self, aug):
stopwords = ['[COL]', '[VAL]', 'title', 'name', 'description', 'manufacturer', 'brand', 'specTableContent']
aug_typo = nac.KeyboardAug(stopwords=stopwords, aug_char_p=0.1, aug_word_p=0.1)
aug_swap = naw.RandomWordAug(action="swap", stopwords=stopwords, aug_p=0.1)
aug_del = naw.RandomWordAug(action="delete", stopwords=stopwords, aug_p=0.1)
aug_crop = naw.RandomWordAug(action="crop", stopwords=stopwords, aug_p=0.1)
aug_sub = naw.RandomWordAug(action="substitute", stopwords=stopwords, aug_p=0.1)
aug_split = naw.SplitAug(stopwords=stopwords, aug_p=0.1)
aug = aug.strip('-')
if aug == 'all':
self.augs = [aug_typo, aug_swap, aug_split, aug_sub, aug_del, aug_crop, None]
if aug == 'typo':
self.augs = [aug_typo, None]
if aug == 'swap':
self.augs = [aug_swap, None]
if aug == 'delete':
self.augs = [aug_del, None]
if aug == 'crop':
self.augs = [aug_crop, None]
if aug == 'substitute':
self.augs = [aug_sub, None]
if aug == 'split':
self.augs = [aug_split, None]
def apply_aug(self, string):
aug = random.choice(self.augs)
if aug is None:
return string
else:
return aug.augment(string)
# Dataset class for general Contrastive Pre-training for WDC computers
class ContrastivePretrainDataset(torch.utils.data.Dataset):
def __init__(self, path, deduction_set, tokenizer='huawei-noah/TinyBERT_General_4L_312D', max_length=128, intermediate_set=None, clean=False, dataset='lspc', only_interm=False, aug=False):
self.max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, additional_special_tokens=('[COL]', '[VAL]'))
self.dataset = dataset
self.aug = aug
if self.aug:
self.augmenter = Augmenter(self.aug)
data = pd.read_pickle(path)
if dataset == 'abt-buy':
data['brand'] = ''
if dataset == 'amazon-google':
data['description'] = ''
if intermediate_set is not None:
interm_data = pd.read_pickle(intermediate_set)
if only_interm:
data = interm_data
else:
data = data.append(interm_data)
data = data.reset_index(drop=True)
data = data.fillna('')
data = self._prepare_data(data)
self.data = data
def __getitem__(self, idx):
# for every example in batch, sample one positive from the dataset
example = self.data.loc[idx].copy()
selection = self.data[self.data['labels'] == example['labels']]
# if len(selection) > 1:
# selection = selection.drop(idx)
pos = selection.sample(1).iloc[0].copy()
# apply augmentation if set
if self.aug:
example['features'] = self.augmenter.apply_aug(example['features'])
pos['features'] = self.augmenter.apply_aug(pos['features'])
return (example, pos)
def __len__(self):
return len(self.data)
def _prepare_data(self, data):
if self.dataset == 'lspc':
data['features'] = data.apply(serialize_sample_lspc, axis=1)
elif self.dataset == 'abt-buy':
data['features'] = data.apply(serialize_sample_abtbuy, axis=1)
elif self.dataset == 'amazon-google':
data['features'] = data.apply(serialize_sample_amazongoogle, axis=1)
label_enc = LabelEncoder()
data['labels'] = label_enc.fit_transform(data['cluster_id'])
self.label_encoder = label_enc
data = data[['features', 'labels']]
return data
# Dataset class for Contrastive Pre-training for Abt-Buy and Amazon-Google
# builds correspondence graph from train+val and builds source-aware sampling datasets
# if split=False, corresponds to not using source-aware sampling
class ContrastivePretrainDatasetDeepmatcher(torch.utils.data.Dataset):
def __init__(self, path, deduction_set, tokenizer='huawei-noah/TinyBERT_General_4L_312D', max_length=128, intermediate_set=None, clean=False, dataset='abt-buy', aug=False, split=True):
self.max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, additional_special_tokens=('[COL]', '[VAL]'))
self.dataset = dataset
self.aug = aug
if self.aug:
self.augmenter = Augmenter(self.aug)
data = pd.read_pickle(path)
if dataset == 'abt-buy':
data['brand'] = ''
if dataset == 'amazon-google':
data['description'] = ''
if clean:
train_data = pd.read_json(deduction_set, lines=True)
if dataset == 'abt-buy':
val = pd.read_csv('../../data/interim/abt-buy/abt-buy-valid.csv')
elif dataset == 'amazon-google':
val = pd.read_csv('../../data/interim/amazon-google/amazon-google-valid.csv')
# use 80% of train and val set positives to build correspondence graph
val_set = train_data[train_data['pair_id'].isin(val['pair_id'])]
val_set_pos = val_set[val_set['label'] == 1]
val_set_pos = val_set_pos.sample(frac=0.80)
val_ids = set()
val_ids.update(val_set['pair_id'])
train_data = train_data[~train_data['pair_id'].isin(val_ids)]
train_data = train_data[train_data['label'] == 1]
train_data = train_data.sample(frac=0.80)
train_data = train_data.append(val_set_pos)
# build the connected components by applying binning
bucket_list = []
for i, row in train_data.iterrows():
left = f'{row["id_left"]}'
right = f'{row["id_right"]}'
found = False
for bucket in bucket_list:
if left in bucket and row['label'] == 1:
bucket.add(right)
found = True
break
elif right in bucket and row['label'] == 1:
bucket.add(left)
found = True
break
if not found:
bucket_list.append(set([left, right]))
cluster_id_amount = len(bucket_list)
#assign labels to connected components and single nodes (at this point single nodes have same label)
cluster_id_dict = {}
for i, id_set in enumerate(bucket_list):
for v in id_set:
cluster_id_dict[v] = i
data = data.set_index('id', drop=False)
data['cluster_id'] = data['id'].apply(assign_clusterid, args=(cluster_id_dict, cluster_id_amount))
#data = data[data['cluster_id'] != cluster_id_amount]
single_entities = data[data['cluster_id'] == cluster_id_amount].copy()
index = single_entities.index
if dataset == 'abt-buy':
left_index = [x for x in index if 'abt' in x]
right_index = [x for x in index if 'buy' in x]
elif dataset == 'amazon-google':
left_index = [x for x in index if 'amazon' in x]
right_index = [x for x in index if 'google' in x]
# assing increasing integer label to single nodes
single_entities = single_entities.reset_index(drop=True)
single_entities['cluster_id'] = single_entities['cluster_id'] + single_entities.index
single_entities = single_entities.set_index('id', drop=False)
single_entities_left = single_entities.loc[left_index]
single_entities_right = single_entities.loc[right_index]
# source aware sampling, build one sample per dataset
if split:
data1 = data.copy().drop(single_entities['id'])
data1 = data1.append(single_entities_left)
data2 = data.copy().drop(single_entities['id'])
data2 = data2.append(single_entities_right)
else:
data1 = data.copy().drop(single_entities['id'])
data1 = data1.append(single_entities_left)
data1 = data1.append(single_entities_right)
data2 = data.copy().drop(single_entities['id'])
data2 = data2.append(single_entities_left)
data2 = data2.append(single_entities_right)
if intermediate_set is not None:
interm_data = pd.read_pickle(intermediate_set)
if dataset != 'lspc':
cols = data.columns
if 'name' in cols:
interm_data = interm_data.rename(columns={'title':'name'})
if 'manufacturer' in cols:
interm_data = interm_data.rename(columns={'brand':'manufacturer'})
interm_data['cluster_id'] = interm_data['cluster_id']+10000
data1 = data1.append(interm_data)
data2 = data2.append(interm_data)
data1 = data1.reset_index(drop=True)
data2 = data2.reset_index(drop=True)
label_enc = LabelEncoder()
cluster_id_set = set()
cluster_id_set.update(data1['cluster_id'])
cluster_id_set.update(data2['cluster_id'])
label_enc.fit(list(cluster_id_set))
data1['labels'] = label_enc.transform(data1['cluster_id'])
data2['labels'] = label_enc.transform(data2['cluster_id'])
self.label_encoder = label_enc
data1 = data1.reset_index(drop=True)
data1 = data1.fillna('')
data1 = self._prepare_data(data1)
data2 = data2.reset_index(drop=True)
data2 = data2.fillna('')
data2 = self._prepare_data(data2)
diff = abs(len(data1)-len(data2))
if len(data1) > len(data2):
if len(data2) < diff:
sample = data2.sample(diff, replace=True)
else:
sample = data2.sample(diff)
data2 = data2.append(sample)
data2 = data2.reset_index(drop=True)
elif len(data2) > len(data1):
if len(data1) < diff:
sample = data1.sample(diff, replace=True)
else:
sample = data1.sample(diff)
data1 = data1.append(sample)
data1 = data1.reset_index(drop=True)
self.data1 = data1
self.data2 = data2
def __getitem__(self, idx):
# for every example, sample one positive from the respective sampling dataset
example1 = self.data1.loc[idx].copy()
selection1 = self.data1[self.data1['labels'] == example1['labels']]
# if len(selection1) > 1:
# selection1 = selection1.drop(idx)
pos1 = selection1.sample(1).iloc[0].copy()
example2 = self.data2.loc[idx].copy()
selection2 = self.data2[self.data2['labels'] == example2['labels']]
# if len(selection2) > 1:
# selection2 = selection2.drop(idx)
pos2 = selection2.sample(1).iloc[0].copy()
# apply augmentation if set
if self.aug:
example1['features'] = self.augmenter.apply_aug(example1['features'])
pos1['features'] = self.augmenter.apply_aug(pos1['features'])
example2['features'] = self.augmenter.apply_aug(example2['features'])
pos2['features'] = self.augmenter.apply_aug(pos2['features'])
return ((example1, pos1), (example2, pos2))
def __len__(self):
return len(self.data1)
def _prepare_data(self, data):
if self.dataset == 'lspc':
data['features'] = data.apply(serialize_sample_lspc, axis=1)
elif self.dataset == 'abt-buy':
data['features'] = data.apply(serialize_sample_abtbuy, axis=1)
elif self.dataset == 'amazon-google':
data['features'] = data.apply(serialize_sample_amazongoogle, axis=1)
data = data[['features', 'labels']]
return data
# Dataset class for pair-wise cross-entropy fine-tuning
class ContrastiveClassificationDataset(torch.utils.data.Dataset):
def __init__(self, path, dataset_type, size=None, tokenizer='huawei-noah/TinyBERT_General_4L_312D', max_length=128, dataset='lspc', aug=False):
self.max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, additional_special_tokens=('[COL]', '[VAL]'))
self.dataset_type = dataset_type
self.dataset = dataset
self.aug = aug
if self.aug:
self.augmenter = Augmenter(self.aug)
if dataset == 'serialized':
data = json.loads(path)
self.data = data
return
if dataset == 'lspc':
data = pd.read_pickle(path)
else:
data = pd.read_json(path, lines=True)
if dataset == 'abt-buy':
data['brand_left'] = ''
data['brand_right'] = ''
if dataset == 'amazon-google':
data['description_left'] = ''
data['description_right'] = ''
data = data.fillna('')
if self.dataset_type != 'test':
if dataset == 'lspc':
validation_ids = pd.read_csv(f'../../data/raw/wdc-lspc/validation-sets/computers_valid_{size}.csv')
elif dataset == 'abt-buy':
validation_ids = pd.read_csv(f'../../data/interim/abt-buy/abt-buy-valid.csv')
elif dataset == 'amazon-google':
validation_ids = pd.read_csv(f'../../data/interim/amazon-google/amazon-google-valid.csv')
if self.dataset_type == 'train':
data = data[~data['pair_id'].isin(validation_ids['pair_id'])]
else:
data = data[data['pair_id'].isin(validation_ids['pair_id'])]
data = data.reset_index(drop=True)
data = self._prepare_data(data)
self.data = data
def __getitem__(self, idx):
example = self.data.loc[idx].copy()
if self.aug:
example['features_left'] = self.augmenter.apply_aug(example['features_left'])
example['features_right'] = self.augmenter.apply_aug(example['features_right'])
return example
def __len__(self):
return len(self.data)
def _prepare_data(self, data):
if self.dataset == 'lspc':
data['features_left'] = data.apply(self.serialize_sample_lspc, args=('left',), axis=1)
data['features_right'] = data.apply(self.serialize_sample_lspc, args=('right',), axis=1)
elif self.dataset == 'abt-buy':
data['features_left'] = data.apply(self.serialize_sample_abtbuy, args=('left',), axis=1)
data['features_right'] = data.apply(self.serialize_sample_abtbuy, args=('right',), axis=1)
elif self.dataset == 'amazon-google':
data['features_left'] = data.apply(self.serialize_sample_amazongoogle, args=('left',), axis=1)
data['features_right'] = data.apply(self.serialize_sample_amazongoogle, args=('right',), axis=1)
data = data[['features_left', 'features_right', 'label']]
data = data.rename(columns={'label': 'labels'})
return data
def serialize_sample_lspc(self, sample, side):
string = ''
string = f'{string}[COL] brand [VAL] {" ".join(sample[f"brand_{side}"].split(" ")[:5])}'.strip()
string = f'{string} [COL] title [VAL] {" ".join(sample[f"title_{side}"].split(" ")[:50])}'.strip()
string = f'{string} [COL] description [VAL] {" ".join(sample[f"description_{side}"].split(" ")[:100])}'.strip()
string = f'{string} [COL] specTableContent [VAL] {" ".join(sample[f"specTableContent_{side}"].split(" ")[:200])}'.strip()
return string
def serialize_sample_abtbuy(self, sample, side):
string = ''
string = f'{string}[COL] brand [VAL] {" ".join(sample[f"brand_{side}"].split())}'.strip()
string = f'{string} [COL] title [VAL] {" ".join(sample[f"name_{side}"].split())}'.strip()
string = f'{string} [COL] price [VAL] {" ".join(str(sample[f"price_{side}"]).split())}'.strip()
string = f'{string} [COL] description [VAL] {" ".join(sample[f"description_{side}"].split()[:100])}'.strip()
return string
def serialize_sample_amazongoogle(self, sample, side):
string = ''
string = f'{string}[COL] brand [VAL] {" ".join(sample[f"manufacturer_{side}"].split())}'.strip()
string = f'{string} [COL] title [VAL] {" ".join(sample[f"title_{side}"].split())}'.strip()
string = f'{string} [COL] price [VAL] {" ".join(str(sample[f"price_{side}"]).split())}'.strip()
string = f'{string} [COL] description [VAL] {" ".join(sample[f"description_{side}"].split()[:100])}'.strip()
return string