Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass, field | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from fairseq import utils | |
| from fairseq.dataclass import ChoiceEnum, FairseqDataclass | |
| from fairseq.models import ( | |
| BaseFairseqModel, | |
| register_model, | |
| ) | |
| from fairseq.models.roberta.model import RobertaClassificationHead | |
| from fairseq.modules import ( | |
| LayerNorm, | |
| TransformerSentenceEncoder, | |
| TransformerSentenceEncoderLayer, | |
| ) | |
| ACTIVATION_FN_CHOICES = ChoiceEnum(utils.get_available_activation_fns()) | |
| JOINT_CLASSIFICATION_CHOICES = ChoiceEnum(["none", "sent"]) | |
| SENTENCE_REP_CHOICES = ChoiceEnum(["head", "meanpool", "maxpool"]) | |
| def update_init_roberta_model_state(state): | |
| """ | |
| update the state_dict of a Roberta model for initializing | |
| weights of the BertRanker | |
| """ | |
| for k in list(state.keys()): | |
| if ".lm_head." in k or "version" in k: | |
| del state[k] | |
| continue | |
| # remove 'encoder/decoder.sentence_encoder.' from the key | |
| assert k.startswith("encoder.sentence_encoder.") or k.startswith( | |
| "decoder.sentence_encoder." | |
| ), f"Cannot recognize parameter name {k}" | |
| if "layernorm_embedding" in k: | |
| new_k = k.replace(".layernorm_embedding.", ".emb_layer_norm.") | |
| state[new_k[25:]] = state[k] | |
| else: | |
| state[k[25:]] = state[k] | |
| del state[k] | |
| class BaseRanker(nn.Module): | |
| def __init__(self, args, task): | |
| super().__init__() | |
| self.separator_token = task.dictionary.eos() | |
| self.padding_idx = task.dictionary.pad() | |
| def forward(self, src_tokens): | |
| raise NotImplementedError | |
| def get_segment_labels(self, src_tokens): | |
| segment_boundary = (src_tokens == self.separator_token).long() | |
| segment_labels = ( | |
| segment_boundary.cumsum(dim=1) | |
| - segment_boundary | |
| - (src_tokens == self.padding_idx).long() | |
| ) | |
| return segment_labels | |
| def get_positions(self, src_tokens, segment_labels): | |
| segment_positions = ( | |
| torch.arange(src_tokens.shape[1]) | |
| .to(src_tokens.device) | |
| .repeat(src_tokens.shape[0], 1) | |
| ) | |
| segment_boundary = (src_tokens == self.separator_token).long() | |
| _, col_idx = (segment_positions * segment_boundary).nonzero(as_tuple=True) | |
| col_idx = torch.cat([torch.zeros(1).type_as(col_idx), col_idx]) | |
| offset = torch.cat( | |
| [ | |
| torch.zeros(1).type_as(segment_boundary), | |
| segment_boundary.sum(dim=1).cumsum(dim=0)[:-1], | |
| ] | |
| ) | |
| segment_positions -= col_idx[segment_labels + offset.unsqueeze(1)] * ( | |
| segment_labels != 0 | |
| ) | |
| padding_mask = src_tokens.ne(self.padding_idx) | |
| segment_positions = (segment_positions + 1) * padding_mask.type_as( | |
| segment_positions | |
| ) + self.padding_idx | |
| return segment_positions | |
| class BertRanker(BaseRanker): | |
| def __init__(self, args, task): | |
| super(BertRanker, self).__init__(args, task) | |
| init_model = getattr(args, "pretrained_model", "") | |
| self.joint_layers = nn.ModuleList() | |
| if os.path.isfile(init_model): | |
| print(f"initialize weight from {init_model}") | |
| from fairseq import hub_utils | |
| x = hub_utils.from_pretrained( | |
| os.path.dirname(init_model), | |
| checkpoint_file=os.path.basename(init_model), | |
| ) | |
| in_state_dict = x["models"][0].state_dict() | |
| init_args = x["args"].model | |
| num_positional_emb = init_args.max_positions + task.dictionary.pad() + 1 | |
| # follow the setup in roberta | |
| self.model = TransformerSentenceEncoder( | |
| padding_idx=task.dictionary.pad(), | |
| vocab_size=len(task.dictionary), | |
| num_encoder_layers=getattr( | |
| args, "encoder_layers", init_args.encoder_layers | |
| ), | |
| embedding_dim=init_args.encoder_embed_dim, | |
| ffn_embedding_dim=init_args.encoder_ffn_embed_dim, | |
| num_attention_heads=init_args.encoder_attention_heads, | |
| dropout=init_args.dropout, | |
| attention_dropout=init_args.attention_dropout, | |
| activation_dropout=init_args.activation_dropout, | |
| num_segments=2, # add language embeddings | |
| max_seq_len=num_positional_emb, | |
| offset_positions_by_padding=False, | |
| encoder_normalize_before=True, | |
| apply_bert_init=True, | |
| activation_fn=init_args.activation_fn, | |
| freeze_embeddings=args.freeze_embeddings, | |
| n_trans_layers_to_freeze=args.n_trans_layers_to_freeze, | |
| ) | |
| # still need to learn segment embeddings as we added a second language embedding | |
| if args.freeze_embeddings: | |
| for p in self.model.segment_embeddings.parameters(): | |
| p.requires_grad = False | |
| update_init_roberta_model_state(in_state_dict) | |
| print("loading weights from the pretrained model") | |
| self.model.load_state_dict( | |
| in_state_dict, strict=False | |
| ) # ignore mismatch in language embeddings | |
| ffn_embedding_dim = init_args.encoder_ffn_embed_dim | |
| num_attention_heads = init_args.encoder_attention_heads | |
| dropout = init_args.dropout | |
| attention_dropout = init_args.attention_dropout | |
| activation_dropout = init_args.activation_dropout | |
| activation_fn = init_args.activation_fn | |
| classifier_embed_dim = getattr( | |
| args, "embed_dim", init_args.encoder_embed_dim | |
| ) | |
| if classifier_embed_dim != init_args.encoder_embed_dim: | |
| self.transform_layer = nn.Linear( | |
| init_args.encoder_embed_dim, classifier_embed_dim | |
| ) | |
| else: | |
| self.model = TransformerSentenceEncoder( | |
| padding_idx=task.dictionary.pad(), | |
| vocab_size=len(task.dictionary), | |
| num_encoder_layers=args.encoder_layers, | |
| embedding_dim=args.embed_dim, | |
| ffn_embedding_dim=args.ffn_embed_dim, | |
| num_attention_heads=args.attention_heads, | |
| dropout=args.dropout, | |
| attention_dropout=args.attention_dropout, | |
| activation_dropout=args.activation_dropout, | |
| max_seq_len=task.max_positions() | |
| if task.max_positions() | |
| else args.tokens_per_sample, | |
| num_segments=2, | |
| offset_positions_by_padding=False, | |
| encoder_normalize_before=args.encoder_normalize_before, | |
| apply_bert_init=args.apply_bert_init, | |
| activation_fn=args.activation_fn, | |
| ) | |
| classifier_embed_dim = args.embed_dim | |
| ffn_embedding_dim = args.ffn_embed_dim | |
| num_attention_heads = args.attention_heads | |
| dropout = args.dropout | |
| attention_dropout = args.attention_dropout | |
| activation_dropout = args.activation_dropout | |
| activation_fn = args.activation_fn | |
| self.joint_classification = args.joint_classification | |
| if args.joint_classification == "sent": | |
| if args.joint_normalize_before: | |
| self.joint_layer_norm = LayerNorm(classifier_embed_dim) | |
| else: | |
| self.joint_layer_norm = None | |
| self.joint_layers = nn.ModuleList( | |
| [ | |
| TransformerSentenceEncoderLayer( | |
| embedding_dim=classifier_embed_dim, | |
| ffn_embedding_dim=ffn_embedding_dim, | |
| num_attention_heads=num_attention_heads, | |
| dropout=dropout, | |
| attention_dropout=attention_dropout, | |
| activation_dropout=activation_dropout, | |
| activation_fn=activation_fn, | |
| ) | |
| for _ in range(args.num_joint_layers) | |
| ] | |
| ) | |
| self.classifier = RobertaClassificationHead( | |
| classifier_embed_dim, | |
| classifier_embed_dim, | |
| 1, # num_classes | |
| "tanh", | |
| args.classifier_dropout, | |
| ) | |
| def forward(self, src_tokens, src_lengths): | |
| segment_labels = self.get_segment_labels(src_tokens) | |
| positions = self.get_positions(src_tokens, segment_labels) | |
| inner_states, _ = self.model( | |
| tokens=src_tokens, | |
| segment_labels=segment_labels, | |
| last_state_only=True, | |
| positions=positions, | |
| ) | |
| return inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C | |
| def sentence_forward(self, encoder_out, src_tokens=None, sentence_rep="head"): | |
| # encoder_out: B x T x C | |
| if sentence_rep == "head": | |
| x = encoder_out[:, :1, :] | |
| else: # 'meanpool', 'maxpool' | |
| assert src_tokens is not None, "meanpool requires src_tokens input" | |
| segment_labels = self.get_segment_labels(src_tokens) | |
| padding_mask = src_tokens.ne(self.padding_idx) | |
| encoder_mask = segment_labels * padding_mask.type_as(segment_labels) | |
| if sentence_rep == "meanpool": | |
| ntokens = torch.sum(encoder_mask, dim=1, keepdim=True) | |
| x = torch.sum( | |
| encoder_out * encoder_mask.unsqueeze(2), dim=1, keepdim=True | |
| ) / ntokens.unsqueeze(2).type_as(encoder_out) | |
| else: # 'maxpool' | |
| encoder_out[ | |
| (encoder_mask == 0).unsqueeze(2).repeat(1, 1, encoder_out.shape[-1]) | |
| ] = -float("inf") | |
| x, _ = torch.max(encoder_out, dim=1, keepdim=True) | |
| if hasattr(self, "transform_layer"): | |
| x = self.transform_layer(x) | |
| return x # B x 1 x C | |
| def joint_forward(self, x): | |
| # x: T x B x C | |
| if self.joint_layer_norm: | |
| x = self.joint_layer_norm(x.transpose(0, 1)) | |
| x = x.transpose(0, 1) | |
| for layer in self.joint_layers: | |
| x, _ = layer(x, self_attn_padding_mask=None) | |
| return x | |
| def classification_forward(self, x): | |
| # x: B x T x C | |
| return self.classifier(x) | |
| class DiscriminativeNMTRerankerConfig(FairseqDataclass): | |
| pretrained_model: str = field( | |
| default="", metadata={"help": "pretrained model to load"} | |
| ) | |
| sentence_rep: SENTENCE_REP_CHOICES = field( | |
| default="head", | |
| metadata={ | |
| "help": "method to transform the output of the transformer stack to a sentence-level representation" | |
| }, | |
| ) | |
| dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) | |
| attention_dropout: float = field( | |
| default=0.0, metadata={"help": "dropout probability for attention weights"} | |
| ) | |
| activation_dropout: float = field( | |
| default=0.0, metadata={"help": "dropout probability after activation in FFN"} | |
| ) | |
| classifier_dropout: float = field( | |
| default=0.0, metadata={"help": "classifier dropout probability"} | |
| ) | |
| embed_dim: int = field(default=768, metadata={"help": "embedding dimension"}) | |
| ffn_embed_dim: int = field( | |
| default=2048, metadata={"help": "embedding dimension for FFN"} | |
| ) | |
| encoder_layers: int = field(default=12, metadata={"help": "num encoder layers"}) | |
| attention_heads: int = field(default=8, metadata={"help": "num attention heads"}) | |
| encoder_normalize_before: bool = field( | |
| default=False, metadata={"help": "apply layernorm before each encoder block"} | |
| ) | |
| apply_bert_init: bool = field( | |
| default=False, metadata={"help": "use custom param initialization for BERT"} | |
| ) | |
| activation_fn: ACTIVATION_FN_CHOICES = field( | |
| default="relu", metadata={"help": "activation function to use"} | |
| ) | |
| freeze_embeddings: bool = field( | |
| default=False, metadata={"help": "freeze embeddings in the pretrained model"} | |
| ) | |
| n_trans_layers_to_freeze: int = field( | |
| default=0, | |
| metadata={ | |
| "help": "number of layers to freeze in the pretrained transformer model" | |
| }, | |
| ) | |
| # joint classfication | |
| joint_classification: JOINT_CLASSIFICATION_CHOICES = field( | |
| default="none", | |
| metadata={"help": "method to compute joint features for classification"}, | |
| ) | |
| num_joint_layers: int = field( | |
| default=1, metadata={"help": "number of joint layers"} | |
| ) | |
| joint_normalize_before: bool = field( | |
| default=False, | |
| metadata={"help": "apply layer norm on the input to the joint layer"}, | |
| ) | |
| class DiscriminativeNMTReranker(BaseFairseqModel): | |
| def build_model(cls, args, task): | |
| model = BertRanker(args, task) | |
| return DiscriminativeNMTReranker(args, model) | |
| def __init__(self, args, model): | |
| super().__init__() | |
| self.model = model | |
| self.sentence_rep = args.sentence_rep | |
| self.joint_classification = args.joint_classification | |
| def forward(self, src_tokens, src_lengths, **kwargs): | |
| return self.model(src_tokens, src_lengths) | |
| def sentence_forward(self, encoder_out, src_tokens): | |
| return self.model.sentence_forward(encoder_out, src_tokens, self.sentence_rep) | |
| def joint_forward(self, x): | |
| return self.model.joint_forward(x) | |
| def classification_forward(self, x): | |
| return self.model.classification_forward(x) | |