import torch from torch import nn import torch.nn.functional as F from torch.nn import BCEWithLogitsLoss from transformers import AutoModel, AutoConfig def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] #First element of model_output contains all token embeddings input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) class BaseEncoder(nn.Module): def __init__(self, len_tokenizer, model='huawei-noah/TinyBERT_General_4L_312D'): super().__init__() self.transformer = AutoModel.from_pretrained(model) self.transformer.resize_token_embeddings(len_tokenizer) def forward(self, input_ids, attention_mask): output = self.transformer(input_ids, attention_mask) return output # self-supervised contrastive model class ContrastiveSelfSupervisedPretrainModel(nn.Module): def __init__(self, len_tokenizer, model='huawei-noah/TinyBERT_General_4L_312D', ssv=True, pool=False, proj='mlp', temperature=0.07, num_augments=2): super().__init__() self.ssv = ssv self.pool = pool self.proj = proj self.temperature = temperature self.num_augments = num_augments self.criterion = SupConLoss(self.temperature) self.encoder = BaseEncoder(len_tokenizer, model) self.config = self.encoder.transformer.config self.contrastive_head = ContrastivePretrainHead(self.config.hidden_size, self.proj) def forward(self, input_ids, attention_mask, labels): additional_outputs = [] if self.pool: output_left = self.encoder(input_ids, attention_mask) output_left = mean_pooling(output_left, attention_mask) for num in range(self.num_augments-1): output_right = self.encoder(input_ids, attention_mask) output_right = mean_pooling(output_right, attention).unsqueeze(1) additional_outputs.append(output_right) else: output_left = self.encoder(input_ids, attention_mask)['pooler_output'].unsqueeze(1) for num in range(self.num_augments-1): additional_outputs.append(self.encoder(input_ids, attention_mask)['pooler_output'].unsqueeze(1)) output = torch.cat((output_left, *additional_outputs), 1) output = F.normalize(output, dim=-1) proj_output = self.contrastive_head(output) proj_output = F.normalize(proj_output, dim=-1) if self.ssv: loss = self.criterion(proj_output) else: loss = self.criterion(proj_output, labels) return ((loss,)) # supervised contrastive model class ContrastivePretrainModel(nn.Module): def __init__(self, len_tokenizer, model='huawei-noah/TinyBERT_General_4L_312D', pool=True, proj='mlp', temperature=0.07): super().__init__() self.pool = pool self.proj = proj self.temperature = temperature self.criterion = SupConLoss(self.temperature) self.encoder = BaseEncoder(len_tokenizer, model) self.config = self.encoder.transformer.config def forward(self, input_ids, attention_mask, labels, input_ids_right, attention_mask_right): if self.pool: output_left = self.encoder(input_ids, attention_mask) output_left = mean_pooling(output_left, attention_mask) output_right = self.encoder(input_ids_right, attention_mask_right) output_right = mean_pooling(output_right, attention_mask_right) else: output_left = self.encoder(input_ids, attention_mask)['pooler_output'] output_right = self.encoder(input_ids_right, attention_mask_right)['pooler_output'] output = torch.cat((output_left.unsqueeze(1), output_right.unsqueeze(1)), 1) output = F.normalize(output, dim=-1) loss = self.criterion(output, labels) return ((loss,)) class ContrastivePretrainHead(nn.Module): def __init__(self, hidden_size, proj='mlp'): super().__init__() if proj == 'linear': self.proj = nn.Linear(hidden_size, hidden_size) elif proj == 'mlp': self.proj = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size) ) def forward(self, hidden_states): x = self.proj(hidden_states) return x # cross-entropy fine-tuning model class ContrastiveClassifierModel(nn.Module): def __init__(self, len_tokenizer, checkpoint_path, model='huawei-noah/TinyBERT_General_4L_312D', pool=True, comb_fct='concat-abs-diff-mult', frozen=True, pos_neg=False): super().__init__() self.pool = pool self.frozen = frozen self.checkpoint_path = checkpoint_path self.comb_fct = comb_fct self.pos_neg = pos_neg self.encoder = BaseEncoder(len_tokenizer, model) self.config = self.encoder.transformer.config if self.pos_neg: self.criterion = BCEWithLogitsLoss(pos_weight=torch.Tensor([pos_neg])) else: self.criterion = BCEWithLogitsLoss() self.classification_head = ClassificationHead(self.config, self.comb_fct) if self.checkpoint_path: checkpoint = torch.load(self.checkpoint_path) self.load_state_dict(checkpoint, strict=False) if self.frozen: for param in self.encoder.parameters(): param.requires_grad = False def forward(self, input_ids, attention_mask, labels, input_ids_right, attention_mask_right): if self.pool: output_left = self.encoder(input_ids, attention_mask) output_left = mean_pooling(output_left, attention_mask) output_right = self.encoder(input_ids_right, attention_mask_right) output_right = mean_pooling(output_right, attention_mask_right) else: output_left = self.encoder(input_ids, attention_mask)['pooler_output'] output_right = self.encoder(input_ids_right, attention_mask_right)['pooler_output'] if self.comb_fct == 'concat-abs-diff': output = torch.cat((output_left, output_right, torch.abs(output_left - output_right)), -1) elif self.comb_fct == 'concat-mult': output = torch.cat((output_left, output_right, output_left * output_right), -1) elif self.comb_fct == 'concat': output = torch.cat((output_left, output_right), -1) elif self.comb_fct == 'abs-diff': output = torch.abs(output_left - output_right) elif self.comb_fct == 'mult': output = output_left * output_right elif self.comb_fct == 'abs-diff-mult': output = torch.cat((torch.abs(output_left - output_right), output_left * output_right), -1) elif self.comb_fct == 'concat-abs-diff-mult': output = torch.cat((output_left, output_right, torch.abs(output_left - output_right), output_left * output_right), -1) proj_output = self.classification_head(output) loss = self.criterion(proj_output.view(-1), labels.float()) proj_output = torch.sigmoid(proj_output) return (loss, proj_output) class ClassificationHead(nn.Module): def __init__(self, config, comb_fct): super().__init__() if comb_fct in ['concat-abs-diff', 'concat-mult']: self.hidden_size = 3 * config.hidden_size elif comb_fct in ['concat', 'abs-diff-mult']: self.hidden_size = 2 * config.hidden_size elif comb_fct in ['abs-diff', 'mult']: self.hidden_size = config.hidden_size elif comb_fct in ['concat-abs-diff-mult']: self.hidden_size = 4 * config.hidden_size classifier_dropout = config.hidden_dropout_prob self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(self.hidden_size, 1) def forward(self, features): x = self.dropout(features) x = self.out_proj(x) return x class SupConLoss(nn.Module): """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. It also supports the unsupervised contrastive loss in SimCLR""" def __init__(self, temperature=0.07, contrast_mode='all', base_temperature=0.07): super(SupConLoss, self).__init__() self.temperature = temperature self.contrast_mode = contrast_mode self.base_temperature = base_temperature def forward(self, features, labels=None, mask=None): """Compute loss for model. If both `labels` and `mask` are None, it degenerates to SimCLR unsupervised loss: https://arxiv.org/pdf/2002.05709.pdf Args: features: hidden vector of shape [bsz, n_views, ...]. labels: ground truth of shape [bsz]. mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j has the same class as sample i. Can be asymmetric. Returns: A loss scalar. """ device = (torch.device('cuda') if features.is_cuda else torch.device('cpu')) if len(features.shape) < 3: raise ValueError('`features` needs to be [bsz, n_views, ...],' 'at least 3 dimensions are required') if len(features.shape) > 3: features = features.view(features.shape[0], features.shape[1], -1) batch_size = features.shape[0] if labels is not None and mask is not None: raise ValueError('Cannot define both `labels` and `mask`') elif labels is None and mask is None: mask = torch.eye(batch_size, dtype=torch.float32).to(device) elif labels is not None: labels = labels.contiguous().view(-1, 1) if labels.shape[0] != batch_size: raise ValueError('Num of labels does not match num of features') mask = torch.eq(labels, labels.T).float().to(device) else: mask = mask.float().to(device) contrast_count = features.shape[1] contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) if self.contrast_mode == 'one': anchor_feature = features[:, 0] anchor_count = 1 elif self.contrast_mode == 'all': anchor_feature = contrast_feature anchor_count = contrast_count else: raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) # compute logits anchor_dot_contrast = torch.div( torch.matmul(anchor_feature, contrast_feature.T), self.temperature) # for numerical stability logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) logits = anchor_dot_contrast - logits_max.detach() # tile mask mask = mask.repeat(anchor_count, contrast_count) # mask-out self-contrast cases logits_mask = torch.scatter( torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0 ) mask = mask * logits_mask # compute log_prob exp_logits = torch.exp(logits) * logits_mask log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) # compute mean of log-likelihood over positive mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) # loss loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos loss = loss.view(anchor_count, batch_size).mean() return loss