| from everything import * | |
| from bert import BertModel | |
| from optimizer import AdamW | |
| from tokenizer import BertTokenizer | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| class SentimentDataset(Dataset): | |
| def __init__(self, dataset): | |
| self.dataset = dataset | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| return self.dataset[idx] | |
| def pad_data(self, data): | |
| sents = [x[0] for x in data] | |
| labels = [x[1] for x in data] | |
| sent_ids = [x[2] for x in data] | |
| encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) | |
| token_ids = torch.LongTensor(encoding['input_ids']) | |
| attention_mask = torch.LongTensor(encoding['attention_mask']) | |
| labels = torch.LongTensor(labels) | |
| return token_ids, attention_mask, labels, sents, sent_ids | |
| def collate_fn(self, all_data): | |
| token_ids, attention_mask, labels, sents, sent_ids = self.pad_data(all_data) | |
| batched_data = { | |
| 'token_ids': token_ids, | |
| 'attention_mask': attention_mask, | |
| 'labels': labels, | |
| 'sents': sents, | |
| 'sent_ids': sent_ids | |
| } | |
| return batched_data | |
| class SentimentTestDataset(Dataset): | |
| def __init__(self, dataset): | |
| self.dataset = dataset | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| return self.dataset[idx] | |
| def pad_data(self, data): | |
| sents = [x[0] for x in data] | |
| sent_ids = [x[1] for x in data] | |
| encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) | |
| token_ids = torch.LongTensor(encoding['input_ids']) | |
| attention_mask = torch.LongTensor(encoding['attention_mask']) | |
| return token_ids, attention_mask, sents, sent_ids | |
| def collate_fn(self, all_data): | |
| token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data) | |
| batched_data = { | |
| 'token_ids': token_ids, | |
| 'attention_mask': attention_mask, | |
| 'sents': sents, | |
| 'sent_ids': sent_ids | |
| } | |
| return batched_data | |
| class AmazonDataset(Dataset): | |
| def __init__(self, dataset): | |
| self.dataset = dataset | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| return self.dataset[idx] | |
| def pad_data(self, data): | |
| sents = [x[0] for x in data] | |
| sent_ids = [x[1] for x in data] | |
| encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) | |
| token_ids = torch.LongTensor(encoding['input_ids']) | |
| attension_mask = torch.LongTensor(encoding['attention_mask']) | |
| return token_ids, attension_mask, sent_ids | |
| def collate_fn(self, data): | |
| token_ids, attention_mask, sent_ids = self.pad_data(data) | |
| batched_data = { | |
| 'token_ids': token_ids, | |
| 'attention_mask': attention_mask, | |
| 'sent_ids': sent_ids | |
| } | |
| return batched_data | |
| class SemanticDataset(Dataset): | |
| def __init__(self, dataset): | |
| self.dataset = dataset | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| return self.dataset[idx] | |
| def pad_data(self, data): | |
| sents1 = [x[0] for x in data] | |
| sents2 = [x[1] for x in data] | |
| score = [x[2] for x in data] | |
| sent_ids = [x[3] for x in data] | |
| encoding = tokenizer(sents1 + sents2, return_tensors='pt', padding=True, truncation=True) | |
| token_ids = torch.LongTensor(encoding['input_ids']) | |
| attension_mask = torch.LongTensor(encoding['attention_mask']) | |
| return token_ids, attension_mask, score, sent_ids | |
| def collate_fn(self, data): | |
| token_ids, attention_mask, score, sent_ids = self.pad_data(data) | |
| n = len(sent_ids) | |
| batched_data = { | |
| 'token_ids_1': token_ids[:n], | |
| 'token_ids_2': token_ids[n:], | |
| 'attention_mask_1': attention_mask[:n], | |
| 'attention_mask_2': attention_mask[n:], | |
| 'score': score, | |
| 'sent_ids': sent_ids | |
| } | |
| return batched_data | |
| class InferenceDataset(Dataset): | |
| def __init__(self, dataset): | |
| self.dataset = dataset | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| return self.dataset[idx] | |
| def pad_data(self, data): | |
| anchor = [x[0] for x in data] | |
| positive = [x[1] for x in data] | |
| negative = [x[2] for x in data] | |
| sent_ids = [x[3] for x in data] | |
| encoding = tokenizer(anchor + positive + negative, return_tensors='pt', padding=True, truncation=True) | |
| token_ids = torch.LongTensor(encoding['input_ids']) | |
| attension_mask = torch.LongTensor(encoding['attention_mask']) | |
| return token_ids, attension_mask, sent_ids | |
| def collate_fn(self, data): | |
| token_ids, attention_mask, sent_ids = self.pad_data(data) | |
| n = len(sent_ids) | |
| batched_data = { | |
| 'anchor_ids': token_ids[:n], | |
| 'positive_ids': token_ids[n:2*n], | |
| 'negative_ids': token_ids[2*n:], | |
| 'anchor_masks': attention_mask[:n], | |
| 'positive_masks': attention_mask[n:2*n], | |
| 'negative_masks': attention_mask[2*n:], | |
| 'sent_ids': sent_ids | |
| } | |
| return batched_data | |
| def load_data(filename, flag='train'): | |
| ''' | |
| - for amazon dataset: list of (sent, id) | |
| - for nli dataset: list of (anchor, positive, negative, id) | |
| - for stsb dataset: list of (sentence1, sentence2, score, id) | |
| - for test dataset: list of (sent, id) | |
| - for train dataset: list of (sent, label, id) | |
| ''' | |
| if flag == 'amazon': | |
| df = pd.read_parquet(filename) | |
| data = list(zip(df['content'], df.index)) | |
| elif flag == 'nli': | |
| df = pd.read_parquet(filename) | |
| data = list(zip(df['anchor'], df['positive'], df['negative'], df.index)) | |
| elif flag == 'stsb': | |
| df = pd.read_parquet(filename) | |
| data = list(zip(df['sentence1'], df['sentence2'], df['score'], df.index)) | |
| else: | |
| data, num_labels = [], set() | |
| with open(filename, 'r') as fp: | |
| if flag == 'test': | |
| for record in csv.DictReader(fp, delimiter = '\t'): | |
| sent = record['sentence'].lower().strip() | |
| sent_id = record['id'].lower().strip() | |
| data.append((sent,sent_id)) | |
| else: | |
| for record in csv.DictReader(fp, delimiter = '\t'): | |
| sent = record['sentence'].lower().strip() | |
| sent_id = record['id'].lower().strip() | |
| label = int(record['sentiment'].strip()) | |
| num_labels.add(label) | |
| data.append((sent, label, sent_id)) | |
| print(f"load {len(data)} data from {filename}") | |
| if flag == "train": | |
| return data, len(num_labels) | |
| else: | |
| return data | |
| def seed_everything(seed=11711): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True |