File size: 3,775 Bytes
7b95dc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | import torch
from torch.ao.quantization import QuantStub, DeQuantStub, get_default_qconfig, prepare, convert, QConfigMapping
from pointcept.engines.defaults import default_config_parser, default_argument_parser, default_setup
from pointcept.models.default import DefaultSegmentorV2
from pointcept.engines.launch import launch
from thop import profile
# 导入 spconv 库,以便检查模块类型
import spconv.pytorch as spconv
def main_worker(cfg):
cfg = default_setup(cfg)
device = torch.device("cuda")
print(f"INFO: Running Partial PTQ on device: {device}")
cfg.model.pop('type', None)
model = DefaultSegmentorV2(**cfg.model)
checkpoint = torch.load(cfg.weight, map_location=device)
state_dict = checkpoint.get('state_dict', checkpoint.get('model', checkpoint))
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
class QuantWrapper(torch.nn.Module):
def __init__(self, model_fp32):
super().__init__()
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.model_fp32 = model_fp32
def forward(self, data_dict):
if 'feat' in data_dict:
data_dict['feat'] = self.quant(data_dict['feat'])
point = self.model_fp32.backbone(data_dict)
seg_logits = self.model_fp32.seg_head(point["feat"])
seg_logits = self.dequant(seg_logits)
return seg_logits
quant_model = QuantWrapper(model)
quant_model.eval()
backend = 'fbgemm'
quant_model.qconfig = get_default_qconfig(backend)
# ==================== 【关键修改处】 ====================
# 这是当前版本PyTorch中,跳过量化特定模块的正确方法
qconfig_mapping = QConfigMapping().set_object_type(
spconv.conv.SubMConv3d, None
).set_object_type(
spconv.conv.SparseConv3d, None
)
# 之前那两行 add_..._list 的代码已被移除
# =======================================================
print("INFO: Preparing model for Partial PTQ...")
model_prepared = prepare(quant_model, qconfig_mapping=qconfig_mapping, inplace=False)
in_channels = cfg.model['backbone']['in_channels']
num_points = 4096
dummy_dict = {
"feat": torch.randn(num_points, in_channels),
"coord": torch.randn(num_points, 3),
"grid_coord": torch.randint(0, 100, (num_points, 3), dtype=torch.int32),
"offset": torch.tensor([num_points], dtype=torch.long),
"batch": torch.zeros(num_points, dtype=torch.long)
}
for key in dummy_dict:
if isinstance(dummy_dict[key], torch.Tensor):
dummy_dict[key] = dummy_dict[key].to(device)
print("INFO: Calibrating model with dummy data...")
with torch.no_grad():
for _ in range(5):
model_prepared(dummy_dict)
print("INFO: Calibration complete.")
model_quantized = convert(model_prepared, inplace=False)
print("INFO: Model conversion to quantized complete.")
save_path = cfg.save_path + '/model_partial_ptq_quantized.pth'
torch.save(model_quantized.state_dict(), save_path)
print(f"Partial PTQ Quantized model state_dict saved to {save_path}")
try:
flops, params = profile(model_quantized, inputs=(dummy_dict,))
print(f"Partial PTQ Quantized FLOPs: {flops / 1e9:.2f} GFLOPs, Params: {params / 1e6:.2f} M")
except Exception as e:
print(f"FLOPs calculation failed: {e}")
def main():
args = default_argument_parser().parse_args()
cfg = default_config_parser(args.config_file, args.options)
args.num_gpus = 1
launch(main_worker, num_gpus_per_machine=args.num_gpus, cfg=(cfg,))
if __name__ == "__main__":
main() |