acul3 commited on
Commit
0aebce7
·
verified ·
1 Parent(s): 02acc80

Upload scripts/test_e2e_int8.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/test_e2e_int8.py +251 -0
scripts/test_e2e_int8.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ E2E Validation for INT8 weight-only quantized models.
4
+ Compares: HF original vs INT8 quantized fixed modules.
5
+ """
6
+
7
+ import os, sys, time, torch, torch.nn.functional as F
8
+ from PIL import Image
9
+ sys.path.insert(0, ".")
10
+
11
+ MODEL_DIR = "./models/LightOnOCR-2-1B"
12
+ FIXED_H, FIXED_W = 1120, 1540
13
+ IMAGE_TOKEN_ID = 151655
14
+ EOS_TOKEN_ID = 151645
15
+ NUM_LAYERS = 28
16
+ NUM_KV_HEADS = 8
17
+ HEAD_DIM = 128
18
+ MAX_SEQ_LEN = 4096
19
+
20
+
21
+ def get_test_images():
22
+ images = {}
23
+ if os.path.exists("test_images/receipt.png"):
24
+ images["receipt"] = Image.open("test_images/receipt.png").convert("RGB")
25
+ img = Image.new("RGB", (800, 600), "white")
26
+ from PIL import ImageDraw
27
+ draw = ImageDraw.Draw(img)
28
+ draw.text((50, 50), "Invoice #12345", fill="black")
29
+ draw.text((50, 100), "Date: 2024-01-15", fill="black")
30
+ draw.text((50, 150), "Item 1: Widget x5 @ $10.00 = $50.00", fill="black")
31
+ draw.text((50, 200), "Item 2: Gadget x2 @ $24.99 = $49.98", fill="black")
32
+ draw.text((50, 250), "Total: $99.98", fill="black")
33
+ images["synthetic"] = img
34
+ return images
35
+
36
+
37
+ def preprocess_image_fixed(img, processor):
38
+ img_resized = img.resize((FIXED_W, FIXED_H), Image.LANCZOS)
39
+ dummy_msg = [{"role": "user", "content": [{"type": "image"}]}]
40
+ text = processor.apply_chat_template(dummy_msg, add_generation_prompt=True, tokenize=False)
41
+ inputs = processor(text=text, images=[img_resized], return_tensors="pt")
42
+ return inputs["pixel_values"]
43
+
44
+
45
+ def build_fixed_input_ids(processor):
46
+ dummy_img = Image.new("RGB", (FIXED_W, FIXED_H), "white")
47
+ messages = [{"role": "user", "content": [
48
+ {"type": "image"}, {"type": "text", "text": "OCR this document. Extract all text."}
49
+ ]}]
50
+ text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
51
+ inputs = processor(text=text, images=[dummy_img], return_tensors="pt")
52
+ return inputs["input_ids"]
53
+
54
+
55
+ def run_hf_model(images, processor):
56
+ from transformers import AutoModelForImageTextToText
57
+ from safetensors.torch import load_file
58
+
59
+ print("\n[HF Model]")
60
+ model = AutoModelForImageTextToText.from_pretrained(
61
+ MODEL_DIR, dtype=torch.bfloat16, attn_implementation="sdpa", device_map="cpu")
62
+ state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors"))
63
+ remapped = {k.replace("model.vision_encoder.", "model.vision_tower.")
64
+ .replace("model.vision_projection.", "model.multi_modal_projector."): v
65
+ for k, v in state_dict.items()}
66
+ model.load_state_dict(remapped, strict=False)
67
+ model = model.to("cuda").eval()
68
+
69
+ results = {}
70
+ for name, img in images.items():
71
+ print(f" [{name}] HF generate...")
72
+ pv = preprocess_image_fixed(img, processor).to("cuda")
73
+ input_ids = build_fixed_input_ids(processor).to("cuda")
74
+ input_len = input_ids.shape[1]
75
+ t0 = time.time()
76
+ with torch.no_grad():
77
+ out = model.generate(
78
+ input_ids=input_ids, pixel_values=pv,
79
+ attention_mask=torch.ones_like(input_ids),
80
+ image_sizes=torch.tensor([[FIXED_H, FIXED_W]], device="cuda"),
81
+ max_new_tokens=512, do_sample=False, temperature=None, top_p=None)
82
+ elapsed = time.time() - t0
83
+ text = processor.tokenizer.decode(out[0, input_len:], skip_special_tokens=True)
84
+ n = len(out[0]) - input_len
85
+ print(f" {n} tok, {elapsed:.1f}s ({n/elapsed:.1f} tok/s)")
86
+ print(f" {text[:150]}...")
87
+ results[name] = {"text": text, "tokens": n, "time": elapsed}
88
+ del model; torch.cuda.empty_cache()
89
+ return results
90
+
91
+
92
+ def run_int8_modules(images, processor):
93
+ """Run INT8 weight-only quantized fixed modules E2E."""
94
+ from export_vision import build_vision_module, load_original_model
95
+ from export_decoder import build_decoder_module
96
+ from torchao.quantization import quantize_, int8_weight_only
97
+
98
+ print("\n[INT8 Quantized Modules]")
99
+ orig = load_original_model()
100
+ vision = build_vision_module(orig)
101
+ decoder = build_decoder_module(orig)
102
+ embed_tokens = orig.model.language_model.embed_tokens
103
+
104
+ device = "cuda"
105
+ dtype = torch.bfloat16
106
+
107
+ # Apply INT8 weight-only quantization (same as what we exported to .pte)
108
+ print(" Applying int8_weight_only to vision...")
109
+ vision = vision.to("cpu").to(torch.float32)
110
+ quantize_(vision, int8_weight_only())
111
+ vision = vision.to(device).to(dtype).eval()
112
+
113
+ print(" Applying int8_weight_only to decoder...")
114
+ decoder = decoder.to("cpu").to(torch.float32)
115
+ quantize_(decoder, int8_weight_only())
116
+ decoder = decoder.to(device).to(dtype).eval()
117
+
118
+ embed_tokens = embed_tokens.to(device).to(dtype)
119
+ del orig; torch.cuda.empty_cache()
120
+
121
+ results = {}
122
+ for name, img in images.items():
123
+ print(f" [{name}] INT8 E2E...")
124
+ try:
125
+ pv = preprocess_image_fixed(img, processor).to(device).to(dtype)
126
+ input_ids = build_fixed_input_ids(processor).to(device)
127
+
128
+ with torch.no_grad():
129
+ image_features = vision(pv)
130
+ print(f" Vision: {image_features.shape}")
131
+
132
+ with torch.no_grad():
133
+ text_embeds = embed_tokens(input_ids)
134
+
135
+ ids_list = input_ids[0].tolist()
136
+ img_positions = [i for i, t in enumerate(ids_list) if t == IMAGE_TOKEN_ID]
137
+
138
+ combined = text_embeds.clone()
139
+ indices = torch.tensor(img_positions, device=device)
140
+ combined[0, indices] = image_features[0]
141
+
142
+ seq_len = combined.shape[1]
143
+
144
+ kv_caches = []
145
+ for _ in range(NUM_LAYERS):
146
+ k = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device)
147
+ v = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device)
148
+ kv_caches.extend([k, v])
149
+
150
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
151
+ cache_position = torch.arange(seq_len, device=device)
152
+ mask = torch.full((1, 1, seq_len, MAX_SEQ_LEN), float("-inf"), dtype=dtype, device=device)
153
+ for i in range(seq_len):
154
+ mask[0, 0, i, :i+1] = 0.0
155
+
156
+ orig_embed = decoder.embed_tokens
157
+ class PrefillEmbed(torch.nn.Module):
158
+ def __init__(self, e): super().__init__(); self.e = e
159
+ def forward(self, x): return self.e
160
+ decoder.embed_tokens = PrefillEmbed(combined)
161
+
162
+ t0 = time.time()
163
+ with torch.no_grad():
164
+ result = decoder(input_ids[:, :seq_len], mask, position_ids, cache_position, *kv_caches)
165
+ decoder.embed_tokens = orig_embed
166
+
167
+ logits = result[0]
168
+ kv_caches = list(result[1:])
169
+ next_token = logits[0, -1].argmax().item()
170
+ generated = [next_token]
171
+ cur_pos = seq_len
172
+
173
+ for step in range(511):
174
+ if next_token == EOS_TOKEN_ID or cur_pos >= MAX_SEQ_LEN:
175
+ break
176
+ token_input = torch.tensor([[next_token]], device=device)
177
+ pos_ids = torch.tensor([[cur_pos]], device=device)
178
+ cache_pos = torch.tensor([cur_pos], device=device)
179
+ dmask = torch.zeros(1, 1, 1, MAX_SEQ_LEN, dtype=dtype, device=device)
180
+ dmask[0, 0, 0, cur_pos+1:] = float("-inf")
181
+ with torch.no_grad():
182
+ result = decoder(token_input, dmask, pos_ids, cache_pos, *kv_caches)
183
+ logits = result[0]
184
+ kv_caches = list(result[1:])
185
+ next_token = logits[0, -1].argmax().item()
186
+ generated.append(next_token)
187
+ cur_pos += 1
188
+
189
+ elapsed = time.time() - t0
190
+ text = processor.tokenizer.decode(generated, skip_special_tokens=True)
191
+ n = len(generated)
192
+ print(f" {n} tok, {elapsed:.1f}s ({n/elapsed:.1f} tok/s)")
193
+ print(f" {text[:150]}...")
194
+ results[name] = {"text": text, "tokens": n, "time": elapsed}
195
+
196
+ except Exception as e:
197
+ import traceback; traceback.print_exc()
198
+ results[name] = {"text": f"ERROR: {e}", "tokens": 0, "time": 0}
199
+
200
+ return results
201
+
202
+
203
+ def levenshtein(s1, s2):
204
+ if len(s1) < len(s2): return levenshtein(s2, s1)
205
+ if len(s2) == 0: return len(s1)
206
+ prev = list(range(len(s2) + 1))
207
+ for i, c1 in enumerate(s1):
208
+ curr = [i + 1]
209
+ for j, c2 in enumerate(s2):
210
+ curr.append(min(prev[j+1]+1, curr[j]+1, prev[j]+(c1!=c2)))
211
+ prev = curr
212
+ return prev[-1]
213
+
214
+
215
+ def main():
216
+ from transformers import AutoProcessor
217
+ processor = AutoProcessor.from_pretrained(MODEL_DIR)
218
+
219
+ print("="*60)
220
+ print("LightOnOCR E2E: HF vs INT8 Quantized")
221
+ print("="*60)
222
+
223
+ images = get_test_images()
224
+ hf = run_hf_model(images, processor)
225
+ torch.cuda.empty_cache()
226
+ q8 = run_int8_modules(images, processor)
227
+
228
+ print("\n" + "="*60)
229
+ print("COMPARISON: HF (FP32) vs INT8 Weight-Only")
230
+ print("="*60)
231
+
232
+ for name in images:
233
+ hf_t = hf[name]["text"]
234
+ q8_t = q8[name]["text"]
235
+ exact = hf_t.strip() == q8_t.strip()
236
+ ed = levenshtein(hf_t, q8_t)
237
+ max_len = max(len(hf_t), len(q8_t), 1)
238
+ char_acc = 1.0 - ed / max_len
239
+ ref_w = set(hf_t.lower().split())
240
+ hyp_w = set(q8_t.lower().split())
241
+ word_acc = len(ref_w & hyp_w) / len(ref_w | hyp_w) if ref_w | hyp_w else 1.0
242
+
243
+ print(f"\n{'─'*60}")
244
+ print(f" [{name}]")
245
+ print(f" HF ({hf[name]['tokens']} tok): {hf_t[:200]}")
246
+ print(f" INT8 ({q8[name]['tokens']} tok): {q8_t[:200]}")
247
+ print(f" Exact: {'✅' if exact else '❌'} | Edit dist: {ed} | Char acc: {char_acc:.4f} | Word acc: {word_acc:.4f}")
248
+
249
+
250
+ if __name__ == "__main__":
251
+ main()