File size: 1,973 Bytes
7b95dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
Main Testing Script

Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
Please cite our work if the code is helpful to you.
"""
from pointcept.engines.defaults import (
    default_argument_parser,
    default_config_parser,
    default_setup,
)
from pointcept.engines.test import TESTERS
from pointcept.engines.launch import launch
from thop import profile
import torch

def main_worker(cfg):
    cfg = default_setup(cfg)
    test_cfg = dict(cfg=cfg, **cfg.test)
    tester = TESTERS.build(test_cfg)

    # 仅在主进程统计 FLOPs(适用于分布式/多GPU环境)
    try:
        import torch.distributed as dist
        is_main_process = not dist.is_initialized() or dist.get_rank() == 0
    except ImportError:
        is_main_process = True

    if is_main_process:
        model = tester.model
        model.eval()
        # 自动获取 in_channels
        try:
            in_channels = model.embedding.in_channels
        except AttributeError:
            in_channels = model.backbone.embedding.in_channels
        num_points = 4096
        dummy_dict = {
            "feat": torch.randn(num_points, in_channels).cuda(),
            "coord": torch.randn(num_points, 3).cuda(),
            "batch": torch.zeros(num_points, dtype=torch.long).cuda(),
            "offset": torch.tensor([num_points], dtype=torch.long).cuda(),
            "grid_size": torch.tensor(0.01).cuda(),  # 通常点云任务用0.01
        }
        flops, params = profile(model, inputs=(dummy_dict,))
        print(f"FLOPs: {flops / 1e9:.2f} GFLOPs, Params: {params / 1e6:.2f} M")

    tester.test()

def main():
    args = default_argument_parser().parse_args()
    cfg = default_config_parser(args.config_file, args.options)

    launch(
        main_worker,
        num_gpus_per_machine=args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        cfg=(cfg,),
    )

if __name__ == "__main__":
    main()