Commit
·
897eef4
1
Parent(s):
02b98f5
contrastive commit 3
Browse files
data/{twitter-unsup.csv → amazon-polarity.parquet}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dbe4770cfa6be45add6c9a322044bd4c1901520dde5a2707eca402a74fbe854e
|
| 3 |
+
size 870289
|
unsup_simcse.py
CHANGED
|
@@ -3,6 +3,7 @@ import torch
|
|
| 3 |
import random
|
| 4 |
import argparse
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
| 8 |
from tqdm import tqdm
|
|
@@ -20,7 +21,7 @@ from classifier import SentimentDataset, BertSentimentClassifier
|
|
| 20 |
TQDM_DISABLE = False
|
| 21 |
|
| 22 |
|
| 23 |
-
class
|
| 24 |
def __init__(self, dataset, args):
|
| 25 |
self.dataset = dataset
|
| 26 |
self.p = args
|
|
@@ -31,19 +32,22 @@ class TwitterDataset(Dataset):
|
|
| 31 |
def __getitem__(self, idx):
|
| 32 |
return self.dataset[idx]
|
| 33 |
|
| 34 |
-
def pad_data(self,
|
|
|
|
|
|
|
| 35 |
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
| 36 |
token_ids = torch.LongTensor(encoding['input_ids'])
|
| 37 |
attension_mask = torch.LongTensor(encoding['attention_mask'])
|
| 38 |
|
| 39 |
-
return token_ids, attension_mask
|
| 40 |
|
| 41 |
-
def collate_fn(self,
|
| 42 |
-
token_ids, attention_mask = self.pad_data(
|
| 43 |
|
| 44 |
batched_data = {
|
| 45 |
'token_ids': token_ids,
|
| 46 |
'attention_mask': attention_mask,
|
|
|
|
| 47 |
}
|
| 48 |
|
| 49 |
return batched_data
|
|
@@ -51,36 +55,36 @@ class TwitterDataset(Dataset):
|
|
| 51 |
|
| 52 |
def load_data(filename, flag='train'):
|
| 53 |
'''
|
| 54 |
-
- for
|
| 55 |
-
- for
|
|
|
|
| 56 |
'''
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
if flag == 'twitter':
|
| 62 |
-
for cnt, record in enumerate(csv.DictReader(fp, delimiter = ',')):
|
| 63 |
-
sent = record['clean_text'].lower().strip()
|
| 64 |
-
data.append(sent)
|
| 65 |
-
if cnt == 10000: break
|
| 66 |
-
elif flag == 'test':
|
| 67 |
-
for record in csv.DictReader(fp, delimiter = '\t'):
|
| 68 |
-
sent = record['sentence'].lower().strip()
|
| 69 |
-
sent_id = record['id'].lower().strip()
|
| 70 |
-
data.append((sent,sent_id))
|
| 71 |
-
else:
|
| 72 |
-
for record in csv.DictReader(fp, delimiter = '\t'):
|
| 73 |
-
sent = record['sentence'].lower().strip()
|
| 74 |
-
sent_id = record['id'].lower().strip()
|
| 75 |
-
label = int(record['sentiment'].strip())
|
| 76 |
-
num_labels.add(label)
|
| 77 |
-
data.append((sent, label, sent_id))
|
| 78 |
-
print(f"load {len(data)} data from {filename}")
|
| 79 |
-
|
| 80 |
-
if flag == 'train':
|
| 81 |
-
return data, len(num_labels)
|
| 82 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
return data
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
def save_model(model, optimizer, args, config, filepath):
|
|
@@ -98,11 +102,6 @@ def save_model(model, optimizer, args, config, filepath):
|
|
| 98 |
print(f"save the model to {filepath}")
|
| 99 |
|
| 100 |
|
| 101 |
-
# def model_eval(dataloader, model, device):
|
| 102 |
-
# model.eval()
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
def contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05):
|
| 107 |
'''
|
| 108 |
embeds_1: [batch_size, hidden_size]
|
|
@@ -131,7 +130,7 @@ def train(args):
|
|
| 131 |
'''
|
| 132 |
Training Pipeline
|
| 133 |
-----------------
|
| 134 |
-
1. Load the
|
| 135 |
2. Determine batch_size (64) and number of batches (?).
|
| 136 |
3. Initialize SentimentClassifier (including bert).
|
| 137 |
4. Looping through 10 epoches.
|
|
@@ -142,16 +141,16 @@ def train(args):
|
|
| 142 |
9. If dev_acc > best_dev_acc: save_model(...)
|
| 143 |
'''
|
| 144 |
|
| 145 |
-
|
| 146 |
train_data, num_labels = load_data(args.train, 'train')
|
| 147 |
dev_data = load_data(args.dev, 'valid')
|
| 148 |
|
| 149 |
-
|
| 150 |
train_dataset = SentimentDataset(train_data, args)
|
| 151 |
dev_dataset = SentimentDataset(dev_data, args)
|
| 152 |
|
| 153 |
-
|
| 154 |
-
num_workers=args.num_cpu_cores, collate_fn=
|
| 155 |
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier,
|
| 156 |
num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
|
| 157 |
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier,
|
|
@@ -177,7 +176,7 @@ def train(args):
|
|
| 177 |
for epoch in range(args.epochs):
|
| 178 |
model.bert.train()
|
| 179 |
train_loss = num_batches = 0
|
| 180 |
-
for batch in tqdm(
|
| 181 |
b_ids, b_mask = batch['token_ids'], batch['attention_mask']
|
| 182 |
b_ids = b_ids.to(device)
|
| 183 |
b_mask = b_mask.to(device)
|
|
@@ -189,11 +188,13 @@ def train(args):
|
|
| 189 |
# Calculate mean SimCSE loss function
|
| 190 |
loss = contrastive_loss(logits_1, logits_2)
|
| 191 |
|
|
|
|
|
|
|
| 192 |
loss.backward()
|
| 193 |
optimizer_cse.step()
|
| 194 |
|
| 195 |
train_loss += loss.item()
|
| 196 |
-
num_batches +=
|
| 197 |
|
| 198 |
train_loss = train_loss / num_batches
|
| 199 |
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}")
|
|
@@ -205,11 +206,12 @@ def get_args():
|
|
| 205 |
parser.add_argument("--num-cpu-cores", type=int, default=4)
|
| 206 |
parser.add_argument("--epochs", type=int, default=10)
|
| 207 |
parser.add_argument("--use_gpu", action='store_true')
|
| 208 |
-
parser.add_argument("--batch_size_cse",
|
| 209 |
-
parser.add_argument("--
|
|
|
|
| 210 |
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
|
| 211 |
-
parser.add_argument("--lr_cse", default=2e-5)
|
| 212 |
-
parser.add_argument("--lr_classifier", default=1e-5)
|
| 213 |
|
| 214 |
args = parser.parse_args()
|
| 215 |
return args
|
|
@@ -229,9 +231,9 @@ if __name__ == "__main__":
|
|
| 229 |
use_gpu=args.use_gpu,
|
| 230 |
epochs=args.epochs,
|
| 231 |
batch_size_cse=args.batch_size_cse,
|
| 232 |
-
batch_size_classifier=args.
|
| 233 |
hidden_dropout_prob=args.hidden_dropout_prob,
|
| 234 |
-
train_bert='data/
|
| 235 |
train='data/ids-sst-train.csv',
|
| 236 |
dev='data/ids-sst-dev.csv',
|
| 237 |
test='data/ids-sst-test-student.csv'
|
|
|
|
| 3 |
import random
|
| 4 |
import argparse
|
| 5 |
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
| 9 |
from tqdm import tqdm
|
|
|
|
| 21 |
TQDM_DISABLE = False
|
| 22 |
|
| 23 |
|
| 24 |
+
class AmazonDataset(Dataset):
|
| 25 |
def __init__(self, dataset, args):
|
| 26 |
self.dataset = dataset
|
| 27 |
self.p = args
|
|
|
|
| 32 |
def __getitem__(self, idx):
|
| 33 |
return self.dataset[idx]
|
| 34 |
|
| 35 |
+
def pad_data(self, data):
|
| 36 |
+
sents = [x[0] for x in data]
|
| 37 |
+
sent_ids = [x[1] for x in data]
|
| 38 |
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
| 39 |
token_ids = torch.LongTensor(encoding['input_ids'])
|
| 40 |
attension_mask = torch.LongTensor(encoding['attention_mask'])
|
| 41 |
|
| 42 |
+
return token_ids, attension_mask, sent_ids
|
| 43 |
|
| 44 |
+
def collate_fn(self, data):
|
| 45 |
+
token_ids, attention_mask, sent_ids = self.pad_data(data)
|
| 46 |
|
| 47 |
batched_data = {
|
| 48 |
'token_ids': token_ids,
|
| 49 |
'attention_mask': attention_mask,
|
| 50 |
+
'sent_ids': sent_ids
|
| 51 |
}
|
| 52 |
|
| 53 |
return batched_data
|
|
|
|
| 55 |
|
| 56 |
def load_data(filename, flag='train'):
|
| 57 |
'''
|
| 58 |
+
- for amazon dataset: list of (sent, sent_id)
|
| 59 |
+
- for test dataset: list of (sent, sent_id)
|
| 60 |
+
- for train dataset: list of (sent, label, sent_id)
|
| 61 |
'''
|
| 62 |
|
| 63 |
+
if flag == 'amazon':
|
| 64 |
+
df = pd.read_parquet(filename)
|
| 65 |
+
data = list(zip(df['content'], df.index))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
else:
|
| 67 |
+
data, num_labels = [], set()
|
| 68 |
+
|
| 69 |
+
with open(filename, 'r') as fp:
|
| 70 |
+
if flag == 'test':
|
| 71 |
+
for record in csv.DictReader(fp, delimiter = '\t'):
|
| 72 |
+
sent = record['sentence'].lower().strip()
|
| 73 |
+
sent_id = record['id'].lower().strip()
|
| 74 |
+
data.append((sent,sent_id))
|
| 75 |
+
else:
|
| 76 |
+
for record in csv.DictReader(fp, delimiter = '\t'):
|
| 77 |
+
sent = record['sentence'].lower().strip()
|
| 78 |
+
sent_id = record['id'].lower().strip()
|
| 79 |
+
label = int(record['sentiment'].strip())
|
| 80 |
+
num_labels.add(label)
|
| 81 |
+
data.append((sent, label, sent_id))
|
| 82 |
+
|
| 83 |
+
print(f"load {len(data)} data from {filename}")
|
| 84 |
+
if flag in ['test', 'amazon']:
|
| 85 |
return data
|
| 86 |
+
else:
|
| 87 |
+
return data, len(num_labels)
|
| 88 |
|
| 89 |
|
| 90 |
def save_model(model, optimizer, args, config, filepath):
|
|
|
|
| 102 |
print(f"save the model to {filepath}")
|
| 103 |
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
def contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05):
|
| 106 |
'''
|
| 107 |
embeds_1: [batch_size, hidden_size]
|
|
|
|
| 130 |
'''
|
| 131 |
Training Pipeline
|
| 132 |
-----------------
|
| 133 |
+
1. Load the Amazon Polarity and SST Dataset.
|
| 134 |
2. Determine batch_size (64) and number of batches (?).
|
| 135 |
3. Initialize SentimentClassifier (including bert).
|
| 136 |
4. Looping through 10 epoches.
|
|
|
|
| 141 |
9. If dev_acc > best_dev_acc: save_model(...)
|
| 142 |
'''
|
| 143 |
|
| 144 |
+
amazon_data = load_data(args.train_bert, 'amazon')
|
| 145 |
train_data, num_labels = load_data(args.train, 'train')
|
| 146 |
dev_data = load_data(args.dev, 'valid')
|
| 147 |
|
| 148 |
+
amazon_dataset = AmazonDataset(amazon_data, args)
|
| 149 |
train_dataset = SentimentDataset(train_data, args)
|
| 150 |
dev_dataset = SentimentDataset(dev_data, args)
|
| 151 |
|
| 152 |
+
amazon_dataloader = DataLoader(amazon_dataset, shuffle=True, batch_size=args.batch_size_cse,
|
| 153 |
+
num_workers=args.num_cpu_cores, collate_fn=amazon_dataset.collate_fn)
|
| 154 |
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier,
|
| 155 |
num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
|
| 156 |
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier,
|
|
|
|
| 176 |
for epoch in range(args.epochs):
|
| 177 |
model.bert.train()
|
| 178 |
train_loss = num_batches = 0
|
| 179 |
+
for batch in tqdm(amazon_dataloader, f'train-amazon-{epoch}', leave=False, disable=TQDM_DISABLE):
|
| 180 |
b_ids, b_mask = batch['token_ids'], batch['attention_mask']
|
| 181 |
b_ids = b_ids.to(device)
|
| 182 |
b_mask = b_mask.to(device)
|
|
|
|
| 188 |
# Calculate mean SimCSE loss function
|
| 189 |
loss = contrastive_loss(logits_1, logits_2)
|
| 190 |
|
| 191 |
+
# Back propagation
|
| 192 |
+
optimizer_cse.zero_grad()
|
| 193 |
loss.backward()
|
| 194 |
optimizer_cse.step()
|
| 195 |
|
| 196 |
train_loss += loss.item()
|
| 197 |
+
num_batches += 1
|
| 198 |
|
| 199 |
train_loss = train_loss / num_batches
|
| 200 |
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}")
|
|
|
|
| 206 |
parser.add_argument("--num-cpu-cores", type=int, default=4)
|
| 207 |
parser.add_argument("--epochs", type=int, default=10)
|
| 208 |
parser.add_argument("--use_gpu", action='store_true')
|
| 209 |
+
parser.add_argument("--batch_size_cse", type=int, default=8)
|
| 210 |
+
parser.add_argument("--batch_size_sst", type=int, default=64)
|
| 211 |
+
parser.add_argument("--batch_size_cfimdb", type=int, default=8)
|
| 212 |
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
|
| 213 |
+
parser.add_argument("--lr_cse", type=float, default=2e-5)
|
| 214 |
+
parser.add_argument("--lr_classifier", type=float, default=1e-5)
|
| 215 |
|
| 216 |
args = parser.parse_args()
|
| 217 |
return args
|
|
|
|
| 231 |
use_gpu=args.use_gpu,
|
| 232 |
epochs=args.epochs,
|
| 233 |
batch_size_cse=args.batch_size_cse,
|
| 234 |
+
batch_size_classifier=args.batch_size_sst,
|
| 235 |
hidden_dropout_prob=args.hidden_dropout_prob,
|
| 236 |
+
train_bert='data/amazon-polarity.parquet',
|
| 237 |
train='data/ids-sst-train.csv',
|
| 238 |
dev='data/ids-sst-dev.csv',
|
| 239 |
test='data/ids-sst-test-student.csv'
|