import logging import numpy as np import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers import (BartConfig, BartForConditionalGeneration, BartTokenizer, RobertaConfig, RobertaModel, RobertaTokenizer, T5Config, T5ForConditionalGeneration, T5Model, T5Tokenizer) from torch.nn.functional import one_hot from torch import Tensor from typing import Union logger = logging.getLogger(__name__) def get_model_size(model): model_parameters = filter(lambda p: p.requires_grad, model.parameters()) model_size = sum([np.prod(p.size()) for p in model_parameters]) return "{}M".format(round(model_size / 1e+6)) def enrich_vocab( args, tokenizer, config, load_extra_ids=True ): add_tokens = ["", "", "", "", "", "", "", "", "", "", "", "", ""] add_token_ids = [ tok for tok in add_tokens if tok not in tokenizer.get_vocab()] if add_token_ids: tokenizer.add_special_tokens( {"additional_special_tokens": add_token_ids} ) if load_extra_ids is True: tokenizer.add_special_tokens( { "additional_special_tokens": [ "".format(i) for i in range(99, -1, -1) ] } ) tokenizer.add_special_tokens( { "additional_special_tokens": [ "".format(i) for i in range(99, -1, -1) ] } ) tokenizer.add_special_tokens({"additional_special_tokens": [""]}) langs = [ "", "", "", "", "", "", "", "", "", "", ] # Add DFG ids # tokenizer.add_special_tokens( # { # "additional_special_tokens": [ # "".format(i) for i in range(99, -1, -1) # ] # } # ) # tokenizer.add_special_tokens({"additional_special_tokens": ["", "", "comesFrom", "computedFrom"]}) # tokenizer.add_special_tokens({"additional_special_tokens": ["", ""]}) if args.add_lang_ids: tokenizer.add_special_tokens( { "additional_special_tokens": langs } ) config.lang_id = { lang: tokenizer.get_vocab()[lang] for lang in langs } config.vocab_size = len(tokenizer) config.bos_token_id = tokenizer.get_vocab()[""] config.pad_token_id = tokenizer.get_vocab()[""] config.eos_token_id = tokenizer.get_vocab()[""] config.mask_token_id = tokenizer.get_vocab()[""] config.keep_token_id = tokenizer.get_vocab()[""] config.add_token_id = tokenizer.get_vocab()[""] config.del_token_id = tokenizer.get_vocab()[""] config.start_token_id = tokenizer.get_vocab()[""] config.end_token_id = tokenizer.get_vocab()[""] config.lang_tokens = langs tokenizer.special_dict = { f"": tokenizer.get_vocab()[f""] for i in range(99, -1, -1) } tokenizer.mask_id = tokenizer.get_vocab()[""] tokenizer.bos_id = tokenizer.get_vocab()[""] tokenizer.pad_id = tokenizer.get_vocab()[""] tokenizer.eos_id = tokenizer.get_vocab()[""] tokenizer.msg_id = tokenizer.get_vocab()[""] tokenizer.keep_id = tokenizer.get_vocab()[""] tokenizer.add_id = tokenizer.get_vocab()[""] tokenizer.del_id = tokenizer.get_vocab()[""] tokenizer.start_id = tokenizer.get_vocab()[""] tokenizer.end_id = tokenizer.get_vocab()[""] # tokenizer.edge_id = tokenizer.get_vocab()[""] return tokenizer, config def build_or_load_gen_model(args, load_model=True): config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config = config_class.from_pretrained( args.config_name if args.config_name else args.model_name_or_path) if not args.tokenizer_name: # default codet5 tokenizer tokenizer_name = "Salesforce/codet5-base" tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name) if not args.model_name_or_path: if args.model_type == "codet5": args.model_name_or_path = "Salesforce/codet5-base" else: args.model_name_or_path = "t5-base" if args.model_type == 'roberta': encoder = model_class.from_pretrained( args.model_name_or_path, config=config) decoder_layer = nn.TransformerDecoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads) decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) model = Seq2Seq(encoder=encoder, decoder=decoder, config=config, beam_size=args.beam_size, max_length=args.max_target_length, sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id) elif args.from_scratch is True: model = model_class(config) else: model = model_class.from_pretrained(args.model_name_or_path) tokenizer, config = enrich_vocab(args, tokenizer, config) model.config = config # update the config in model model.resize_token_embeddings(len(tokenizer)) logger.info( "Finish loading model [%s] from %s", get_model_size(model), args.model_name_or_path ) model.to(args.device) if args.load_model_path is not None and load_model is True: logger.info("Load model from {}".format(args.load_model_path)) model.load_state_dict(torch.load( args.load_model_path, map_location="cpu")) logger.info("Model from {} has been loaded".format(args.load_model_path)) model.to(args.device) return config, model, tokenizer class RobertaClassificationHeadMF(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, config): super().__init__() self.manual_dense = nn.Linear(14, config.d_model) self.dropout = nn.Dropout(0.1) self.out_proj_new = nn.Linear( config.d_model + config.d_model, 2, bias=True) # self.out_proj_new = nn.Linear(config.hidden_size + config.hidden_size, 1) def forward(self, features, manual_features=None, **kwargs): # take token (equiv. to [CLS]) [bs,hidden_size] x = features[:, 0, :] y = manual_features.float() # [bs, feature_size] y = self.manual_dense(y) y = torch.tanh(y) x = torch.cat((x, y), dim=-1) x = self.dropout(x) x = self.out_proj_new(x) return x class RobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, config): super().__init__() self.dense = nn.Linear(config.d_model * 2, config.d_model) self.out_proj = nn.Linear(config.d_model, 2) def forward(self, x, **kwargs): x = x.reshape(-1, x.size(-1) * 2) x = self.dense(x) x = torch.tanh(x) x = self.out_proj(x) return x class CloneModel(nn.Module): def __init__(self, encoder, config, tokenizer, args): super(CloneModel, self).__init__() self.encoder = encoder self.config = config self.tokenizer = tokenizer self.classifier = RobertaClassificationHead(config) self.args = args def get_t5_vec(self, source_ids): attention_mask = source_ids.ne(self.tokenizer.pad_token_id) outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask, labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True) hidden_states = outputs['decoder_hidden_states'][-1] eos_mask = source_ids.eq(self.config.eos_token_id) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError( "All examples must have the same number of tokens.") vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[:, -1, :] return vec def get_bart_vec(self, source_ids): attention_mask = source_ids.ne(self.tokenizer.pad_token_id) outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask, labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True) hidden_states = outputs['decoder_hidden_states'][-1] eos_mask = source_ids.eq(self.config.eos_token_id) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError( "All examples must have the same number of tokens.") vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[:, -1, :] return vec def get_roberta_vec(self, source_ids): attention_mask = source_ids.ne(self.tokenizer.pad_token_id) vec = self.encoder(input_ids=source_ids, attention_mask=attention_mask)[0][:, 0, :] return vec def forward(self, source_ids=None, labels=None): source_ids = source_ids.view(-1, self.args.max_source_length) if self.args.model_type == 'codet5': vec = self.get_t5_vec(source_ids) elif self.args.model_type == 'bart': vec = self.get_bart_vec(source_ids) elif self.args.model_type == 'roberta': vec = self.get_roberta_vec(source_ids) logits = self.classifier(vec) prob = nn.functional.softmax(logits) if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels) return loss, prob else: return prob class DefectModel(nn.Module): def __init__(self, encoder, config, tokenizer, args): super(DefectModel, self).__init__() self.encoder = encoder self.config = config self.tokenizer = tokenizer self.classifier = nn.Linear(config.hidden_size, 2) self.args = args def get_t5_vec(self, source_ids): attention_mask = source_ids.ne(self.tokenizer.pad_token_id) outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask, labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True) hidden_states = outputs['decoder_hidden_states'][-1] eos_mask = source_ids.eq(self.config.eos_token_id) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError( "All examples must have the same number of tokens.") vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[:, -1, :] return vec def get_bart_vec(self, source_ids): attention_mask = source_ids.ne(self.tokenizer.pad_token_id) outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask, labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True) hidden_states = outputs['decoder_hidden_states'][-1] eos_mask = source_ids.eq(self.config.eos_token_id) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError( "All examples must have the same number of tokens.") vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[:, -1, :] return vec def get_roberta_vec(self, source_ids): attention_mask = source_ids.ne(self.tokenizer.pad_token_id) vec = self.encoder(input_ids=source_ids, attention_mask=attention_mask)[0][:, 0, :] return vec def forward(self, source_ids=None, labels=None): source_ids = source_ids.view(-1, self.args.max_source_length) if self.args.model_type == 'codet5': vec = self.get_t5_vec(source_ids) elif self.args.model_type == 'bart': vec = self.get_bart_vec(source_ids) elif self.args.model_type == 'roberta': vec = self.get_roberta_vec(source_ids) logits = self.classifier(vec) prob = nn.functional.softmax(logits) if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels) return loss, prob else: return prob # https://github.com/microsoft/CodeBERT/blob/master/CodeBERT/code2nl/model.py class Seq2Seq(nn.Module): """ Build Seqence-to-Sequence. Parameters: * `encoder`- encoder of seq2seq model. e.g. roberta * `decoder`- decoder of seq2seq model. e.g. transformer * `config`- configuration of encoder model. * `beam_size`- beam size for beam search. * `max_length`- max length of target for beam search. * `sos_id`- start of symbol ids in target for beam search. * `eos_id`- end of symbol ids in target for beam search. """ def __init__(self, encoder, decoder, config, beam_size=None, max_length=None, sos_id=None, eos_id=None): super(Seq2Seq, self).__init__() self.encoder = encoder self.decoder = decoder self.config = config self.register_buffer("bias", torch.tril(torch.ones(2048, 2048))) self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.lm_head = nn.Linear( config.hidden_size, config.vocab_size, bias=False) self.lsm = nn.LogSoftmax(dim=-1) self.tie_weights() self.beam_size = beam_size self.max_length = max_length self.sos_id = sos_id self.eos_id = eos_id def _tie_or_clone_weights(self, first_module, second_module): """ Tie or clone module weights depending of weither we are using TorchScript or not """ if self.config.torchscript: first_module.weight = nn.Parameter(second_module.weight.clone()) else: first_module.weight = second_module.weight def tie_weights(self): """ Make sure we are sharing the input and output embeddings. Export to TorchScript can't handle parameter sharing so we are cloning them instead. """ self._tie_or_clone_weights(self.lm_head, self.encoder.embeddings.word_embeddings) def forward(self, input_ids=None, attention_mask=None, labels=None, decoder_attention_mask=None, args=None): """Update some parameters name - Bo. Args: input_ids (_type_, optional): _description_. Defaults to None. attention_mask (_type_, optional): _description_. Defaults to None. labels (_type_, optional): _description_. Defaults to None. target_mask (_type_, optional): _description_. Defaults to None. args (_type_, optional): _description_. Defaults to None. Returns: _type_: _description_ """ outputs = self.encoder(input_ids, attention_mask=attention_mask) encoder_output = outputs[0].permute( [1, 0, 2]).contiguous() # last_hidden_state if labels is not None: attn_mask = -1e4 * \ (1 - self.bias[:labels.shape[1], :labels.shape[1]]) tgt_embeddings = self.encoder.embeddings( labels).permute([1, 0, 2]).contiguous() out = self.decoder(tgt_embeddings, encoder_output, tgt_mask=attn_mask, memory_key_padding_mask=~attention_mask) # memory_key_padding_mask=(1 - source_mask).bool()) hidden_states = torch.tanh(self.dense( out)).permute([1, 0, 2]).contiguous() lm_logits = self.lm_head(hidden_states) # Shift so that tokens < n predict n active_loss = decoder_attention_mask[..., 1:].ne(0).view(-1) == 1 shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss(ignore_index=-1) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss], shift_labels.view(-1)[active_loss]) outputs = loss, loss * active_loss.sum(), active_loss.sum() return outputs else: # Predict preds = [] zero = torch.cuda.LongTensor(1).fill_(0) for i in range(input_ids.shape[0]): context = encoder_output[:, i:i + 1] context_mask = attention_mask[i:i + 1, :] beam = Beam(self.beam_size, self.sos_id, self.eos_id) input_ids = beam.getCurrentState() context = context.repeat(1, self.beam_size, 1) context_mask = context_mask.repeat(self.beam_size, 1) for _ in range(self.max_length): if beam.done(): break attn_mask = -1e4 * \ (1 - self.bias[:input_ids.shape[1], :input_ids.shape[1]]) tgt_embeddings = self.encoder.embeddings( input_ids).permute([1, 0, 2]).contiguous() out = self.decoder(tgt_embeddings, context, tgt_mask=attn_mask, memory_key_padding_mask=~context_mask) # memory_key_padding_mask=(1 - context_mask).bool()) out = torch.tanh(self.dense(out)) hidden_states = out.permute( [1, 0, 2]).contiguous()[:, -1, :] out = self.lsm(self.lm_head(hidden_states)).data beam.advance(out) input_ids.data.copy_(input_ids.data.index_select( 0, beam.getCurrentOrigin())) input_ids = torch.cat( (input_ids, beam.getCurrentState()), -1) hyp = beam.getHyp(beam.getFinal()) pred = beam.buildTargetTokens(hyp)[:self.beam_size] pred = [torch.cat([x.view(-1) for x in p] + [zero] * (self.max_length - len(p))).view(1, -1) for p in pred] preds.append(torch.cat(pred, 0).unsqueeze(0)) preds = torch.cat(preds, 0) return preds class Beam(object): def __init__(self, size, sos, eos): self.size = size self.tt = torch.cuda # The score for each translation on the beam. self.scores = self.tt.FloatTensor(size).zero_() # The backpointers at each time-step. self.prevKs = [] # The outputs at each time-step. self.nextYs = [self.tt.LongTensor(size) .fill_(0)] self.nextYs[0][0] = sos # Has EOS topped the beam yet. self._eos = eos self.eosTop = False # Time and k pair for finished. self.finished = [] def getCurrentState(self): "Get the outputs for the current timestep." batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) return batch def getCurrentOrigin(self): "Get the backpointers for the current timestep." return self.prevKs[-1] def advance(self, wordLk): """ Given prob over words for every last beam `wordLk` and attention `attnOut`: Compute and update the beam search. Parameters: * `wordLk`- probs of advancing from the last step (K x words) * `attnOut`- attention at the last step Returns: True if beam search is complete. """ numWords = wordLk.size(1) # Sum the previous scores. if len(self.prevKs) > 0: beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) # Don't let EOS have children. for i in range(self.nextYs[-1].size(0)): if self.nextYs[-1][i] == self._eos: beamLk[i] = -1e20 else: beamLk = wordLk[0] flatBeamLk = beamLk.view(-1) bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) self.scores = bestScores # bestScoresId is flattened beam x word array, so calculate which # word and beam each score came from prevK = bestScoresId // numWords self.prevKs.append(prevK) self.nextYs.append((bestScoresId - prevK * numWords)) for i in range(self.nextYs[-1].size(0)): if self.nextYs[-1][i] == self._eos: s = self.scores[i] self.finished.append((s, len(self.nextYs) - 1, i)) # End condition is when top-of-beam is EOS and no global score. if self.nextYs[-1][0] == self._eos: self.eosTop = True def done(self): return self.eosTop and len(self.finished) >= self.size def getFinal(self): if len(self.finished) == 0: self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) self.finished.sort(key=lambda a: -a[0]) if len(self.finished) != self.size: unfinished = [] for i in range(self.nextYs[-1].size(0)): if self.nextYs[-1][i] != self._eos: s = self.scores[i] unfinished.append((s, len(self.nextYs) - 1, i)) unfinished.sort(key=lambda a: -a[0]) self.finished += unfinished[:self.size - len(self.finished)] return self.finished[:self.size] def getHyp(self, beam_res): """ Walk back to construct the full hypothesis. """ hyps = [] for _, timestep, k in beam_res: hyp = [] for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): hyp.append(self.nextYs[j + 1][k]) k = self.prevKs[j][k] hyps.append(hyp[::-1]) return hyps def buildTargetTokens(self, preds): sentence = [] for pred in preds: tokens = [] for tok in pred: if tok == self._eos: break tokens.append(tok) sentence.append(tokens) return sentence class CodeChangeModel(T5ForConditionalGeneration): def __init__(self, config): super().__init__(config) self.cls_head = nn.Linear(self.config.d_model, 2, bias=True) self.init() @staticmethod def from_pretrained(path): model = T5ForConditionalGeneration.from_pretrained(path) model.__class__ = CodeChangeModel # change quality estimation model.cls_head = nn.Linear(model.config.d_model, 2, bias=True) model.init() return model def init(self): factor = self.config.initializer_factor self.cls_head.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) self.cls_head.bias.data.zero_() def init_classifier(self): self.cls_head = nn.Linear(self.config.d_model, 2, bias=True) self.cls_head.to(self.encoder.device) def init_MF_classifier(self): self.cls_head = RobertaClassificationHeadMF(self.config) self.cls_head.to(self.encoder.device) def forward( self, *argv, **kwargs ): r""" Doc from Huggingface transformers: labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` Returns: Examples:: >>> from transformers import T5Tokenizer, T5ForConditionalGeneration >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') >>> model = T5ForConditionalGeneration.from_pretrained('t5-small') >>> # training >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids >>> outputs = model(input_ids=input_ids, labels=labels) >>> loss = outputs.loss >>> logits = outputs.logits >>> # inference >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 >>> outputs = model.generate(input_ids) >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> # studies have shown that owning a dog is good for you. """ if "cls" in kwargs: # for classification task assert ( ("input_ids" in kwargs or "old_ids" in kwargs) and "labels" in kwargs and ("attention_mask" in kwargs or "old_attention_mask" in kwargs) ) if "manual_feature" in kwargs: return self.cls_MF( input_ids=kwargs["input_ids"], manual_feature=kwargs["manual_feature"], labels=kwargs["labels"], attention_mask=kwargs["attention_mask"], ) elif "SF" in kwargs: return self.cls_SF( input_ids=kwargs["input_ids"], labels=kwargs["labels"], attention_mask=kwargs["attention_mask"], ) elif "DQE" in kwargs: return self.cls_DQE( input_ids=kwargs["input_ids"], labels=kwargs["labels"], attention_mask=kwargs["attention_mask"], ) else: return self.cls( input_ids=kwargs["input_ids"], labels=kwargs["labels"], attention_mask=kwargs["attention_mask"], ) if "input_labels" in kwargs: # review generation task assert ( "input_ids" in kwargs and "input_labels" in kwargs and "decoder_input_ids" in kwargs and "attention_mask" in kwargs and "decoder_attention_mask" in kwargs ), "Please give these arg keys." input_ids = kwargs["input_ids"] # source idss input_labels = kwargs["input_labels"] decoder_input_ids = kwargs["decoder_input_ids"] # target ids attention_mask = kwargs["attention_mask"] decoder_attention_mask = kwargs["decoder_attention_mask"] if "encoder_loss" not in kwargs: encoder_loss = True else: encoder_loss = kwargs["encoder_loss"] return self.review_forward(input_ids, input_labels, decoder_input_ids, attention_mask, decoder_attention_mask, encoder_loss) return super().forward(*argv, **kwargs) def cls( self, input_ids, labels, attention_mask ): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, return_dict=False ) hidden_states = encoder_outputs[0] first_hidden = hidden_states[:, 0, :] first_hidden = nn.Dropout(0.3)(first_hidden) logits = self.cls_head(first_hidden) loss_fct = CrossEntropyLoss() if labels != None: loss = loss_fct(logits, labels) # version 1 return loss return logits def cls_DQE( self, input_ids, labels, attention_mask ): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, return_dict=False ) hidden_states = encoder_outputs[0] first_hidden = hidden_states[:, 0, :] first_hidden = nn.Dropout(0.3)(first_hidden) logits = self.cls_head(first_hidden) loss_fct = CrossEntropyLoss() if labels != None: loss = loss_fct(logits, labels) # version 1 # loss = loss_fct(m(logits), labels) # version 2 return loss return logits def cls_SF( self, input_ids, labels, attention_mask ): # torch.manual_seed(42) # test # torch.random.manual_seed(42) # test encoder_outputs = self.encoder( \ input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, return_dict=False ) hidden_states = encoder_outputs[0] first_hidden = hidden_states[:, 0, :] first_hidden = nn.Dropout(0.3)(first_hidden) logits = self.cls_head(first_hidden) weights = torch.FloatTensor([1, 3]).to(logits.device) loss_fct = FocalLoss(gamma=2, weights=weights) m = torch.nn.Softmax(dim=-1) if labels != None: loss = loss_fct(m(logits), labels) # version 2 return loss return logits def cls_MF( self, input_ids, manual_feature, labels, attention_mask ): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, return_dict=False ) logits = self.cls_head(encoder_outputs[0], manual_feature) weights = torch.FloatTensor([1, 3]).to(logits.device) loss_fct = FocalLoss(gamma=2, weights=weights) m = torch.nn.Softmax(dim=-1) if labels != None: loss = loss_fct(m(logits), labels) # version 2 return loss return logits def review_forward( self, input_ids, input_labels, decoder_input_ids, attention_mask, decoder_attention_mask, encoder_loss=True ): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, return_dict=False ) hidden_states = encoder_outputs[0] decoder_inputs = self._shift_right(decoder_input_ids) # Decode decoder_outputs = self.decoder( input_ids=decoder_inputs, attention_mask=decoder_attention_mask, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, output_attentions=False, return_dict=False ) sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: # this is True default sequence_output = sequence_output * (self.model_dim ** -0.5) if encoder_loss: cls_logits = nn.functional.linear( hidden_states, self.encoder.get_input_embeddings().weight) lm_logits = self.lm_head(sequence_output) if decoder_input_ids is not None: lm_loss_fct = CrossEntropyLoss( ignore_index=self.config.pad_token_id) loss = lm_loss_fct( lm_logits.view(-1, lm_logits.size(-1)), decoder_input_ids.view(-1)) if encoder_loss and input_labels is not None: cls_loss_fct = CrossEntropyLoss(ignore_index=-100) loss += cls_loss_fct(cls_logits.view(-1, cls_logits.size(-1)), input_labels.view(-1)) return loss return cls_logits, lm_logits class FocalLoss(nn.Module): """Computes the focal loss between input and target as described here https://arxiv.org/abs/1708.02002v2 Args: gamma (float): The focal loss focusing parameter. weights (Union[None, Tensor]): Rescaling weight given to each class. If given, has to be a Tensor of size C. optional. reduction (str): Specifies the reduction to apply to the output. it should be one of the following 'none', 'mean', or 'sum'. default 'mean'. ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. optional. eps (float): smoothing to prevent log from returning inf. """ def __init__( self, gamma, weights: Union[None, Tensor] = None, reduction: str = 'mean', ignore_index=-100, eps=1e-16 ) -> None: super().__init__() if reduction not in ['mean', 'none', 'sum']: raise NotImplementedError( 'Reduction {} not implemented.'.format(reduction) ) assert weights is None or isinstance(weights, Tensor), \ 'weights should be of type Tensor or None, but {} given'.format( type(weights)) self.reduction = reduction self.gamma = gamma self.ignore_index = ignore_index self.eps = eps self.weights = weights def _get_weights(self, target: Tensor) -> Tensor: if self.weights is None: return torch.ones(target.shape[0]) weights = target * self.weights return weights.sum(dim=-1) def _process_target( self, target: Tensor, num_classes: int, mask: Tensor ) -> Tensor: # convert all ignore_index elements to zero to avoid error in one_hot # note - the choice of value 0 is arbitrary, but it should not matter as these elements will be ignored in the loss calculation target = target * (target != self.ignore_index) target = target.view(-1) return one_hot(target, num_classes=num_classes) def _process_preds(self, x: Tensor) -> Tensor: if x.dim() == 1: x = torch.vstack([1 - x, x]) x = x.permute(1, 0) return x return x.view(-1, x.shape[-1]) def _calc_pt( self, target: Tensor, x: Tensor, mask: Tensor ) -> Tensor: p = target * x p = p.sum(dim=-1) p = p * ~mask return p def forward(self, x: Tensor, target: Tensor) -> Tensor: assert torch.all((x >= 0.0) & (x <= 1.0)), ValueError( 'The predictions values should be between 0 and 1, \ make sure to pass the values to sigmoid for binary \ classification or softmax for multi-class classification' ) mask = target == self.ignore_index mask = mask.view(-1) x = self._process_preds(x) num_classes = x.shape[-1] target = self._process_target(target, num_classes, mask) weights = self._get_weights(target).to(x.device) pt = self._calc_pt(target, x, mask) focal = 1 - pt nll = -torch.log(self.eps + pt) nll = nll.masked_fill(mask, 0) loss = weights * (focal ** self.gamma) * nll return self._reduce(loss, mask, weights) def _reduce(self, x: Tensor, mask: Tensor, weights: Tensor) -> Tensor: if self.reduction == 'mean': return x.sum() / (~mask * weights).sum() elif self.reduction == 'sum': return x.sum() else: return x MODEL_CLASSES = { 'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer), 't5': (T5Config, T5ForConditionalGeneration, T5Tokenizer), 'codet5': (T5Config, T5ForConditionalGeneration, RobertaTokenizer), 'bart': (BartConfig, BartForConditionalGeneration, BartTokenizer), 'myt5': (T5Config, T5Model, RobertaTokenizer), 'codet5_CC': (T5Config, CodeChangeModel, RobertaTokenizer), }