File size: 4,553 Bytes
f7962bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
"""
Weight-only INT8 quantization — no calibration, no forward passes needed.
Uses torchao int8_weight_only which packs weights instantly.
Then re-exports to ExecuTorch XNNPACK .pte.
"""

import os, sys, time, gc, torch
sys.path.insert(0, ".")

MODEL_DIR = "./models/LightOnOCR-2-1B"
FIXED_H, FIXED_W = 1120, 1540


def quantize_vision(orig):
    from export_vision import build_vision_module
    from torchao.quantization import quantize_, int8_weight_only
    from torch.export import export
    from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
    from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

    print("\n=== VISION ENCODER (INT8 weight-only) ===")
    vision = build_vision_module(orig)
    vision = vision.to("cpu").to(torch.float32).eval()
    print(f"  Params: {sum(p.numel() for p in vision.parameters())/1e6:.1f}M")

    # Weight-only quantization — instant, no forward pass
    print("  Applying int8_weight_only...")
    t0 = time.time()
    quantize_(vision, int8_weight_only())
    print(f"  Quantization took {time.time()-t0:.1f}s")

    # Export
    print("  torch.export...")
    example = (torch.randn(1, 3, FIXED_H, FIXED_W),)
    t0 = time.time()
    ep = export(vision, example)
    print(f"  Export took {time.time()-t0:.1f}s")

    # Lower to XNNPACK
    print("  XNNPACK lowering...")
    t0 = time.time()
    edge = to_edge_transform_and_lower(
        ep,
        compile_config=EdgeCompileConfig(_check_ir_validity=False),
        partitioner=[XnnpackPartitioner()]
    )
    et = edge.to_executorch()
    print(f"  Lowering took {time.time()-t0:.1f}s")

    path = "vision_encoder_int8.pte"
    with open(path, "wb") as f:
        f.write(et.buffer)
    print(f"  ✅ {path}: {os.path.getsize(path)/1024/1024:.1f} MB")
    del vision, ep, edge, et; gc.collect()
    return path


def quantize_decoder(orig):
    import export_decoder as ed
    from export_decoder import build_decoder_module
    from torchao.quantization import quantize_, int8_weight_only
    from torch.export import export
    from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
    from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

    print("\n=== TEXT DECODER (INT8 weight-only) ===")
    decoder = build_decoder_module(orig)
    decoder = decoder.to("cpu").to(torch.float32).eval()
    print(f"  Params: {sum(p.numel() for p in decoder.parameters())/1e6:.1f}M")

    # Weight-only quantization — instant
    print("  Applying int8_weight_only...")
    t0 = time.time()
    quantize_(decoder, int8_weight_only())
    print(f"  Quantization took {time.time()-t0:.1f}s")

    # Export
    print("  torch.export...")
    kv = ed.create_empty_kv_caches(1, torch.float32, "cpu")
    example = (
        torch.ones(1, 8, dtype=torch.long),
        ed.create_causal_mask(8, ed.MAX_SEQ_LEN, torch.float32),
        torch.arange(8).unsqueeze(0),
        torch.arange(8),
        *kv,
    )
    t0 = time.time()
    ep = export(decoder, example)
    print(f"  Export took {time.time()-t0:.1f}s")

    # Lower
    print("  XNNPACK lowering...")
    t0 = time.time()
    edge = to_edge_transform_and_lower(
        ep,
        compile_config=EdgeCompileConfig(_check_ir_validity=False),
        partitioner=[XnnpackPartitioner()]
    )
    et = edge.to_executorch()
    print(f"  Lowering took {time.time()-t0:.1f}s")

    path = "text_decoder_int8.pte"
    with open(path, "wb") as f:
        f.write(et.buffer)
    print(f"  ✅ {path}: {os.path.getsize(path)/1024/1024:.1f} MB")
    del decoder, ep, edge, et; gc.collect()
    return path


def main():
    from export_vision import load_original_model

    print("LightOnOCR INT8 Weight-Only Quantization")
    print("No calibration needed — weights quantized instantly\n")

    print("Loading model...")
    orig = load_original_model()

    vis_path = quantize_vision(orig)
    dec_path = quantize_decoder(orig)
    del orig; gc.collect()

    print("\n=== RESULTS ===")
    for fp32, int8 in [("vision_encoder.pte", vis_path),
                       ("text_decoder_4096.pte", dec_path)]:
        if os.path.exists(fp32) and os.path.exists(int8):
            orig_mb = os.path.getsize(fp32) / 1024 / 1024
            quant_mb = os.path.getsize(int8) / 1024 / 1024
            ratio = quant_mb / orig_mb * 100
            print(f"  {fp32}: {orig_mb:.0f} MB → {int8}: {quant_mb:.0f} MB ({ratio:.0f}%)")


if __name__ == "__main__":
    main()