acul3 commited on
Commit
f7962bb
·
verified ·
1 Parent(s): 77cf118

Upload scripts/quantize_wo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/quantize_wo.py +135 -0
scripts/quantize_wo.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Weight-only INT8 quantization — no calibration, no forward passes needed.
4
+ Uses torchao int8_weight_only which packs weights instantly.
5
+ Then re-exports to ExecuTorch XNNPACK .pte.
6
+ """
7
+
8
+ import os, sys, time, gc, torch
9
+ sys.path.insert(0, ".")
10
+
11
+ MODEL_DIR = "./models/LightOnOCR-2-1B"
12
+ FIXED_H, FIXED_W = 1120, 1540
13
+
14
+
15
+ def quantize_vision(orig):
16
+ from export_vision import build_vision_module
17
+ from torchao.quantization import quantize_, int8_weight_only
18
+ from torch.export import export
19
+ from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
20
+ from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
21
+
22
+ print("\n=== VISION ENCODER (INT8 weight-only) ===")
23
+ vision = build_vision_module(orig)
24
+ vision = vision.to("cpu").to(torch.float32).eval()
25
+ print(f" Params: {sum(p.numel() for p in vision.parameters())/1e6:.1f}M")
26
+
27
+ # Weight-only quantization — instant, no forward pass
28
+ print(" Applying int8_weight_only...")
29
+ t0 = time.time()
30
+ quantize_(vision, int8_weight_only())
31
+ print(f" Quantization took {time.time()-t0:.1f}s")
32
+
33
+ # Export
34
+ print(" torch.export...")
35
+ example = (torch.randn(1, 3, FIXED_H, FIXED_W),)
36
+ t0 = time.time()
37
+ ep = export(vision, example)
38
+ print(f" Export took {time.time()-t0:.1f}s")
39
+
40
+ # Lower to XNNPACK
41
+ print(" XNNPACK lowering...")
42
+ t0 = time.time()
43
+ edge = to_edge_transform_and_lower(
44
+ ep,
45
+ compile_config=EdgeCompileConfig(_check_ir_validity=False),
46
+ partitioner=[XnnpackPartitioner()]
47
+ )
48
+ et = edge.to_executorch()
49
+ print(f" Lowering took {time.time()-t0:.1f}s")
50
+
51
+ path = "vision_encoder_int8.pte"
52
+ with open(path, "wb") as f:
53
+ f.write(et.buffer)
54
+ print(f" ✅ {path}: {os.path.getsize(path)/1024/1024:.1f} MB")
55
+ del vision, ep, edge, et; gc.collect()
56
+ return path
57
+
58
+
59
+ def quantize_decoder(orig):
60
+ import export_decoder as ed
61
+ from export_decoder import build_decoder_module
62
+ from torchao.quantization import quantize_, int8_weight_only
63
+ from torch.export import export
64
+ from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
65
+ from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
66
+
67
+ print("\n=== TEXT DECODER (INT8 weight-only) ===")
68
+ decoder = build_decoder_module(orig)
69
+ decoder = decoder.to("cpu").to(torch.float32).eval()
70
+ print(f" Params: {sum(p.numel() for p in decoder.parameters())/1e6:.1f}M")
71
+
72
+ # Weight-only quantization — instant
73
+ print(" Applying int8_weight_only...")
74
+ t0 = time.time()
75
+ quantize_(decoder, int8_weight_only())
76
+ print(f" Quantization took {time.time()-t0:.1f}s")
77
+
78
+ # Export
79
+ print(" torch.export...")
80
+ kv = ed.create_empty_kv_caches(1, torch.float32, "cpu")
81
+ example = (
82
+ torch.ones(1, 8, dtype=torch.long),
83
+ ed.create_causal_mask(8, ed.MAX_SEQ_LEN, torch.float32),
84
+ torch.arange(8).unsqueeze(0),
85
+ torch.arange(8),
86
+ *kv,
87
+ )
88
+ t0 = time.time()
89
+ ep = export(decoder, example)
90
+ print(f" Export took {time.time()-t0:.1f}s")
91
+
92
+ # Lower
93
+ print(" XNNPACK lowering...")
94
+ t0 = time.time()
95
+ edge = to_edge_transform_and_lower(
96
+ ep,
97
+ compile_config=EdgeCompileConfig(_check_ir_validity=False),
98
+ partitioner=[XnnpackPartitioner()]
99
+ )
100
+ et = edge.to_executorch()
101
+ print(f" Lowering took {time.time()-t0:.1f}s")
102
+
103
+ path = "text_decoder_int8.pte"
104
+ with open(path, "wb") as f:
105
+ f.write(et.buffer)
106
+ print(f" ✅ {path}: {os.path.getsize(path)/1024/1024:.1f} MB")
107
+ del decoder, ep, edge, et; gc.collect()
108
+ return path
109
+
110
+
111
+ def main():
112
+ from export_vision import load_original_model
113
+
114
+ print("LightOnOCR INT8 Weight-Only Quantization")
115
+ print("No calibration needed — weights quantized instantly\n")
116
+
117
+ print("Loading model...")
118
+ orig = load_original_model()
119
+
120
+ vis_path = quantize_vision(orig)
121
+ dec_path = quantize_decoder(orig)
122
+ del orig; gc.collect()
123
+
124
+ print("\n=== RESULTS ===")
125
+ for fp32, int8 in [("vision_encoder.pte", vis_path),
126
+ ("text_decoder_4096.pte", dec_path)]:
127
+ if os.path.exists(fp32) and os.path.exists(int8):
128
+ orig_mb = os.path.getsize(fp32) / 1024 / 1024
129
+ quant_mb = os.path.getsize(int8) / 1024 / 1024
130
+ ratio = quant_mb / orig_mb * 100
131
+ print(f" {fp32}: {orig_mb:.0f} MB → {int8}: {quant_mb:.0f} MB ({ratio:.0f}%)")
132
+
133
+
134
+ if __name__ == "__main__":
135
+ main()