| | |
| | |
| | |
| | |
| |
|
| | from fairseq import utils |
| | from fairseq.models import ( |
| | FairseqMultiModel, |
| | register_model, |
| | register_model_architecture, |
| | ) |
| | from fairseq.models.transformer import ( |
| | Embedding, |
| | base_architecture, |
| | ) |
| | from fairseq.models.multilingual_transformer import ( |
| | MultilingualTransformerModel, |
| | base_multilingual_architecture, |
| | ) |
| | from fairseq.utils import safe_hasattr |
| | from collections import OrderedDict |
| |
|
| |
|
| | @register_model("multilingual_transformer_from_mbart") |
| | class MultilingualTransformerModelFromMbart(MultilingualTransformerModel): |
| | @classmethod |
| | def build_model(cls, args, task): |
| | """Build a new model instance.""" |
| | from fairseq.tasks.multilingual_translation import MultilingualTranslationTask |
| |
|
| | assert isinstance(task, MultilingualTranslationTask) |
| |
|
| | |
| | base_multilingual_architecture(args) |
| |
|
| | if not safe_hasattr(args, "max_source_positions"): |
| | args.max_source_positions = 1024 |
| | if not safe_hasattr(args, "max_target_positions"): |
| | args.max_target_positions = 1024 |
| |
|
| | src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs] |
| | tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs] |
| |
|
| | if args.share_encoders: |
| | args.share_encoder_embeddings = True |
| | if args.share_decoders: |
| | args.share_decoder_embeddings = True |
| |
|
| | def build_embedding(dictionary, embed_dim, path=None): |
| | num_embeddings = len(dictionary) |
| | padding_idx = dictionary.pad() |
| | emb = Embedding(num_embeddings, embed_dim, padding_idx) |
| | |
| | if path: |
| | embed_dict = utils.parse_embedding(path) |
| | utils.load_embedding(embed_dict, dictionary, emb) |
| | return emb |
| |
|
| | |
| | shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None |
| | if args.share_all_embeddings: |
| | if args.encoder_embed_dim != args.decoder_embed_dim: |
| | raise ValueError( |
| | "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" |
| | ) |
| | if args.decoder_embed_path and ( |
| | args.decoder_embed_path != args.encoder_embed_path |
| | ): |
| | raise ValueError( |
| | "--share-all-embeddings not compatible with --decoder-embed-path" |
| | ) |
| | shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( |
| | dicts=task.dicts, |
| | langs=task.langs, |
| | embed_dim=args.encoder_embed_dim, |
| | build_embedding=build_embedding, |
| | pretrained_embed_path=args.encoder_embed_path, |
| | ) |
| | shared_decoder_embed_tokens = shared_encoder_embed_tokens |
| | args.share_decoder_input_output_embed = True |
| | else: |
| | if args.share_encoder_embeddings: |
| | shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( |
| | dicts=task.dicts, |
| | langs=src_langs, |
| | embed_dim=args.encoder_embed_dim, |
| | build_embedding=build_embedding, |
| | pretrained_embed_path=args.encoder_embed_path, |
| | ) |
| | if args.share_decoder_embeddings: |
| | shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( |
| | dicts=task.dicts, |
| | langs=tgt_langs, |
| | embed_dim=args.decoder_embed_dim, |
| | build_embedding=build_embedding, |
| | pretrained_embed_path=args.decoder_embed_path, |
| | ) |
| |
|
| | |
| | lang_encoders, lang_decoders = {}, {} |
| |
|
| | def get_encoder(lang): |
| | if lang not in lang_encoders: |
| | if shared_encoder_embed_tokens is not None: |
| | encoder_embed_tokens = shared_encoder_embed_tokens |
| | else: |
| | encoder_embed_tokens = build_embedding( |
| | task.dicts[lang], |
| | args.encoder_embed_dim, |
| | args.encoder_embed_path, |
| | ) |
| | lang_encoders[lang] = MultilingualTransformerModel._get_module_class( |
| | True, args, task.dicts[lang], encoder_embed_tokens, src_langs |
| | ) |
| | return lang_encoders[lang] |
| |
|
| | def get_decoder(lang): |
| | if lang not in lang_decoders: |
| | if shared_decoder_embed_tokens is not None: |
| | decoder_embed_tokens = shared_decoder_embed_tokens |
| | else: |
| | decoder_embed_tokens = build_embedding( |
| | task.dicts[lang], |
| | args.decoder_embed_dim, |
| | args.decoder_embed_path, |
| | ) |
| | lang_decoders[lang] = MultilingualTransformerModel._get_module_class( |
| | False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs |
| | ) |
| | return lang_decoders[lang] |
| |
|
| | |
| | shared_encoder, shared_decoder = None, None |
| | if args.share_encoders: |
| | shared_encoder = get_encoder(src_langs[0]) |
| | if args.share_decoders: |
| | shared_decoder = get_decoder(tgt_langs[0]) |
| |
|
| | encoders, decoders = OrderedDict(), OrderedDict() |
| | for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs): |
| | encoders[lang_pair] = ( |
| | shared_encoder if shared_encoder is not None else get_encoder(src) |
| | ) |
| | decoders[lang_pair] = ( |
| | shared_decoder if shared_decoder is not None else get_decoder(tgt) |
| | ) |
| |
|
| | return MultilingualTransformerModelFromMbart(encoders, decoders) |
| |
|
| | def load_state_dict(self, state_dict, strict=True, model_cfg=None): |
| | state_dict_subset = state_dict.copy() |
| | lang_pairs = set([x.split(".")[1] for x in state_dict.keys()]) |
| | finetune_mode = not any("neutral" in lp for lp in lang_pairs) |
| |
|
| | if finetune_mode: |
| | |
| | |
| | |
| | print("loading pre-trained BART") |
| | self_state_dict = self.state_dict() |
| | for k, v in state_dict.items(): |
| | for lang_pair in self.models: |
| | new_key = k if "models." in k else f"models.{lang_pair}.{k}" |
| | |
| | if self_state_dict[new_key].shape == v.shape: |
| | state_dict_subset[new_key] = v |
| | elif any( |
| | w in k |
| | for w in [ |
| | "encoder.embed_tokens.weight", |
| | "decoder.embed_tokens.weight", |
| | "decoder.output_projection.weight", |
| | ] |
| | ): |
| | |
| | |
| | |
| | print( |
| | f"{k}: {self_state_dict[new_key].shape} != {v.shape}", |
| | end="", |
| | flush=True, |
| | ) |
| | vocab_size = v.shape[0] - 5 |
| | state_dict_subset[new_key] = self_state_dict[new_key] |
| | state_dict_subset[new_key] = v[: vocab_size + 4] |
| | print(f" => fixed by using first {vocab_size + 4} dims") |
| | else: |
| | raise ValueError("unable to load model due to mimatched dims!") |
| | del state_dict_subset[k] |
| | else: |
| | print("loading pre-trained emotion translation model") |
| | for k, _ in state_dict.items(): |
| | assert k.startswith("models.") |
| | lang_pair = k.split(".")[1] |
| | if lang_pair not in self.models: |
| | del state_dict_subset[k] |
| |
|
| | super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg) |
| |
|
| |
|
| | @register_model_architecture("transformer", "transformer_small") |
| | def transformer_small(args): |
| | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) |
| | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512) |
| | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) |
| | args.encoder_layers = getattr(args, "encoder_layers", 3) |
| | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) |
| | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 512) |
| | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) |
| | args.decoder_layers = getattr(args, "decoder_layers", 3) |
| | base_architecture(args) |
| |
|
| |
|
| | @register_model_architecture( |
| | "multilingual_transformer_from_mbart", "multilingual_small" |
| | ) |
| | def multilingual_small(args): |
| | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) |
| | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512) |
| | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) |
| | args.encoder_layers = getattr(args, "encoder_layers", 3) |
| | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) |
| | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 512) |
| | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) |
| | args.decoder_layers = getattr(args, "decoder_layers", 3) |
| | base_multilingual_architecture(args) |
| |
|