File size: 2,022 Bytes
ffa94c8
 
 
 
5ed6654
ffa94c8
 
 
5ed6654
ffa94c8
 
 
 
 
 
 
 
 
 
 
 
 
5ed6654
ffa94c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ed6654
ffa94c8
5ed6654
ffa94c8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python3
"""
Post-training weight quantization for PlaprePico CoreML model.

Applies int8 linear weight quantization to reduce model size
and improve inference speed (less memory bandwidth).

Usage:
    python quantize.py [--input PATH] [--output PATH]
"""

import argparse
from pathlib import Path

import coremltools as ct
from coremltools.optimize.coreml import (
    OpLinearQuantizerConfig,
    OptimizationConfig,
    linear_quantize_weights,
)


def quantize_model(input_path: Path, output_path: Path, nbits: int = 8):
    print(f"Loading {input_path}...")
    model = ct.models.MLModel(str(input_path), compute_units=ct.ComputeUnit.CPU_ONLY)

    print(f"Quantizing weights to int{nbits}...")
    op_config = OpLinearQuantizerConfig(
        mode="linear_symmetric",
        dtype=f"int{nbits}",
        granularity="per_channel",
    )
    config = OptimizationConfig(global_config=op_config)
    quantized = linear_quantize_weights(model, config)

    print(f"Saving to {output_path}...")
    quantized.save(str(output_path))

    # Compare sizes
    import os
    def dir_size(p):
        total = 0
        for f in Path(p).rglob("*"):
            if f.is_file():
                total += f.stat().st_size
        return total

    orig_mb = dir_size(input_path) / 1e6
    quant_mb = dir_size(output_path) / 1e6
    print(f"\nOriginal:   {orig_mb:.1f} MB")
    print(f"Quantized:  {quant_mb:.1f} MB ({quant_mb/orig_mb*100:.0f}%)")
    print("Done!")


def main():
    parser = argparse.ArgumentParser(description="Quantize PlaprePico model")
    parser.add_argument("--input", type=str,
                        default=str(Path(__file__).parent.parent / "PlaprePico.mlpackage"))
    parser.add_argument("--output", type=str, default=None)
    args = parser.parse_args()

    input_path = Path(args.input)
    output_path = Path(args.output) if args.output else input_path.parent / "PlaprePico_int8.mlpackage"

    quantize_model(input_path, output_path, 8)


if __name__ == "__main__":
    main()