import torch from torch.ao.quantization import QuantStub, DeQuantStub, get_default_qconfig, prepare, convert, QConfigMapping from pointcept.engines.defaults import default_config_parser, default_argument_parser, default_setup from pointcept.models.default import DefaultSegmentorV2 from pointcept.engines.launch import launch from thop import profile # 导入 spconv 库,以便检查模块类型 import spconv.pytorch as spconv def main_worker(cfg): cfg = default_setup(cfg) device = torch.device("cuda") print(f"INFO: Running Partial PTQ on device: {device}") cfg.model.pop('type', None) model = DefaultSegmentorV2(**cfg.model) checkpoint = torch.load(cfg.weight, map_location=device) state_dict = checkpoint.get('state_dict', checkpoint.get('model', checkpoint)) model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() class QuantWrapper(torch.nn.Module): def __init__(self, model_fp32): super().__init__() self.quant = QuantStub() self.dequant = DeQuantStub() self.model_fp32 = model_fp32 def forward(self, data_dict): if 'feat' in data_dict: data_dict['feat'] = self.quant(data_dict['feat']) point = self.model_fp32.backbone(data_dict) seg_logits = self.model_fp32.seg_head(point["feat"]) seg_logits = self.dequant(seg_logits) return seg_logits quant_model = QuantWrapper(model) quant_model.eval() backend = 'fbgemm' quant_model.qconfig = get_default_qconfig(backend) # ==================== 【关键修改处】 ==================== # 这是当前版本PyTorch中,跳过量化特定模块的正确方法 qconfig_mapping = QConfigMapping().set_object_type( spconv.conv.SubMConv3d, None ).set_object_type( spconv.conv.SparseConv3d, None ) # 之前那两行 add_..._list 的代码已被移除 # ======================================================= print("INFO: Preparing model for Partial PTQ...") model_prepared = prepare(quant_model, qconfig_mapping=qconfig_mapping, inplace=False) in_channels = cfg.model['backbone']['in_channels'] num_points = 4096 dummy_dict = { "feat": torch.randn(num_points, in_channels), "coord": torch.randn(num_points, 3), "grid_coord": torch.randint(0, 100, (num_points, 3), dtype=torch.int32), "offset": torch.tensor([num_points], dtype=torch.long), "batch": torch.zeros(num_points, dtype=torch.long) } for key in dummy_dict: if isinstance(dummy_dict[key], torch.Tensor): dummy_dict[key] = dummy_dict[key].to(device) print("INFO: Calibrating model with dummy data...") with torch.no_grad(): for _ in range(5): model_prepared(dummy_dict) print("INFO: Calibration complete.") model_quantized = convert(model_prepared, inplace=False) print("INFO: Model conversion to quantized complete.") save_path = cfg.save_path + '/model_partial_ptq_quantized.pth' torch.save(model_quantized.state_dict(), save_path) print(f"Partial PTQ Quantized model state_dict saved to {save_path}") try: flops, params = profile(model_quantized, inputs=(dummy_dict,)) print(f"Partial PTQ Quantized FLOPs: {flops / 1e9:.2f} GFLOPs, Params: {params / 1e6:.2f} M") except Exception as e: print(f"FLOPs calculation failed: {e}") def main(): args = default_argument_parser().parse_args() cfg = default_config_parser(args.config_file, args.options) args.num_gpus = 1 launch(main_worker, num_gpus_per_machine=args.num_gpus, cfg=(cfg,)) if __name__ == "__main__": main()