File size: 9,805 Bytes
0aebce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
#!/usr/bin/env python3
"""
E2E Validation for INT8 weight-only quantized models.
Compares: HF original vs INT8 quantized fixed modules.
"""

import os, sys, time, torch, torch.nn.functional as F
from PIL import Image
sys.path.insert(0, ".")

MODEL_DIR = "./models/LightOnOCR-2-1B"
FIXED_H, FIXED_W = 1120, 1540
IMAGE_TOKEN_ID = 151655
EOS_TOKEN_ID = 151645
NUM_LAYERS = 28
NUM_KV_HEADS = 8
HEAD_DIM = 128
MAX_SEQ_LEN = 4096


def get_test_images():
    images = {}
    if os.path.exists("test_images/receipt.png"):
        images["receipt"] = Image.open("test_images/receipt.png").convert("RGB")
    img = Image.new("RGB", (800, 600), "white")
    from PIL import ImageDraw
    draw = ImageDraw.Draw(img)
    draw.text((50, 50), "Invoice #12345", fill="black")
    draw.text((50, 100), "Date: 2024-01-15", fill="black")
    draw.text((50, 150), "Item 1: Widget x5 @ $10.00 = $50.00", fill="black")
    draw.text((50, 200), "Item 2: Gadget x2 @ $24.99 = $49.98", fill="black")
    draw.text((50, 250), "Total: $99.98", fill="black")
    images["synthetic"] = img
    return images


def preprocess_image_fixed(img, processor):
    img_resized = img.resize((FIXED_W, FIXED_H), Image.LANCZOS)
    dummy_msg = [{"role": "user", "content": [{"type": "image"}]}]
    text = processor.apply_chat_template(dummy_msg, add_generation_prompt=True, tokenize=False)
    inputs = processor(text=text, images=[img_resized], return_tensors="pt")
    return inputs["pixel_values"]


def build_fixed_input_ids(processor):
    dummy_img = Image.new("RGB", (FIXED_W, FIXED_H), "white")
    messages = [{"role": "user", "content": [
        {"type": "image"}, {"type": "text", "text": "OCR this document. Extract all text."}
    ]}]
    text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    inputs = processor(text=text, images=[dummy_img], return_tensors="pt")
    return inputs["input_ids"]


def run_hf_model(images, processor):
    from transformers import AutoModelForImageTextToText
    from safetensors.torch import load_file

    print("\n[HF Model]")
    model = AutoModelForImageTextToText.from_pretrained(
        MODEL_DIR, dtype=torch.bfloat16, attn_implementation="sdpa", device_map="cpu")
    state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors"))
    remapped = {k.replace("model.vision_encoder.", "model.vision_tower.")
                 .replace("model.vision_projection.", "model.multi_modal_projector."): v
                for k, v in state_dict.items()}
    model.load_state_dict(remapped, strict=False)
    model = model.to("cuda").eval()

    results = {}
    for name, img in images.items():
        print(f"  [{name}] HF generate...")
        pv = preprocess_image_fixed(img, processor).to("cuda")
        input_ids = build_fixed_input_ids(processor).to("cuda")
        input_len = input_ids.shape[1]
        t0 = time.time()
        with torch.no_grad():
            out = model.generate(
                input_ids=input_ids, pixel_values=pv,
                attention_mask=torch.ones_like(input_ids),
                image_sizes=torch.tensor([[FIXED_H, FIXED_W]], device="cuda"),
                max_new_tokens=512, do_sample=False, temperature=None, top_p=None)
        elapsed = time.time() - t0
        text = processor.tokenizer.decode(out[0, input_len:], skip_special_tokens=True)
        n = len(out[0]) - input_len
        print(f"    {n} tok, {elapsed:.1f}s ({n/elapsed:.1f} tok/s)")
        print(f"    {text[:150]}...")
        results[name] = {"text": text, "tokens": n, "time": elapsed}
    del model; torch.cuda.empty_cache()
    return results


