supcon / code /src /modeling.py
IGandarillas1's picture
Add model contrastive classifier
fc1c2b8
raw
history blame
11.9 kB
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss
from transformers import AutoModel, AutoConfig
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class BaseEncoder(nn.Module):
def __init__(self, len_tokenizer, model='huawei-noah/TinyBERT_General_4L_312D'):
super().__init__()
self.transformer = AutoModel.from_pretrained(model)
self.transformer.resize_token_embeddings(len_tokenizer)
def forward(self, input_ids, attention_mask):
output = self.transformer(input_ids, attention_mask)
return output
# self-supervised contrastive model
class ContrastiveSelfSupervisedPretrainModel(nn.Module):
def __init__(self, len_tokenizer, model='huawei-noah/TinyBERT_General_4L_312D', ssv=True, pool=False, proj='mlp', temperature=0.07, num_augments=2):
super().__init__()
self.ssv = ssv
self.pool = pool
self.proj = proj
self.temperature = temperature
self.num_augments = num_augments
self.criterion = SupConLoss(self.temperature)
self.encoder = BaseEncoder(len_tokenizer, model)
self.config = self.encoder.transformer.config
self.contrastive_head = ContrastivePretrainHead(self.config.hidden_size, self.proj)
def forward(self, input_ids, attention_mask, labels):
additional_outputs = []
if self.pool:
output_left = self.encoder(input_ids, attention_mask)
output_left = mean_pooling(output_left, attention_mask)
for num in range(self.num_augments-1):
output_right = self.encoder(input_ids, attention_mask)
output_right = mean_pooling(output_right, attention).unsqueeze(1)
additional_outputs.append(output_right)
else:
output_left = self.encoder(input_ids, attention_mask)['pooler_output'].unsqueeze(1)
for num in range(self.num_augments-1):
additional_outputs.append(self.encoder(input_ids, attention_mask)['pooler_output'].unsqueeze(1))
output = torch.cat((output_left, *additional_outputs), 1)
output = F.normalize(output, dim=-1)
proj_output = self.contrastive_head(output)
proj_output = F.normalize(proj_output, dim=-1)
if self.ssv:
loss = self.criterion(proj_output)
else:
loss = self.criterion(proj_output, labels)
return ((loss,))
# supervised contrastive model
class ContrastivePretrainModel(nn.Module):
def __init__(self, len_tokenizer, model='huawei-noah/TinyBERT_General_4L_312D', pool=True, proj='mlp', temperature=0.07):
super().__init__()
self.pool = pool
self.proj = proj
self.temperature = temperature
self.criterion = SupConLoss(self.temperature)
self.encoder = BaseEncoder(len_tokenizer, model)
self.config = self.encoder.transformer.config
def forward(self, input_ids, attention_mask, labels, input_ids_right, attention_mask_right):
if self.pool:
output_left = self.encoder(input_ids, attention_mask)
output_left = mean_pooling(output_left, attention_mask)
output_right = self.encoder(input_ids_right, attention_mask_right)
output_right = mean_pooling(output_right, attention_mask_right)
else:
output_left = self.encoder(input_ids, attention_mask)['pooler_output']
output_right = self.encoder(input_ids_right, attention_mask_right)['pooler_output']
output = torch.cat((output_left.unsqueeze(1), output_right.unsqueeze(1)), 1)
output = F.normalize(output, dim=-1)
loss = self.criterion(output, labels)
return ((loss,))
class ContrastivePretrainHead(nn.Module):
def __init__(self, hidden_size, proj='mlp'):
super().__init__()
if proj == 'linear':
self.proj = nn.Linear(hidden_size, hidden_size)
elif proj == 'mlp':
self.proj = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size)
)
def forward(self, hidden_states):
x = self.proj(hidden_states)
return x
# cross-entropy fine-tuning model
class ContrastiveClassifierModel(nn.Module):
def __init__(self, len_tokenizer, checkpoint_path, model='huawei-noah/TinyBERT_General_4L_312D', pool=True, comb_fct='concat-abs-diff-mult', frozen=True, pos_neg=False):
super().__init__()
self.pool = pool
self.frozen = frozen
self.checkpoint_path = checkpoint_path
self.comb_fct = comb_fct
self.pos_neg = pos_neg
self.encoder = BaseEncoder(len_tokenizer, model)
self.config = self.encoder.transformer.config
if self.pos_neg:
self.criterion = BCEWithLogitsLoss(pos_weight=torch.Tensor([pos_neg]))
else:
self.criterion = BCEWithLogitsLoss()
self.classification_head = ClassificationHead(self.config, self.comb_fct)
if self.checkpoint_path:
checkpoint = torch.load(self.checkpoint_path)
self.load_state_dict(checkpoint, strict=False)
if self.frozen:
for param in self.encoder.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask, labels, input_ids_right, attention_mask_right):
if self.pool:
output_left = self.encoder(input_ids, attention_mask)
output_left = mean_pooling(output_left, attention_mask)
output_right = self.encoder(input_ids_right, attention_mask_right)
output_right = mean_pooling(output_right, attention_mask_right)
else:
output_left = self.encoder(input_ids, attention_mask)['pooler_output']
output_right = self.encoder(input_ids_right, attention_mask_right)['pooler_output']
if self.comb_fct == 'concat-abs-diff':
output = torch.cat((output_left, output_right, torch.abs(output_left - output_right)), -1)
elif self.comb_fct == 'concat-mult':
output = torch.cat((output_left, output_right, output_left * output_right), -1)
elif self.comb_fct == 'concat':
output = torch.cat((output_left, output_right), -1)
elif self.comb_fct == 'abs-diff':
output = torch.abs(output_left - output_right)
elif self.comb_fct == 'mult':
output = output_left * output_right
elif self.comb_fct == 'abs-diff-mult':
output = torch.cat((torch.abs(output_left - output_right), output_left * output_right), -1)
elif self.comb_fct == 'concat-abs-diff-mult':
output = torch.cat((output_left, output_right, torch.abs(output_left - output_right), output_left * output_right), -1)
proj_output = self.classification_head(output)
loss = self.criterion(proj_output.view(-1), labels.float())
proj_output = torch.sigmoid(proj_output)
return (loss, proj_output)
class ClassificationHead(nn.Module):
def __init__(self, config, comb_fct):
super().__init__()
if comb_fct in ['concat-abs-diff', 'concat-mult']:
self.hidden_size = 3 * config.hidden_size
elif comb_fct in ['concat', 'abs-diff-mult']:
self.hidden_size = 2 * config.hidden_size
elif comb_fct in ['abs-diff', 'mult']:
self.hidden_size = config.hidden_size
elif comb_fct in ['concat-abs-diff-mult']:
self.hidden_size = 4 * config.hidden_size
classifier_dropout = config.hidden_dropout_prob
self.dropout = nn.Dropout(classifier_dropout)
self.out_proj = nn.Linear(self.hidden_size, 1)
def forward(self, features):
x = self.dropout(features)
x = self.out_proj(x)
return x
class SupConLoss(nn.Module):
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
It also supports the unsupervised contrastive loss in SimCLR"""
def __init__(self, temperature=0.07, contrast_mode='all',
base_temperature=0.07):
super(SupConLoss, self).__init__()
self.temperature = temperature
self.contrast_mode = contrast_mode
self.base_temperature = base_temperature
def forward(self, features, labels=None, mask=None):
"""Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf
Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
"""
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
if len(features.shape) < 3:
raise ValueError('`features` needs to be [bsz, n_views, ...],'
'at least 3 dimensions are required')
if len(features.shape) > 3:
features = features.view(features.shape[0], features.shape[1], -1)
batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None:
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)
contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if self.contrast_mode == 'one':
anchor_feature = features[:, 0]
anchor_count = 1
elif self.contrast_mode == 'all':
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
self.temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return loss