YYYYYYUUU's picture
Add core reproduction code (binarization layers, PTv3, superpoint ops, min-repro pack)
7b95dc2 verified
Raw
History Blame Contribute Delete
1.97 kB
"""
Main Testing Script
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
Please cite our work if the code is helpful to you.
"""
from pointcept.engines.defaults import (
default_argument_parser,
default_config_parser,
default_setup,
)
from pointcept.engines.test import TESTERS
from pointcept.engines.launch import launch
from thop import profile
import torch
def main_worker(cfg):
cfg = default_setup(cfg)
test_cfg = dict(cfg=cfg, **cfg.test)
tester = TESTERS.build(test_cfg)
# 仅在主进程统计 FLOPs(适用于分布式/多GPU环境)
try:
import torch.distributed as dist
is_main_process = not dist.is_initialized() or dist.get_rank() == 0
except ImportError:
is_main_process = True
if is_main_process:
model = tester.model
model.eval()
# 自动获取 in_channels
try:
in_channels = model.embedding.in_channels
except AttributeError:
in_channels = model.backbone.embedding.in_channels
num_points = 4096
dummy_dict = {
"feat": torch.randn(num_points, in_channels).cuda(),
"coord": torch.randn(num_points, 3).cuda(),
"batch": torch.zeros(num_points, dtype=torch.long).cuda(),
"offset": torch.tensor([num_points], dtype=torch.long).cuda(),
"grid_size": torch.tensor(0.01).cuda(), # 通常点云任务用0.01
}
flops, params = profile(model, inputs=(dummy_dict,))
print(f"FLOPs: {flops / 1e9:.2f} GFLOPs, Params: {params / 1e6:.2f} M")
tester.test()
def main():
args = default_argument_parser().parse_args()
cfg = default_config_parser(args.config_file, args.options)
launch(
main_worker,
num_gpus_per_machine=args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
cfg=(cfg,),
)
if __name__ == "__main__":
main()