def run_int8_modules(images, processor):
    """Run INT8 weight-only quantized fixed modules E2E."""
    from export_vision import build_vision_module, load_original_model
    from export_decoder import build_decoder_module
    from torchao.quantization import quantize_, int8_weight_only

    print("\n[INT8 Quantized Modules]")
    orig = load_original_model()
    vision = build_vision_module(orig)
    decoder = build_decoder_module(orig)
    embed_tokens = orig.model.language_model.embed_tokens

    device = "cuda"
    dtype = torch.bfloat16

    # Apply INT8 weight-only quantization (same as what we exported to .pte)
    print("  Applying int8_weight_only to vision...")
    vision = vision.to("cpu").to(torch.float32)
    quantize_(vision, int8_weight_only())
    vision = vision.to(device).to(dtype).eval()

    print("  Applying int8_weight_only to decoder...")
    decoder = decoder.to("cpu").to(torch.float32)
    quantize_(decoder, int8_weight_only())
    decoder = decoder.to(device).to(dtype).eval()

    embed_tokens = embed_tokens.to(device).to(dtype)
    del orig; torch.cuda.empty_cache()

    results = {}
    for name, img in images.items():
        print(f"  [{name}] INT8 E2E...")
        try:
            pv = preprocess_image_fixed(img, processor).to(device).to(dtype)
            input_ids = build_fixed_input_ids(processor).to(device)

            with torch.no_grad():
                image_features = vision(pv)
            print(f"    Vision: {image_features.shape}")

            with torch.no_grad():
                text_embeds = embed_tokens(input_ids)

            ids_list = input_ids[0].tolist()
            img_positions = [i for i, t in enumerate(ids_list) if t == IMAGE_TOKEN_ID]

            combined = text_embeds.clone()
            indices = torch.tensor(img_positions, device=device)
            combined[0, indices] = image_features[0]

            seq_len = combined.shape[1]

            kv_caches = []
            for _ in range(NUM_LAYERS):
                k = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device)
                v = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device)
                kv_caches.extend([k, v])

            position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
            cache_position = torch.arange(seq_len, device=device)
            mask = torch.full((1, 1, seq_len, MAX_SEQ_LEN), float("-inf"), dtype=dtype, device=device)
            for i in range(seq_len):
                mask[0, 0, i, :i+1] = 0.0

            orig_embed = decoder.embed_tokens
            class PrefillEmbed(torch.nn.Module):
                def __init__(self, e): super().__init__(); self.e = e
                def forward(self, x): return self.e
            decoder.embed_tokens = PrefillEmbed(combined)

            t0 = time.time()
            with torch.no_grad():
                result = decoder(input_ids[:, :seq_len], mask, position_ids, cache_position, *kv_caches)
            decoder.embed_tokens = orig_embed

            logits = result[0]
            kv_caches = list(result[1:])
            next_token = logits[0, -1].argmax().item()
            generated = [next_token]
            cur_pos = seq_len

            for step in range(511):
                if next_token == EOS_TOKEN_ID or cur_pos >= MAX_SEQ_LEN:
                    break
                token_input = torch.tensor([[next_token]], device=device)
                pos_ids = torch.tensor([[cur_pos]], device=device)
                cache_pos = torch.tensor([cur_pos], device=device)
                dmask = torch.zeros(1, 1, 1, MAX_SEQ_LEN, dtype=dtype, device=device)
                dmask[0, 0, 0, cur_pos+1:] = float("-inf")
                with torch.no_grad():
                    result = decoder(token_input, dmask, pos_ids, cache_pos, *kv_caches)
                logits = result[0]
                kv_caches = list(result[1:])
                next_token = logits[0, -1].argmax().item()
                generated.append(next_token)
                cur_pos += 1

            elapsed = time.time() - t0
            text = processor.tokenizer.decode(generated, skip_special_tokens=True)
            n = len(generated)
            print(f"    {n} tok, {elapsed:.1f}s ({n/elapsed:.1f} tok/s)")
            print(f"    {text[:150]}...")
            results[name] = {"text": text, "tokens": n, "time": elapsed}

        except Exception as e:
            import traceback; traceback.print_exc()
            results[name] = {"text": f"ERROR: {e}", "tokens": 0, "time": 0}

    return results


def levenshtein(s1, s2):
    if len(s1) < len(s2): return levenshtein(s2, s1)
    if len(s2) == 0: return len(s1)
    prev = list(range(len(s2) + 1))
    for i, c1 in enumerate(s1):
        curr = [i + 1]
        for j, c2 in enumerate(s2):
            curr.append(min(prev[j+1]+1, curr[j]+1, prev[j]+(c1!=c2)))
        prev = curr
    return prev[-1]


def main():
    from transformers import AutoProcessor
    processor = AutoProcessor.from_pretrained(MODEL_DIR)

    print("="*60)
    print("LightOnOCR E2E: HF vs INT8 Quantized")
    print("="*60)

    images = get_test_images()
    hf = run_hf_model(images, processor)
    torch.cuda.empty_cache()
    q8 = run_int8_modules(images, processor)

    print("\n" + "="*60)
    print("COMPARISON: HF (FP32) vs INT8 Weight-Only")
    print("="*60)

    for name in images:
        hf_t = hf[name]["text"]
        q8_t = q8[name]["text"]
        exact = hf_t.strip() == q8_t.strip()
        ed = levenshtein(hf_t, q8_t)
        max_len = max(len(hf_t), len(q8_t), 1)
        char_acc = 1.0 - ed / max_len
        ref_w = set(hf_t.lower().split())
        hyp_w = set(q8_t.lower().split())
        word_acc = len(ref_w & hyp_w) / len(ref_w | hyp_w) if ref_w | hyp_w else 1.0

        print(f"\n{'─'*60}")
        print(f"  [{name}]")
        print(f"  HF   ({hf[name]['tokens']} tok): {hf_t[:200]}")
        print(f"  INT8 ({q8[name]['tokens']} tok): {q8_t[:200]}")
        print(f"  Exact: {'✅' if exact else '❌'} | Edit dist: {ed} | Char acc: {char_acc:.4f} | Word acc: {word_acc:.4f}")


if __name__ == "__main__":
    main()