| import logging |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.nn import CrossEntropyLoss |
| from transformers import (BartConfig, BartForConditionalGeneration, |
| BartTokenizer, RobertaConfig, RobertaModel, |
| RobertaTokenizer, T5Config, |
| T5ForConditionalGeneration, T5Model, T5Tokenizer) |
| from torch.nn.functional import one_hot |
| from torch import Tensor |
| from typing import Union |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def get_model_size(model): |
| model_parameters = filter(lambda p: p.requires_grad, model.parameters()) |
| model_size = sum([np.prod(p.size()) for p in model_parameters]) |
| return "{}M".format(round(model_size / 1e+6)) |
|
|
|
|
| def enrich_vocab( |
| args, |
| tokenizer, |
| config, |
| load_extra_ids=True |
| ): |
| add_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<mask>", |
| "<keep>", "<add>", "<del>", "<start>", "<end>", "<issue_id>", "<version_id>", "<commit_id>"] |
| add_token_ids = [ |
| tok for tok in add_tokens if tok not in tokenizer.get_vocab()] |
| if add_token_ids: |
| tokenizer.add_special_tokens( |
| {"additional_special_tokens": add_token_ids} |
| ) |
|
|
| if load_extra_ids is True: |
| tokenizer.add_special_tokens( |
| { |
| "additional_special_tokens": [ |
| "<extra_id_{}>".format(i) for i in range(99, -1, -1) |
| ] |
| } |
| ) |
| tokenizer.add_special_tokens( |
| { |
| "additional_special_tokens": [ |
| "<e{}>".format(i) for i in range(99, -1, -1) |
| ] |
| } |
| ) |
|
|
| tokenizer.add_special_tokens({"additional_special_tokens": ["<msg>"]}) |
|
|
| langs = [ |
| "<en>", |
| "<python>", |
| "<java>", |
| "<javascript>", |
| "<ruby>", |
| "<php>", |
| "<go>", |
| "<c>", |
| "<c_sharp>", |
| "<c_plus_plus>", |
| ] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if args.add_lang_ids: |
| tokenizer.add_special_tokens( |
| { |
| "additional_special_tokens": langs |
| } |
| ) |
| config.lang_id = { |
| lang: tokenizer.get_vocab()[lang] for lang in langs |
| } |
|
|
| config.vocab_size = len(tokenizer) |
| config.bos_token_id = tokenizer.get_vocab()["<s>"] |
| config.pad_token_id = tokenizer.get_vocab()["<pad>"] |
| config.eos_token_id = tokenizer.get_vocab()["</s>"] |
| config.mask_token_id = tokenizer.get_vocab()["<mask>"] |
| config.keep_token_id = tokenizer.get_vocab()["<keep>"] |
| config.add_token_id = tokenizer.get_vocab()["<add>"] |
| config.del_token_id = tokenizer.get_vocab()["<del>"] |
| config.start_token_id = tokenizer.get_vocab()["<start>"] |
| config.end_token_id = tokenizer.get_vocab()["<end>"] |
| config.lang_tokens = langs |
|
|
| tokenizer.special_dict = { |
| f"<e{i}>": tokenizer.get_vocab()[f"<e{i}>"] for i in range(99, -1, -1) |
| } |
|
|
| tokenizer.mask_id = tokenizer.get_vocab()["<mask>"] |
| tokenizer.bos_id = tokenizer.get_vocab()["<s>"] |
| tokenizer.pad_id = tokenizer.get_vocab()["<pad>"] |
| tokenizer.eos_id = tokenizer.get_vocab()["</s>"] |
| tokenizer.msg_id = tokenizer.get_vocab()["<msg>"] |
| tokenizer.keep_id = tokenizer.get_vocab()["<keep>"] |
| tokenizer.add_id = tokenizer.get_vocab()["<add>"] |
| tokenizer.del_id = tokenizer.get_vocab()["<del>"] |
| tokenizer.start_id = tokenizer.get_vocab()["<start>"] |
| tokenizer.end_id = tokenizer.get_vocab()["<end>"] |
| |
|
|
| return tokenizer, config |
|
|
|
|
| def build_or_load_gen_model(args, load_model=True): |
| config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] |
|
|
| config = config_class.from_pretrained( |
| args.config_name if args.config_name else args.model_name_or_path) |
|
|
| if not args.tokenizer_name: |
| tokenizer_name = "Salesforce/codet5-base" |
|
|
| tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name) |
|
|
| if not args.model_name_or_path: |
| if args.model_type == "codet5": |
| args.model_name_or_path = "Salesforce/codet5-base" |
| else: |
| args.model_name_or_path = "t5-base" |
|
|
| if args.model_type == 'roberta': |
| encoder = model_class.from_pretrained( |
| args.model_name_or_path, config=config) |
| decoder_layer = nn.TransformerDecoderLayer( |
| d_model=config.hidden_size, nhead=config.num_attention_heads) |
| decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) |
| model = Seq2Seq(encoder=encoder, decoder=decoder, config=config, |
| beam_size=args.beam_size, max_length=args.max_target_length, |
| sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id) |
| elif args.from_scratch is True: |
| model = model_class(config) |
| else: |
| model = model_class.from_pretrained(args.model_name_or_path) |
|
|
| tokenizer, config = enrich_vocab(args, tokenizer, config) |
| model.config = config |
| model.resize_token_embeddings(len(tokenizer)) |
| logger.info( |
| "Finish loading model [%s] from %s", |
| get_model_size(model), |
| args.model_name_or_path |
| ) |
|
|
| model.to(args.device) |
| if args.load_model_path is not None and load_model is True: |
| logger.info("Load model from {}".format(args.load_model_path)) |
| model.load_state_dict(torch.load( |
| args.load_model_path, map_location="cpu")) |
| logger.info("Model from {} has been loaded".format(args.load_model_path)) |
| model.to(args.device) |
| return config, model, tokenizer |
|
|
|
|
| class RobertaClassificationHeadMF(nn.Module): |
| """Head for sentence-level classification tasks.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.manual_dense = nn.Linear(14, config.d_model) |
| self.dropout = nn.Dropout(0.1) |
| self.out_proj_new = nn.Linear( |
| config.d_model + config.d_model, 2, bias=True) |
| |
|
|
| def forward(self, features, manual_features=None, **kwargs): |
| |
| x = features[:, 0, :] |
| y = manual_features.float() |
| y = self.manual_dense(y) |
| y = torch.tanh(y) |
|
|
| x = torch.cat((x, y), dim=-1) |
| x = self.dropout(x) |
| x = self.out_proj_new(x) |
| return x |
|
|
| class RobertaClassificationHead(nn.Module): |
| """Head for sentence-level classification tasks.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.d_model * 2, config.d_model) |
| self.out_proj = nn.Linear(config.d_model, 2) |
|
|
| def forward(self, x, **kwargs): |
| x = x.reshape(-1, x.size(-1) * 2) |
| x = self.dense(x) |
| x = torch.tanh(x) |
| x = self.out_proj(x) |
| return x |
|
|
|
|
| class CloneModel(nn.Module): |
| def __init__(self, encoder, config, tokenizer, args): |
| super(CloneModel, self).__init__() |
| self.encoder = encoder |
| self.config = config |
| self.tokenizer = tokenizer |
| self.classifier = RobertaClassificationHead(config) |
| self.args = args |
|
|
| def get_t5_vec(self, source_ids): |
| attention_mask = source_ids.ne(self.tokenizer.pad_token_id) |
| outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask, |
| labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True) |
| hidden_states = outputs['decoder_hidden_states'][-1] |
| eos_mask = source_ids.eq(self.config.eos_token_id) |
|
|
| if len(torch.unique(eos_mask.sum(1))) > 1: |
| raise ValueError( |
| "All examples must have the same number of <eos> tokens.") |
| vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, |
| hidden_states.size(-1))[:, -1, :] |
| return vec |
|
|
| def get_bart_vec(self, source_ids): |
| attention_mask = source_ids.ne(self.tokenizer.pad_token_id) |
| outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask, |
| labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True) |
| hidden_states = outputs['decoder_hidden_states'][-1] |
| eos_mask = source_ids.eq(self.config.eos_token_id) |
|
|
| if len(torch.unique(eos_mask.sum(1))) > 1: |
| raise ValueError( |
| "All examples must have the same number of <eos> tokens.") |
| vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, |
| hidden_states.size(-1))[:, -1, :] |
| return vec |
|
|
| def get_roberta_vec(self, source_ids): |
| attention_mask = source_ids.ne(self.tokenizer.pad_token_id) |
| vec = self.encoder(input_ids=source_ids, |
| attention_mask=attention_mask)[0][:, 0, :] |
| return vec |
|
|
| def forward(self, source_ids=None, labels=None): |
| source_ids = source_ids.view(-1, self.args.max_source_length) |
|
|
| if self.args.model_type == 'codet5': |
| vec = self.get_t5_vec(source_ids) |
| elif self.args.model_type == 'bart': |
| vec = self.get_bart_vec(source_ids) |
| elif self.args.model_type == 'roberta': |
| vec = self.get_roberta_vec(source_ids) |
|
|
| logits = self.classifier(vec) |
| prob = nn.functional.softmax(logits) |
|
|
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits, labels) |
| return loss, prob |
| else: |
| return prob |
|
|
|
|
| class DefectModel(nn.Module): |
| def __init__(self, encoder, config, tokenizer, args): |
| super(DefectModel, self).__init__() |
| self.encoder = encoder |
| self.config = config |
| self.tokenizer = tokenizer |
| self.classifier = nn.Linear(config.hidden_size, 2) |
| self.args = args |
|
|
| def get_t5_vec(self, source_ids): |
| attention_mask = source_ids.ne(self.tokenizer.pad_token_id) |
| outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask, |
| labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True) |
| hidden_states = outputs['decoder_hidden_states'][-1] |
| eos_mask = source_ids.eq(self.config.eos_token_id) |
|
|
| if len(torch.unique(eos_mask.sum(1))) > 1: |
| raise ValueError( |
| "All examples must have the same number of <eos> tokens.") |
| vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, |
| hidden_states.size(-1))[:, -1, :] |
| return vec |
|
|
| def get_bart_vec(self, source_ids): |
| attention_mask = source_ids.ne(self.tokenizer.pad_token_id) |
| outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask, |
| labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True) |
| hidden_states = outputs['decoder_hidden_states'][-1] |
| eos_mask = source_ids.eq(self.config.eos_token_id) |
|
|
| if len(torch.unique(eos_mask.sum(1))) > 1: |
| raise ValueError( |
| "All examples must have the same number of <eos> tokens.") |
| vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, |
| hidden_states.size(-1))[:, -1, :] |
| return vec |
|
|
| def get_roberta_vec(self, source_ids): |
| attention_mask = source_ids.ne(self.tokenizer.pad_token_id) |
| vec = self.encoder(input_ids=source_ids, |
| attention_mask=attention_mask)[0][:, 0, :] |
| return vec |
|
|
| def forward(self, source_ids=None, labels=None): |
| source_ids = source_ids.view(-1, self.args.max_source_length) |
|
|
| if self.args.model_type == 'codet5': |
| vec = self.get_t5_vec(source_ids) |
| elif self.args.model_type == 'bart': |
| vec = self.get_bart_vec(source_ids) |
| elif self.args.model_type == 'roberta': |
| vec = self.get_roberta_vec(source_ids) |
|
|
| logits = self.classifier(vec) |
| prob = nn.functional.softmax(logits) |
|
|
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits, labels) |
| return loss, prob |
| else: |
| return prob |
|
|
|
|
| |
| class Seq2Seq(nn.Module): |
| """ |
| Build Seqence-to-Sequence. |
| |
| Parameters: |
| |
| * `encoder`- encoder of seq2seq model. e.g. roberta |
| * `decoder`- decoder of seq2seq model. e.g. transformer |
| * `config`- configuration of encoder model. |
| * `beam_size`- beam size for beam search. |
| * `max_length`- max length of target for beam search. |
| * `sos_id`- start of symbol ids in target for beam search. |
| * `eos_id`- end of symbol ids in target for beam search. |
| """ |
|
|
| def __init__(self, encoder, decoder, config, beam_size=None, max_length=None, sos_id=None, eos_id=None): |
| super(Seq2Seq, self).__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.config = config |
| self.register_buffer("bias", torch.tril(torch.ones(2048, 2048))) |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.lm_head = nn.Linear( |
| config.hidden_size, config.vocab_size, bias=False) |
| self.lsm = nn.LogSoftmax(dim=-1) |
| self.tie_weights() |
|
|
| self.beam_size = beam_size |
| self.max_length = max_length |
| self.sos_id = sos_id |
| self.eos_id = eos_id |
|
|
| def _tie_or_clone_weights(self, first_module, second_module): |
| """ Tie or clone module weights depending of weither we are using TorchScript or not |
| """ |
| if self.config.torchscript: |
| first_module.weight = nn.Parameter(second_module.weight.clone()) |
| else: |
| first_module.weight = second_module.weight |
|
|
| def tie_weights(self): |
| """ Make sure we are sharing the input and output embeddings. |
| Export to TorchScript can't handle parameter sharing so we are cloning them instead. |
| """ |
| self._tie_or_clone_weights(self.lm_head, |
| self.encoder.embeddings.word_embeddings) |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, decoder_attention_mask=None, args=None): |
| """Update some parameters name - Bo. |
| |
| Args: |
| input_ids (_type_, optional): _description_. Defaults to None. |
| attention_mask (_type_, optional): _description_. Defaults to None. |
| labels (_type_, optional): _description_. Defaults to None. |
| target_mask (_type_, optional): _description_. Defaults to None. |
| args (_type_, optional): _description_. Defaults to None. |
| |
| Returns: |
| _type_: _description_ |
| """ |
| outputs = self.encoder(input_ids, attention_mask=attention_mask) |
| encoder_output = outputs[0].permute( |
| [1, 0, 2]).contiguous() |
| if labels is not None: |
| attn_mask = -1e4 * \ |
| (1 - self.bias[:labels.shape[1], :labels.shape[1]]) |
| tgt_embeddings = self.encoder.embeddings( |
| labels).permute([1, 0, 2]).contiguous() |
| out = self.decoder(tgt_embeddings, encoder_output, tgt_mask=attn_mask, |
| memory_key_padding_mask=~attention_mask) |
| |
| hidden_states = torch.tanh(self.dense( |
| out)).permute([1, 0, 2]).contiguous() |
| lm_logits = self.lm_head(hidden_states) |
| |
| active_loss = decoder_attention_mask[..., 1:].ne(0).view(-1) == 1 |
| shift_logits = lm_logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-1) |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss], |
| shift_labels.view(-1)[active_loss]) |
|
|
| outputs = loss, loss * active_loss.sum(), active_loss.sum() |
| return outputs |
| else: |
| |
| preds = [] |
| zero = torch.cuda.LongTensor(1).fill_(0) |
| for i in range(input_ids.shape[0]): |
| context = encoder_output[:, i:i + 1] |
| context_mask = attention_mask[i:i + 1, :] |
| beam = Beam(self.beam_size, self.sos_id, self.eos_id) |
| input_ids = beam.getCurrentState() |
| context = context.repeat(1, self.beam_size, 1) |
| context_mask = context_mask.repeat(self.beam_size, 1) |
| for _ in range(self.max_length): |
| if beam.done(): |
| break |
| attn_mask = -1e4 * \ |
| (1 - self.bias[:input_ids.shape[1], |
| :input_ids.shape[1]]) |
| tgt_embeddings = self.encoder.embeddings( |
| input_ids).permute([1, 0, 2]).contiguous() |
| out = self.decoder(tgt_embeddings, context, tgt_mask=attn_mask, |
| memory_key_padding_mask=~context_mask) |
| |
| out = torch.tanh(self.dense(out)) |
| hidden_states = out.permute( |
| [1, 0, 2]).contiguous()[:, -1, :] |
| out = self.lsm(self.lm_head(hidden_states)).data |
| beam.advance(out) |
| input_ids.data.copy_(input_ids.data.index_select( |
| 0, beam.getCurrentOrigin())) |
| input_ids = torch.cat( |
| (input_ids, beam.getCurrentState()), -1) |
| hyp = beam.getHyp(beam.getFinal()) |
| pred = beam.buildTargetTokens(hyp)[:self.beam_size] |
| pred = [torch.cat([x.view(-1) for x in p] + [zero] * (self.max_length - len(p))).view(1, -1) for p in |
| pred] |
| preds.append(torch.cat(pred, 0).unsqueeze(0)) |
|
|
| preds = torch.cat(preds, 0) |
| return preds |
|
|
|
|
| class Beam(object): |
| def __init__(self, size, sos, eos): |
| self.size = size |
| self.tt = torch.cuda |
| |
| self.scores = self.tt.FloatTensor(size).zero_() |
| |
| self.prevKs = [] |
| |
| self.nextYs = [self.tt.LongTensor(size) |
| .fill_(0)] |
| self.nextYs[0][0] = sos |
| |
| self._eos = eos |
| self.eosTop = False |
| |
| self.finished = [] |
|
|
| def getCurrentState(self): |
| "Get the outputs for the current timestep." |
| batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) |
| return batch |
|
|
| def getCurrentOrigin(self): |
| "Get the backpointers for the current timestep." |
| return self.prevKs[-1] |
|
|
| def advance(self, wordLk): |
| """ |
| Given prob over words for every last beam `wordLk` and attention |
| `attnOut`: Compute and update the beam search. |
| |
| Parameters: |
| |
| * `wordLk`- probs of advancing from the last step (K x words) |
| * `attnOut`- attention at the last step |
| |
| Returns: True if beam search is complete. |
| """ |
| numWords = wordLk.size(1) |
|
|
| |
| if len(self.prevKs) > 0: |
| beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) |
|
|
| |
| for i in range(self.nextYs[-1].size(0)): |
| if self.nextYs[-1][i] == self._eos: |
| beamLk[i] = -1e20 |
| else: |
| beamLk = wordLk[0] |
| flatBeamLk = beamLk.view(-1) |
| bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) |
|
|
| self.scores = bestScores |
|
|
| |
| |
| prevK = bestScoresId // numWords |
| self.prevKs.append(prevK) |
| self.nextYs.append((bestScoresId - prevK * numWords)) |
|
|
| for i in range(self.nextYs[-1].size(0)): |
| if self.nextYs[-1][i] == self._eos: |
| s = self.scores[i] |
| self.finished.append((s, len(self.nextYs) - 1, i)) |
|
|
| |
| if self.nextYs[-1][0] == self._eos: |
| self.eosTop = True |
|
|
| def done(self): |
| return self.eosTop and len(self.finished) >= self.size |
|
|
| def getFinal(self): |
| if len(self.finished) == 0: |
| self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) |
| self.finished.sort(key=lambda a: -a[0]) |
| if len(self.finished) != self.size: |
| unfinished = [] |
| for i in range(self.nextYs[-1].size(0)): |
| if self.nextYs[-1][i] != self._eos: |
| s = self.scores[i] |
| unfinished.append((s, len(self.nextYs) - 1, i)) |
| unfinished.sort(key=lambda a: -a[0]) |
| self.finished += unfinished[:self.size - len(self.finished)] |
| return self.finished[:self.size] |
|
|
| def getHyp(self, beam_res): |
| """ |
| Walk back to construct the full hypothesis. |
| """ |
| hyps = [] |
| for _, timestep, k in beam_res: |
| hyp = [] |
| for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): |
| hyp.append(self.nextYs[j + 1][k]) |
| k = self.prevKs[j][k] |
| hyps.append(hyp[::-1]) |
| return hyps |
|
|
| def buildTargetTokens(self, preds): |
| sentence = [] |
| for pred in preds: |
| tokens = [] |
| for tok in pred: |
| if tok == self._eos: |
| break |
| tokens.append(tok) |
| sentence.append(tokens) |
| return sentence |
|
|
|
|
|
|
| class CodeChangeModel(T5ForConditionalGeneration): |
| def __init__(self, config): |
| super().__init__(config) |
| self.cls_head = nn.Linear(self.config.d_model, 2, bias=True) |
| self.init() |
|
|
| @staticmethod |
| def from_pretrained(path): |
| model = T5ForConditionalGeneration.from_pretrained(path) |
| model.__class__ = CodeChangeModel |
| |
| model.cls_head = nn.Linear(model.config.d_model, 2, bias=True) |
| model.init() |
| return model |
|
|
| def init(self): |
| factor = self.config.initializer_factor |
| self.cls_head.weight.data.normal_(mean=0.0, |
| std=factor * ((self.config.d_model) ** -0.5)) |
| self.cls_head.bias.data.zero_() |
|
|
| def init_classifier(self): |
| self.cls_head = nn.Linear(self.config.d_model, 2, bias=True) |
| self.cls_head.to(self.encoder.device) |
|
|
| def init_MF_classifier(self): |
| self.cls_head = RobertaClassificationHeadMF(self.config) |
| self.cls_head.to(self.encoder.device) |
|
|
|
|
| def forward( |
| self, *argv, **kwargs |
| ): |
| r""" |
| Doc from Huggingface transformers: |
| labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): |
| Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., |
| config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for |
| labels in ``[0, ..., config.vocab_size]`` |
| Returns: |
| Examples:: |
| >>> from transformers import T5Tokenizer, T5ForConditionalGeneration |
| >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') |
| >>> model = T5ForConditionalGeneration.from_pretrained('t5-small') |
| >>> # training |
| >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids |
| >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids |
| >>> outputs = model(input_ids=input_ids, labels=labels) |
| >>> loss = outputs.loss |
| >>> logits = outputs.logits |
| >>> # inference |
| >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 |
| >>> outputs = model.generate(input_ids) |
| >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
| >>> # studies have shown that owning a dog is good for you. |
| """ |
|
|
| if "cls" in kwargs: |
| assert ( |
| ("input_ids" in kwargs or "old_ids" in kwargs) and |
| "labels" in kwargs and |
| ("attention_mask" in kwargs or "old_attention_mask" in kwargs) |
| ) |
| if "manual_feature" in kwargs: |
| return self.cls_MF( |
| input_ids=kwargs["input_ids"], |
| manual_feature=kwargs["manual_feature"], |
| labels=kwargs["labels"], |
| attention_mask=kwargs["attention_mask"], |
| ) |
| elif "SF" in kwargs: |
| return self.cls_SF( |
| input_ids=kwargs["input_ids"], |
| labels=kwargs["labels"], |
| attention_mask=kwargs["attention_mask"], |
| ) |
| elif "DQE" in kwargs: |
| return self.cls_DQE( |
| input_ids=kwargs["input_ids"], |
| labels=kwargs["labels"], |
| attention_mask=kwargs["attention_mask"], |
| ) |
| else: |
| return self.cls( |
| input_ids=kwargs["input_ids"], |
| labels=kwargs["labels"], |
| attention_mask=kwargs["attention_mask"], |
| ) |
| |
| if "input_labels" in kwargs: |
| assert ( |
| "input_ids" in kwargs and |
| "input_labels" in kwargs and |
| "decoder_input_ids" in kwargs and |
| "attention_mask" in kwargs and |
| "decoder_attention_mask" in kwargs |
| ), "Please give these arg keys." |
| input_ids = kwargs["input_ids"] |
| input_labels = kwargs["input_labels"] |
| decoder_input_ids = kwargs["decoder_input_ids"] |
| attention_mask = kwargs["attention_mask"] |
| decoder_attention_mask = kwargs["decoder_attention_mask"] |
| if "encoder_loss" not in kwargs: |
| encoder_loss = True |
| else: |
| encoder_loss = kwargs["encoder_loss"] |
| return self.review_forward(input_ids, input_labels, decoder_input_ids, attention_mask, decoder_attention_mask, encoder_loss) |
|
|
| return super().forward(*argv, **kwargs) |
|
|
| def cls( |
| self, |
| input_ids, |
| labels, |
| attention_mask |
| ): |
| encoder_outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_attentions=False, |
| return_dict=False |
| ) |
| hidden_states = encoder_outputs[0] |
| first_hidden = hidden_states[:, 0, :] |
| first_hidden = nn.Dropout(0.3)(first_hidden) |
| logits = self.cls_head(first_hidden) |
|
|
| loss_fct = CrossEntropyLoss() |
|
|
| if labels != None: |
| loss = loss_fct(logits, labels) |
| return loss |
|
|
| return logits |
|
|
| def cls_DQE( |
| self, |
| input_ids, |
| labels, |
| attention_mask |
| ): |
| encoder_outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_attentions=False, |
| return_dict=False |
| ) |
| hidden_states = encoder_outputs[0] |
| first_hidden = hidden_states[:, 0, :] |
| first_hidden = nn.Dropout(0.3)(first_hidden) |
| logits = self.cls_head(first_hidden) |
|
|
| loss_fct = CrossEntropyLoss() |
|
|
| if labels != None: |
| loss = loss_fct(logits, labels) |
| |
| return loss |
|
|
| return logits |
|
|
| def cls_SF( |
| self, |
| input_ids, |
| labels, |
| attention_mask |
| ): |
| |
| |
| encoder_outputs = self.encoder( \ |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_attentions=False, |
| return_dict=False |
| ) |
| hidden_states = encoder_outputs[0] |
| first_hidden = hidden_states[:, 0, :] |
| first_hidden = nn.Dropout(0.3)(first_hidden) |
| logits = self.cls_head(first_hidden) |
| |
| weights = torch.FloatTensor([1, 3]).to(logits.device) |
| loss_fct = FocalLoss(gamma=2, weights=weights) |
| m = torch.nn.Softmax(dim=-1) |
| |
| if labels != None: |
| loss = loss_fct(m(logits), labels) |
| return loss |
| |
| return logits |
|
|
|
|
| def cls_MF( |
| self, |
| input_ids, |
| manual_feature, |
| labels, |
| attention_mask |
| ): |
| encoder_outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_attentions=False, |
| return_dict=False |
| ) |
|
|
| logits = self.cls_head(encoder_outputs[0], manual_feature) |
|
|
| weights = torch.FloatTensor([1, 3]).to(logits.device) |
| loss_fct = FocalLoss(gamma=2, weights=weights) |
| m = torch.nn.Softmax(dim=-1) |
|
|
| if labels != None: |
| loss = loss_fct(m(logits), labels) |
| return loss |
| return logits |
|
|
|
|
| def review_forward( |
| self, |
| input_ids, |
| input_labels, |
| decoder_input_ids, |
| attention_mask, |
| decoder_attention_mask, |
| encoder_loss=True |
| ): |
| encoder_outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_attentions=False, |
| return_dict=False |
| ) |
| hidden_states = encoder_outputs[0] |
| decoder_inputs = self._shift_right(decoder_input_ids) |
| |
| decoder_outputs = self.decoder( |
| input_ids=decoder_inputs, |
| attention_mask=decoder_attention_mask, |
| encoder_hidden_states=hidden_states, |
| encoder_attention_mask=attention_mask, |
| output_attentions=False, |
| return_dict=False |
| ) |
| sequence_output = decoder_outputs[0] |
| if self.config.tie_word_embeddings: |
| sequence_output = sequence_output * (self.model_dim ** -0.5) |
| if encoder_loss: |
| cls_logits = nn.functional.linear( |
| hidden_states, self.encoder.get_input_embeddings().weight) |
| lm_logits = self.lm_head(sequence_output) |
| if decoder_input_ids is not None: |
| lm_loss_fct = CrossEntropyLoss( |
| ignore_index=self.config.pad_token_id) |
| loss = lm_loss_fct( |
| lm_logits.view(-1, lm_logits.size(-1)), decoder_input_ids.view(-1)) |
| if encoder_loss and input_labels is not None: |
| cls_loss_fct = CrossEntropyLoss(ignore_index=-100) |
| loss += cls_loss_fct(cls_logits.view(-1, |
| cls_logits.size(-1)), input_labels.view(-1)) |
| return loss |
| return cls_logits, lm_logits |
|
|
|
|
| class FocalLoss(nn.Module): |
| """Computes the focal loss between input and target |
| as described here https://arxiv.org/abs/1708.02002v2 |
| |
| Args: |
| gamma (float): The focal loss focusing parameter. |
| weights (Union[None, Tensor]): Rescaling weight given to each class. |
| If given, has to be a Tensor of size C. optional. |
| reduction (str): Specifies the reduction to apply to the output. |
| it should be one of the following 'none', 'mean', or 'sum'. |
| default 'mean'. |
| ignore_index (int): Specifies a target value that is ignored and |
| does not contribute to the input gradient. optional. |
| eps (float): smoothing to prevent log from returning inf. |
| """ |
|
|
| def __init__( |
| self, |
| gamma, |
| weights: Union[None, Tensor] = None, |
| reduction: str = 'mean', |
| ignore_index=-100, |
| eps=1e-16 |
| ) -> None: |
| super().__init__() |
| if reduction not in ['mean', 'none', 'sum']: |
| raise NotImplementedError( |
| 'Reduction {} not implemented.'.format(reduction) |
| ) |
| assert weights is None or isinstance(weights, Tensor), \ |
| 'weights should be of type Tensor or None, but {} given'.format( |
| type(weights)) |
| self.reduction = reduction |
| self.gamma = gamma |
| self.ignore_index = ignore_index |
| self.eps = eps |
| self.weights = weights |
|
|
| def _get_weights(self, target: Tensor) -> Tensor: |
| if self.weights is None: |
| return torch.ones(target.shape[0]) |
| weights = target * self.weights |
| return weights.sum(dim=-1) |
|
|
| def _process_target( |
| self, target: Tensor, num_classes: int, mask: Tensor |
| ) -> Tensor: |
|
|
| |
| |
| target = target * (target != self.ignore_index) |
| target = target.view(-1) |
| return one_hot(target, num_classes=num_classes) |
|
|
| def _process_preds(self, x: Tensor) -> Tensor: |
| if x.dim() == 1: |
| x = torch.vstack([1 - x, x]) |
| x = x.permute(1, 0) |
| return x |
| return x.view(-1, x.shape[-1]) |
|
|
| def _calc_pt( |
| self, target: Tensor, x: Tensor, mask: Tensor |
| ) -> Tensor: |
| p = target * x |
| p = p.sum(dim=-1) |
| p = p * ~mask |
| return p |
|
|
| def forward(self, x: Tensor, target: Tensor) -> Tensor: |
| assert torch.all((x >= 0.0) & (x <= 1.0)), ValueError( |
| 'The predictions values should be between 0 and 1, \ |
| make sure to pass the values to sigmoid for binary \ |
| classification or softmax for multi-class classification' |
| ) |
| mask = target == self.ignore_index |
| mask = mask.view(-1) |
| x = self._process_preds(x) |
| num_classes = x.shape[-1] |
| target = self._process_target(target, num_classes, mask) |
| weights = self._get_weights(target).to(x.device) |
| pt = self._calc_pt(target, x, mask) |
| focal = 1 - pt |
| nll = -torch.log(self.eps + pt) |
| nll = nll.masked_fill(mask, 0) |
| loss = weights * (focal ** self.gamma) * nll |
| return self._reduce(loss, mask, weights) |
|
|
| def _reduce(self, x: Tensor, mask: Tensor, weights: Tensor) -> Tensor: |
| if self.reduction == 'mean': |
| return x.sum() / (~mask * weights).sum() |
| elif self.reduction == 'sum': |
| return x.sum() |
| else: |
| return x |
|
|
|
|
| MODEL_CLASSES = { |
| 'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer), |
| 't5': (T5Config, T5ForConditionalGeneration, T5Tokenizer), |
| 'codet5': (T5Config, T5ForConditionalGeneration, RobertaTokenizer), |
| 'bart': (BartConfig, BartForConditionalGeneration, BartTokenizer), |
| 'myt5': (T5Config, T5Model, RobertaTokenizer), |
| 'codet5_CC': (T5Config, CodeChangeModel, RobertaTokenizer), |
| } |
|
|