""" Main Testing Script Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) Please cite our work if the code is helpful to you. """ from pointcept.engines.defaults import ( default_argument_parser, default_config_parser, default_setup, ) from pointcept.engines.test import TESTERS from pointcept.engines.launch import launch from thop import profile import torch def main_worker(cfg): cfg = default_setup(cfg) test_cfg = dict(cfg=cfg, **cfg.test) tester = TESTERS.build(test_cfg) # 仅在主进程统计 FLOPs(适用于分布式/多GPU环境) try: import torch.distributed as dist is_main_process = not dist.is_initialized() or dist.get_rank() == 0 except ImportError: is_main_process = True if is_main_process: model = tester.model model.eval() # 自动获取 in_channels try: in_channels = model.embedding.in_channels except AttributeError: in_channels = model.backbone.embedding.in_channels num_points = 4096 dummy_dict = { "feat": torch.randn(num_points, in_channels).cuda(), "coord": torch.randn(num_points, 3).cuda(), "batch": torch.zeros(num_points, dtype=torch.long).cuda(), "offset": torch.tensor([num_points], dtype=torch.long).cuda(), "grid_size": torch.tensor(0.01).cuda(), # 通常点云任务用0.01 } flops, params = profile(model, inputs=(dummy_dict,)) print(f"FLOPs: {flops / 1e9:.2f} GFLOPs, Params: {params / 1e6:.2f} M") tester.test() def main(): args = default_argument_parser().parse_args() cfg = default_config_parser(args.config_file, args.options) launch( main_worker, num_gpus_per_machine=args.num_gpus, num_machines=args.num_machines, machine_rank=args.machine_rank, dist_url=args.dist_url, cfg=(cfg,), ) if __name__ == "__main__": main()