File size: 4,553 Bytes
f7962bb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | #!/usr/bin/env python3
"""
Weight-only INT8 quantization — no calibration, no forward passes needed.
Uses torchao int8_weight_only which packs weights instantly.
Then re-exports to ExecuTorch XNNPACK .pte.
"""
import os, sys, time, gc, torch
sys.path.insert(0, ".")
MODEL_DIR = "./models/LightOnOCR-2-1B"
FIXED_H, FIXED_W = 1120, 1540
def quantize_vision(orig):
from export_vision import build_vision_module
from torchao.quantization import quantize_, int8_weight_only
from torch.export import export
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
print("\n=== VISION ENCODER (INT8 weight-only) ===")
vision = build_vision_module(orig)
vision = vision.to("cpu").to(torch.float32).eval()
print(f" Params: {sum(p.numel() for p in vision.parameters())/1e6:.1f}M")
# Weight-only quantization — instant, no forward pass
print(" Applying int8_weight_only...")
t0 = time.time()
quantize_(vision, int8_weight_only())
print(f" Quantization took {time.time()-t0:.1f}s")
# Export
print(" torch.export...")
example = (torch.randn(1, 3, FIXED_H, FIXED_W),)
t0 = time.time()
ep = export(vision, example)
print(f" Export took {time.time()-t0:.1f}s")
# Lower to XNNPACK
print(" XNNPACK lowering...")
t0 = time.time()
edge = to_edge_transform_and_lower(
ep,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()]
)
et = edge.to_executorch()
print(f" Lowering took {time.time()-t0:.1f}s")
path = "vision_encoder_int8.pte"
with open(path, "wb") as f:
f.write(et.buffer)
print(f" ✅ {path}: {os.path.getsize(path)/1024/1024:.1f} MB")
del vision, ep, edge, et; gc.collect()
return path
def quantize_decoder(orig):
import export_decoder as ed
from export_decoder import build_decoder_module
from torchao.quantization import quantize_, int8_weight_only
from torch.export import export
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
print("\n=== TEXT DECODER (INT8 weight-only) ===")
decoder = build_decoder_module(orig)
decoder = decoder.to("cpu").to(torch.float32).eval()
print(f" Params: {sum(p.numel() for p in decoder.parameters())/1e6:.1f}M")
# Weight-only quantization — instant
print(" Applying int8_weight_only...")
t0 = time.time()
quantize_(decoder, int8_weight_only())
print(f" Quantization took {time.time()-t0:.1f}s")
# Export
print(" torch.export...")
kv = ed.create_empty_kv_caches(1, torch.float32, "cpu")
example = (
torch.ones(1, 8, dtype=torch.long),
ed.create_causal_mask(8, ed.MAX_SEQ_LEN, torch.float32),
torch.arange(8).unsqueeze(0),
torch.arange(8),
*kv,
)
t0 = time.time()
ep = export(decoder, example)
print(f" Export took {time.time()-t0:.1f}s")
# Lower
print(" XNNPACK lowering...")
t0 = time.time()
edge = to_edge_transform_and_lower(
ep,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()]
)
et = edge.to_executorch()
print(f" Lowering took {time.time()-t0:.1f}s")
path = "text_decoder_int8.pte"
with open(path, "wb") as f:
f.write(et.buffer)
print(f" ✅ {path}: {os.path.getsize(path)/1024/1024:.1f} MB")
del decoder, ep, edge, et; gc.collect()
return path
def main():
from export_vision import load_original_model
print("LightOnOCR INT8 Weight-Only Quantization")
print("No calibration needed — weights quantized instantly\n")
print("Loading model...")
orig = load_original_model()
vis_path = quantize_vision(orig)
dec_path = quantize_decoder(orig)
del orig; gc.collect()
print("\n=== RESULTS ===")
for fp32, int8 in [("vision_encoder.pte", vis_path),
("text_decoder_4096.pte", dec_path)]:
if os.path.exists(fp32) and os.path.exists(int8):
orig_mb = os.path.getsize(fp32) / 1024 / 1024
quant_mb = os.path.getsize(int8) / 1024 / 1024
ratio = quant_mb / orig_mb * 100
print(f" {fp32}: {orig_mb:.0f} MB → {int8}: {quant_mb:.0f} MB ({ratio:.0f}%)")
if __name__ == "__main__":
main()
|