""" Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. """ import copy from typing import Tuple # calflops is a training-time FLOPs profiler that pulls in `accelerate`. It is not # needed for inference, and the deployment image removes `accelerate`, so import it # lazily inside stats() to keep the package importable for inference-only use. def stats( cfg, input_shape: Tuple = (1, 3, 640, 640), ) -> Tuple[int, dict]: from calflops import calculate_flops base_size = cfg.train_dataloader.collate_fn.base_size input_shape = (1, 3, base_size, base_size) model_for_info = copy.deepcopy(cfg.model).deploy() flops, macs, _ = calculate_flops( model=model_for_info, input_shape=input_shape, output_as_string=True, output_precision=4, print_detailed=False, ) params = sum(p.numel() for p in model_for_info.parameters()) del model_for_info return params, {"Model FLOPs:%s MACs:%s Params:%s" % (flops, macs, params)}