File size: 3,775 Bytes
7b95dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from torch.ao.quantization import QuantStub, DeQuantStub, get_default_qconfig, prepare, convert, QConfigMapping
from pointcept.engines.defaults import default_config_parser, default_argument_parser, default_setup
from pointcept.models.default import DefaultSegmentorV2
from pointcept.engines.launch import launch
from thop import profile
# 导入 spconv 库,以便检查模块类型
import spconv.pytorch as spconv

def main_worker(cfg):
    cfg = default_setup(cfg)
    
    device = torch.device("cuda")
    print(f"INFO: Running Partial PTQ on device: {device}")

    cfg.model.pop('type', None)
    model = DefaultSegmentorV2(**cfg.model)
    
    checkpoint = torch.load(cfg.weight, map_location=device)
    state_dict = checkpoint.get('state_dict', checkpoint.get('model', checkpoint))
    model.load_state_dict(state_dict, strict=True)
    
    model.to(device)
    model.eval()

    class QuantWrapper(torch.nn.Module):
        def __init__(self, model_fp32):
            super().__init__()
            self.quant = QuantStub()
            self.dequant = DeQuantStub()
            self.model_fp32 = model_fp32
        
        def forward(self, data_dict):
            if 'feat' in data_dict:
                data_dict['feat'] = self.quant(data_dict['feat'])
            point = self.model_fp32.backbone(data_dict)
            seg_logits = self.model_fp32.seg_head(point["feat"])
            seg_logits = self.dequant(seg_logits)
            return seg_logits

    quant_model = QuantWrapper(model)
    quant_model.eval()
    
    backend = 'fbgemm'
    quant_model.qconfig = get_default_qconfig(backend)
    
    # ==================== 【关键修改处】 ====================
    # 这是当前版本PyTorch中,跳过量化特定模块的正确方法
    qconfig_mapping = QConfigMapping().set_object_type(
        spconv.conv.SubMConv3d, None
    ).set_object_type(
        spconv.conv.SparseConv3d, None
    )
    # 之前那两行 add_..._list 的代码已被移除
    # =======================================================

    print("INFO: Preparing model for Partial PTQ...")
    model_prepared = prepare(quant_model, qconfig_mapping=qconfig_mapping, inplace=False)

    in_channels = cfg.model['backbone']['in_channels']
    num_points = 4096
    dummy_dict = {
        "feat": torch.randn(num_points, in_channels),
        "coord": torch.randn(num_points, 3),
        "grid_coord": torch.randint(0, 100, (num_points, 3), dtype=torch.int32),
        "offset": torch.tensor([num_points], dtype=torch.long),
        "batch": torch.zeros(num_points, dtype=torch.long)
    }
    for key in dummy_dict:
        if isinstance(dummy_dict[key], torch.Tensor):
            dummy_dict[key] = dummy_dict[key].to(device)

    print("INFO: Calibrating model with dummy data...")
    with torch.no_grad():
        for _ in range(5):
            model_prepared(dummy_dict)
    print("INFO: Calibration complete.")

    model_quantized = convert(model_prepared, inplace=False)
    print("INFO: Model conversion to quantized complete.")

    save_path = cfg.save_path + '/model_partial_ptq_quantized.pth'
    torch.save(model_quantized.state_dict(), save_path)
    print(f"Partial PTQ Quantized model state_dict saved to {save_path}")

    try:
        flops, params = profile(model_quantized, inputs=(dummy_dict,))
        print(f"Partial PTQ Quantized FLOPs: {flops / 1e9:.2f} GFLOPs, Params: {params / 1e6:.2f} M")
    except Exception as e:
        print(f"FLOPs calculation failed: {e}")

def main():
    args = default_argument_parser().parse_args()
    cfg = default_config_parser(args.config_file, args.options)
    args.num_gpus = 1
    launch(main_worker, num_gpus_per_machine=args.num_gpus, cfg=(cfg,))

if __name__ == "__main__":
    main()