Spaces:
Running
Running
| # ------------------------------------------------------------------------ | |
| # Modified from OFA (https://github.com/OFA-Sys/OFA) | |
| # Copyright 2022 The OFA-Sys Team. | |
| # All rights reserved. | |
| # This source code is licensed under the Apache 2.0 license | |
| # found in the LICENSE file in the root directory. | |
| # ------------------------------------------------------------------------ | |
| # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| from dataclasses import dataclass, field | |
| import logging | |
| import os | |
| import math | |
| import torch | |
| from typing import Dict, Optional | |
| from fairseq import search | |
| from fairseq.data import FairseqDataset, iterators, Dictionary | |
| from fairseq.optim.amp_optimizer import AMPOptimizer | |
| from fairseq.dataclass import FairseqDataclass | |
| from fairseq.tasks import FairseqTask, register_task | |
| from omegaconf import DictConfig | |
| from torch import Tensor, device, dtype, nn | |
| logger = logging.getLogger(__name__) | |
| def load_bert_pretrained_weights(model, ckpt_path): | |
| try: | |
| state_dict = torch.load(ckpt_path, map_location="cpu") | |
| except Exception: | |
| raise OSError( | |
| "Unable to load weights from pytorch checkpoint file. " | |
| "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " | |
| ) | |
| missing_keys = [] | |
| unexpected_keys = [] | |
| error_msgs = [] | |
| # Convert old format to new format if needed from a PyTorch state_dict | |
| old_keys = [] | |
| new_keys = [] | |
| for key in state_dict.keys(): | |
| new_key = None | |
| if "gamma" in key: | |
| new_key = key.replace("gamma", "weight") | |
| if "beta" in key: | |
| new_key = key.replace("beta", "bias") | |
| if new_key: | |
| old_keys.append(key) | |
| new_keys.append(new_key) | |
| for old_key, new_key in zip(old_keys, new_keys): | |
| state_dict[new_key] = state_dict.pop(old_key) | |
| # copy state_dict so _load_from_state_dict can modify it | |
| metadata = getattr(state_dict, "_metadata", None) | |
| state_dict = state_dict.copy() | |
| if metadata is not None: | |
| state_dict._metadata = metadata | |
| ############################################################################################## | |
| # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | |
| # so we need to apply the function recursively. | |
| def load(module: nn.Module, prefix=""): | |
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |
| module._load_from_state_dict( | |
| state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, | |
| ) | |
| for name, child in module._modules.items(): | |
| if child is not None: | |
| load(child, prefix + name + ".") | |
| # Make sure we are able to load base models as well as derived models (with heads) | |
| start_prefix = "bert." | |
| load(model, prefix=start_prefix) | |
| if len(unexpected_keys) > 0: | |
| logger.warning( | |
| f"Some weights of the model checkpoint at {ckpt_path} were not used when " | |
| f"initializing {model.__class__.__name__}: {unexpected_keys}\n" | |
| f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " | |
| f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" | |
| f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " | |
| f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." | |
| ) | |
| else: | |
| logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") | |
| if len(missing_keys) > 0: | |
| logger.warning( | |
| f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {ckpt_path} " | |
| f"and are newly initialized: {missing_keys}\n" | |
| f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." | |
| ) | |
| else: | |
| logger.info( | |
| f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {ckpt_path}.\n" | |
| f"If your task is similar to the task the model of the ckeckpoint was trained on, " | |
| f"you can already use {model.__class__.__name__} for predictions without further training." | |
| ) | |
| if len(error_msgs) > 0: | |
| raise RuntimeError( | |
| "Error(s) in loading state_dict for {}:\n\t{}".format( | |
| model.__class__.__name__, "\n\t".join(error_msgs) | |
| ) | |
| ) | |
| class BaseConfig(FairseqDataclass): | |
| data: Optional[str] = field( | |
| default=None, | |
| metadata={ | |
| "help": "comma separated path to data list, will be iterated upon during epochs " | |
| "in round-robin manner; valid data are always in the last" | |
| }, | |
| ) | |
| selected_cols: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "selected cols"}, | |
| ) | |
| bpe_dir: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "bpe dir"}, | |
| ) | |
| max_source_positions: int = field( | |
| default=1024, metadata={"help": "max number of tokens in the source sequence"} | |
| ) | |
| max_target_positions: int = field( | |
| default=1024, metadata={"help": "max number of tokens in the target sequence"} | |
| ) | |
| max_src_length: int = field( | |
| default=128, metadata={"help": "the maximum src sequence length"} | |
| ) | |
| max_tgt_length: int = field( | |
| default=30, metadata={"help": "the maximum target sequence length"} | |
| ) | |
| code_dict_size: int = field( | |
| default=8192, metadata={"help": "code dict size"} | |
| ) | |
| patch_image_size: int = field( | |
| default=480, metadata={"help": "patch image size"} | |
| ) | |
| num_bins: int = field( | |
| default=1000, metadata={"help": "number of quantization bins"} | |
| ) | |
| imagenet_default_mean_and_std: bool = field( | |
| default=False, | |
| metadata={"help": "imagenet normalize"}, | |
| ) | |
| constraint_range: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "constraint range"} | |
| ) | |
| class BaseTask(FairseqTask): | |
| def __init__(self, cfg: BaseConfig, src_dict, tgt_dict): | |
| super().__init__(cfg) | |
| self.src_dict = src_dict | |
| self.tgt_dict = tgt_dict | |
| def setup_task(cls, cfg: DictConfig, **kwargs): | |
| """Setup the task.""" | |
| # Define dictionaries | |
| src_dict = Dictionary() | |
| tgt_dict = Dictionary() | |
| # Add 2D bin tokens | |
| for i in range(cfg.num_bins): | |
| for j in range(cfg.num_bins): | |
| src_dict.add_symbol("<bin_{}_{}>".format(i, j)) | |
| tgt_dict.add_symbol("<bin_{}_{}>".format(i, j)) | |
| logger.info("source dictionary: {} types".format(len(src_dict))) | |
| logger.info("target dictionary: {} types".format(len(tgt_dict))) | |
| return cls(cfg, src_dict, tgt_dict) | |
| def get_batch_iterator( | |
| self, | |
| dataset, | |
| max_tokens=None, | |
| max_sentences=None, | |
| max_positions=None, | |
| ignore_invalid_inputs=False, | |
| required_batch_size_multiple=1, | |
| seed=1, | |
| num_shards=1, | |
| shard_id=0, | |
| num_workers=0, | |
| epoch=1, | |
| data_buffer_size=0, | |
| disable_iterator_cache=False, | |
| ): | |
| assert isinstance(dataset, FairseqDataset) | |
| # initialize the dataset with the correct starting epoch | |
| dataset.set_epoch(epoch) | |
| # create mini-batches with given size constraints | |
| batch_sampler = [ | |
| [j for j in range(i, min(i + max_sentences, len(dataset)))] | |
| for i in range(0, len(dataset), max_sentences) | |
| ] | |
| total_row_count = dataset.dataset.get_total_row_count() | |
| num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences) | |
| if len(batch_sampler) < num_batches: | |
| batch_sampler.append([]) | |
| # return a reusable, sharded iterator | |
| epoch_iter = iterators.EpochBatchIterator( | |
| dataset=dataset, | |
| collate_fn=dataset.collater, | |
| batch_sampler=batch_sampler, | |
| seed=seed, | |
| num_shards=1, | |
| shard_id=0, | |
| num_workers=num_workers, | |
| epoch=epoch, | |
| buffer_size=data_buffer_size | |
| ) | |
| return epoch_iter | |
| def build_model(self, cfg: FairseqDataclass): | |
| model = super().build_model(cfg) | |
| bpe_dict = { | |
| "_name": "gpt2", | |
| "gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"), | |
| "gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe") | |
| } | |
| bpe_dict = DictConfig(bpe_dict) | |
| self.bpe = self.build_bpe(bpe_dict) | |
| return model | |
| def train_step( | |
| self, sample, model, criterion, optimizer, update_num, ignore_grad=False, **extra_kwargs | |
| ): | |
| """ | |
| Do forward and backward, and return the loss as computed by *criterion* | |
| for the given *model* and *sample*. | |
| Args: | |
| sample (dict): the mini-batch. The format is defined by the | |
| :class:`~fairseq.data.FairseqDataset`. | |
| model (~fairseq.models.BaseFairseqModel): the model | |
| criterion (~fairseq.criterions.FairseqCriterion): the criterion | |
| optimizer (~fairseq.optim.FairseqOptimizer): the optimizer | |
| update_num (int): the current update | |
| ignore_grad (bool): multiply loss by 0 if this is set to True | |
| Returns: | |
| tuple: | |
| - the loss | |
| - the sample size, which is used as the denominator for the | |
| gradient | |
| - logging outputs to display while training | |
| """ | |
| model.train() | |
| model.set_num_updates(update_num) | |
| with torch.autograd.profiler.record_function("forward"): | |
| with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): | |
| loss, sample_size, logging_output = criterion(model, sample, update_num=update_num) | |
| if ignore_grad: | |
| loss *= 0 | |
| with torch.autograd.profiler.record_function("backward"): | |
| optimizer.backward(loss) | |
| return loss, sample_size, logging_output | |
| def max_positions(self): | |
| """Return the max sentence length allowed by the task.""" | |
| return (self.cfg.max_source_positions, self.cfg.max_target_positions) | |
| def source_dictionary(self): | |
| """Return the source :class:`~fairseq.data.Dictionary`.""" | |
| return self.src_dict | |
| def target_dictionary(self): | |
| """Return the target :class:`~fairseq.data.Dictionary`.""" | |
| return self.tgt_dict | |