| import torch |
| from torch.ao.quantization import QuantStub, DeQuantStub, get_default_qconfig, prepare, convert, QConfigMapping |
| from pointcept.engines.defaults import default_config_parser, default_argument_parser, default_setup |
| from pointcept.models.default import DefaultSegmentorV2 |
| from pointcept.engines.launch import launch |
| from thop import profile |
| |
| import spconv.pytorch as spconv |
|
|
| def main_worker(cfg): |
| cfg = default_setup(cfg) |
| |
| device = torch.device("cuda") |
| print(f"INFO: Running Partial PTQ on device: {device}") |
|
|
| cfg.model.pop('type', None) |
| model = DefaultSegmentorV2(**cfg.model) |
| |
| checkpoint = torch.load(cfg.weight, map_location=device) |
| state_dict = checkpoint.get('state_dict', checkpoint.get('model', checkpoint)) |
| model.load_state_dict(state_dict, strict=True) |
| |
| model.to(device) |
| model.eval() |
|
|
| class QuantWrapper(torch.nn.Module): |
| def __init__(self, model_fp32): |
| super().__init__() |
| self.quant = QuantStub() |
| self.dequant = DeQuantStub() |
| self.model_fp32 = model_fp32 |
| |
| def forward(self, data_dict): |
| if 'feat' in data_dict: |
| data_dict['feat'] = self.quant(data_dict['feat']) |
| point = self.model_fp32.backbone(data_dict) |
| seg_logits = self.model_fp32.seg_head(point["feat"]) |
| seg_logits = self.dequant(seg_logits) |
| return seg_logits |
|
|
| quant_model = QuantWrapper(model) |
| quant_model.eval() |
| |
| backend = 'fbgemm' |
| quant_model.qconfig = get_default_qconfig(backend) |
| |
| |
| |
| qconfig_mapping = QConfigMapping().set_object_type( |
| spconv.conv.SubMConv3d, None |
| ).set_object_type( |
| spconv.conv.SparseConv3d, None |
| ) |
| |
| |
|
|
| print("INFO: Preparing model for Partial PTQ...") |
| model_prepared = prepare(quant_model, qconfig_mapping=qconfig_mapping, inplace=False) |
|
|
| in_channels = cfg.model['backbone']['in_channels'] |
| num_points = 4096 |
| dummy_dict = { |
| "feat": torch.randn(num_points, in_channels), |
| "coord": torch.randn(num_points, 3), |
| "grid_coord": torch.randint(0, 100, (num_points, 3), dtype=torch.int32), |
| "offset": torch.tensor([num_points], dtype=torch.long), |
| "batch": torch.zeros(num_points, dtype=torch.long) |
| } |
| for key in dummy_dict: |
| if isinstance(dummy_dict[key], torch.Tensor): |
| dummy_dict[key] = dummy_dict[key].to(device) |
|
|
| print("INFO: Calibrating model with dummy data...") |
| with torch.no_grad(): |
| for _ in range(5): |
| model_prepared(dummy_dict) |
| print("INFO: Calibration complete.") |
|
|
| model_quantized = convert(model_prepared, inplace=False) |
| print("INFO: Model conversion to quantized complete.") |
|
|
| save_path = cfg.save_path + '/model_partial_ptq_quantized.pth' |
| torch.save(model_quantized.state_dict(), save_path) |
| print(f"Partial PTQ Quantized model state_dict saved to {save_path}") |
|
|
| try: |
| flops, params = profile(model_quantized, inputs=(dummy_dict,)) |
| print(f"Partial PTQ Quantized FLOPs: {flops / 1e9:.2f} GFLOPs, Params: {params / 1e6:.2f} M") |
| except Exception as e: |
| print(f"FLOPs calculation failed: {e}") |
|
|
| def main(): |
| args = default_argument_parser().parse_args() |
| cfg = default_config_parser(args.config_file, args.options) |
| args.num_gpus = 1 |
| launch(main_worker, num_gpus_per_machine=args.num_gpus, cfg=(cfg,)) |
|
|
| if __name__ == "__main__": |
| main() |