| import logging |
| import os |
| import torch |
| import torch.distributed as dist |
| import yaml |
|
|
| from fvcore.nn import FlopCountAnalysis |
| from fvcore.nn import flop_count_table |
| from fvcore.nn import flop_count_str |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| NORM_MODULES = [ |
| torch.nn.BatchNorm1d, |
| torch.nn.BatchNorm2d, |
| torch.nn.BatchNorm3d, |
| torch.nn.SyncBatchNorm, |
| |
| torch.nn.GroupNorm, |
| torch.nn.InstanceNorm1d, |
| torch.nn.InstanceNorm2d, |
| torch.nn.InstanceNorm3d, |
| torch.nn.LayerNorm, |
| torch.nn.LocalResponseNorm, |
| ] |
|
|
| def register_norm_module(cls): |
| NORM_MODULES.append(cls) |
|
|
| return cls |
|
|
|
|
| def is_main_process(): |
| rank = 0 |
| if 'OMPI_COMM_WORLD_SIZE' in os.environ: |
| rank = int(os.environ['OMPI_COMM_WORLD_RANK']) |
|
|
| return rank == 0 |
|
|
|
|
| @torch.no_grad() |
| def analysis_model(model, dump_input, verbose=False): |
| model.eval() |
| flops = FlopCountAnalysis(model, dump_input) |
| total = flops.total() |
| model.train() |
| params_total = sum(p.numel() for p in model.parameters()) |
| params_learned = sum( |
| p.numel() for p in model.parameters() if p.requires_grad |
| ) |
| logger.info(f"flop count table:\n {flop_count_table(flops)}") |
| if verbose: |
| logger.info(f"flop count str:\n {flop_count_str(flops)}") |
| logger.info(f" Total flops: {total/1000/1000:.3f}M,") |
| logger.info(f" Total params: {params_total/1000/1000:.3f}M,") |
| logger.info(f" Learned params: {params_learned/1000/1000:.3f}M") |
|
|
| return total, flop_count_table(flops), flop_count_str(flops) |
|
|
|
|
| def load_config_dict_to_opt(opt, config_dict, splitter='.'): |
| """ |
| Load the key, value pairs from config_dict to opt, overriding existing values in opt |
| if there is any. |
| """ |
| if not isinstance(config_dict, dict): |
| raise TypeError("Config must be a Python dictionary") |
| for k, v in config_dict.items(): |
| k_parts = k.split(splitter) |
| pointer = opt |
| for k_part in k_parts[:-1]: |
| if k_part not in pointer: |
| pointer[k_part] = {} |
| pointer = pointer[k_part] |
| assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict." |
| ori_value = pointer.get(k_parts[-1]) |
| pointer[k_parts[-1]] = v |
| if ori_value: |
| print(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}") |
|
|
|
|
| def load_opt_from_config_file(conf_file): |
| """ |
| Load opt from the config file. |
| |
| Args: |
| conf_file: config file path |
| |
| Returns: |
| dict: a dictionary of opt settings |
| """ |
| opt = {} |
| with open(conf_file, encoding='utf-8') as f: |
| config_dict = yaml.safe_load(f) |
| load_config_dict_to_opt(opt, config_dict) |
|
|
| return opt |
|
|
| def cast_batch_to_dtype(batch, dtype): |
| """ |
| Cast the float32 tensors in a batch to a specified torch dtype. |
| It should be called before feeding the batch to the FP16 DeepSpeed model. |
| |
| Args: |
| batch (torch.tensor or container of torch.tensor): input batch |
| Returns: |
| return_batch: same type as the input batch with internal float32 tensors casted to the specified dtype. |
| """ |
| if torch.is_tensor(batch): |
| if torch.is_floating_point(batch): |
| return_batch = batch.to(dtype) |
| else: |
| return_batch = batch |
| elif isinstance(batch, list): |
| return_batch = [cast_batch_to_dtype(t, dtype) for t in batch] |
| elif isinstance(batch, tuple): |
| return_batch = tuple(cast_batch_to_dtype(t, dtype) for t in batch) |
| elif isinstance(batch, dict): |
| return_batch = {} |
| for k in batch: |
| return_batch[k] = cast_batch_to_dtype(batch[k], dtype) |
| else: |
| logger.debug(f"Can not cast type {type(batch)} to {dtype}. Skipping it in the batch.") |
| return_batch = batch |
|
|
| return return_batch |
|
|
|
|
| def cast_batch_to_half(batch): |
| """ |
| Cast the float32 tensors in a batch to float16. |
| It should be called before feeding the batch to the FP16 DeepSpeed model. |
| |
| Args: |
| batch (torch.tensor or container of torch.tensor): input batch |
| Returns: |
| return_batch: same type as the input batch with internal float32 tensors casted to float16 |
| """ |
| return cast_batch_to_dtype(batch, torch.float16) |
|
|