File size: 1,973 Bytes
7b95dc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | """
Main Testing Script
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
Please cite our work if the code is helpful to you.
"""
from pointcept.engines.defaults import (
default_argument_parser,
default_config_parser,
default_setup,
)
from pointcept.engines.test import TESTERS
from pointcept.engines.launch import launch
from thop import profile
import torch
def main_worker(cfg):
cfg = default_setup(cfg)
test_cfg = dict(cfg=cfg, **cfg.test)
tester = TESTERS.build(test_cfg)
# 仅在主进程统计 FLOPs(适用于分布式/多GPU环境)
try:
import torch.distributed as dist
is_main_process = not dist.is_initialized() or dist.get_rank() == 0
except ImportError:
is_main_process = True
if is_main_process:
model = tester.model
model.eval()
# 自动获取 in_channels
try:
in_channels = model.embedding.in_channels
except AttributeError:
in_channels = model.backbone.embedding.in_channels
num_points = 4096
dummy_dict = {
"feat": torch.randn(num_points, in_channels).cuda(),
"coord": torch.randn(num_points, 3).cuda(),
"batch": torch.zeros(num_points, dtype=torch.long).cuda(),
"offset": torch.tensor([num_points], dtype=torch.long).cuda(),
"grid_size": torch.tensor(0.01).cuda(), # 通常点云任务用0.01
}
flops, params = profile(model, inputs=(dummy_dict,))
print(f"FLOPs: {flops / 1e9:.2f} GFLOPs, Params: {params / 1e6:.2f} M")
tester.test()
def main():
args = default_argument_parser().parse_args()
cfg = default_config_parser(args.config_file, args.options)
launch(
main_worker,
num_gpus_per_machine=args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
cfg=(cfg,),
)
if __name__ == "__main__":
main() |