ningpp commited on
Commit
b595b4c
·
verified ·
1 Parent(s): eb7dfbf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +717 -0
README.md CHANGED
@@ -21,6 +21,723 @@ base_model:
21
 
22
  Flux is a Java-based OCR
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  # GLM-OCR
 
21
 
22
  Flux is a Java-based OCR
23
 
24
+ ## Attention
25
+ **If you download model before 2026-03-07, you can download model again, current version of the model has better inference performance.**
26
+
27
+
28
+ ## ONNX Inference
29
+ ```
30
+ """
31
+ End-to-end ONNX inference for GLM-OCR model.
32
+
33
+ This script performs complete inference using exported ONNX models:
34
+ 1. Vision encoder (processes images)
35
+ 2. Embedding layer (converts token IDs to embeddings)
36
+ 3. Prefill model (processes prompt)
37
+ 4. Decode model (generates tokens autoregressively)
38
+
39
+ Usage:
40
+ python onnx_inference_e2e.py --image <path> --max-tokens 100
41
+ python onnx_inference_e2e.py --use-real-images --max-tokens 100
42
+ """
43
+
44
+ import os
45
+ import sys
46
+ import time
47
+ import argparse
48
+ from typing import List, Tuple, Optional
49
+ from PIL import Image
50
+ import numpy as np
51
+ import onnxruntime as ort
52
+ from transformers import AutoProcessor, AutoModelForImageTextToText, AutoConfig
53
+
54
+
55
+ class GLMOcrOnnxInference:
56
+ """End-to-end ONNX inference for GLM-OCR."""
57
+
58
+ def __init__(self, onnx_dir: str, device: str = "cpu"):
59
+ """
60
+ Initialize ONNX inference sessions.
61
+
62
+ Args:
63
+ onnx_dir: Directory containing exported ONNX models
64
+ device: "cpu" or "cuda"
65
+ """
66
+ self.onnx_dir = onnx_dir
67
+ self.device = device
68
+ self.providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
69
+
70
+ # Load processor for tokenization
71
+ print(f"Loading processor from {onnx_dir}...")
72
+ self.processor = AutoProcessor.from_pretrained(onnx_dir, trust_remote_code=True)
73
+
74
+ # Model config
75
+ self.config = self._load_config()
76
+
77
+ # Create ONNX sessions
78
+ self.sessions = self._create_sessions()
79
+
80
+ def _load_config(self):
81
+ """Load model configuration without loading the entire model."""
82
+ # Load config directly instead of the entire model
83
+ config = AutoConfig.from_pretrained(self.onnx_dir, trust_remote_code=True)
84
+ return config
85
+
86
+ def _create_sessions(self) -> dict:
87
+ """Create ONNX Runtime sessions for all models."""
88
+ print("Creating ONNX Runtime sessions...")
89
+
90
+ opts = ort.SessionOptions()
91
+ opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
92
+
93
+ if self.device == "cuda":
94
+ # CUDA-specific optimizations
95
+ opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
96
+ opts.enable_mem_pattern = True
97
+ opts.enable_mem_reuse = True
98
+ else:
99
+ # CPU optimizations
100
+ opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
101
+ import multiprocessing
102
+ num_cores = multiprocessing.cpu_count()
103
+ opts.intra_op_num_threads = num_cores
104
+ opts.inter_op_num_threads = 1
105
+
106
+ sessions = {}
107
+
108
+ # Get available providers and set up CUDA options
109
+ if self.device == "cuda":
110
+ available_providers = ort.get_available_providers()
111
+ providers = []
112
+
113
+ # Try TensorRT first if available (best performance)
114
+ if "TensorrtExecutionProvider" in available_providers:
115
+ print(" TensorRT is available but disabled temporarily due to shape inference requirements")
116
+ # Commented out until we run shape inference on the model
117
+ # providers.append(("TensorrtExecutionProvider", {
118
+ # "trt_engine_cache_enable": True,
119
+ # "trt_engine_cache_path": "./trt_cache",
120
+ # "trt_fp16_enable": True,
121
+ # }))
122
+ # print(" Using TensorRT Execution Provider")
123
+
124
+ # Always add CUDAExecutionProvider
125
+ providers.append(("CUDAExecutionProvider", {
126
+ "device_id": 0,
127
+ "arena_extend_strategy": "kNextPowerOfTwo",
128
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
129
+ "do_copy_in_default_stream": True,
130
+ }))
131
+
132
+ # Fallback to CPU
133
+ providers.append("CPUExecutionProvider")
134
+ else:
135
+ providers = self.providers
136
+
137
+ # Vision encoder
138
+ vision_path = os.path.join(self.onnx_dir, "vision_encoder_fused.onnx")
139
+ if os.path.exists(vision_path):
140
+ sessions["vision"] = ort.InferenceSession(
141
+ vision_path, opts, providers=providers
142
+ )
143
+ print(f" ✓ Vision encoder loaded")
144
+
145
+ # Embedding layer
146
+ embedding_path = os.path.join(self.onnx_dir, "embedding.onnx")
147
+ if os.path.exists(embedding_path):
148
+ sessions["embedding"] = ort.InferenceSession(
149
+ embedding_path, opts, providers=providers
150
+ )
151
+ print(f" ✓ Embedding layer loaded")
152
+
153
+ # Prefill model
154
+ prefill_path = os.path.join(self.onnx_dir, "llm_prefill.onnx")
155
+ if os.path.exists(prefill_path):
156
+ sessions["prefill"] = ort.InferenceSession(
157
+ prefill_path, opts, providers=providers
158
+ )
159
+ print(f" ✓ Prefill model loaded")
160
+
161
+ # Decode model
162
+ decode_path = os.path.join(self.onnx_dir, "llm_decode.onnx")
163
+ if os.path.exists(decode_path):
164
+ sessions["decode"] = ort.InferenceSession(
165
+ decode_path, opts, providers=providers
166
+ )
167
+ print(f" ✓ Decode model loaded")
168
+
169
+ return sessions
170
+
171
+ def encode_image(self, image_path: str) -> np.ndarray:
172
+ """
173
+ Encode image using vision encoder.
174
+
175
+ Args:
176
+ image_path: Path to image file
177
+
178
+ Returns:
179
+ Image features as numpy array
180
+ """
181
+ if "vision" not in self.sessions:
182
+ raise RuntimeError("Vision encoder not available")
183
+
184
+ # Load and preprocess image
185
+ image = Image.open(image_path).convert("RGB")
186
+
187
+ # Use full processor to get all necessary inputs (pixel_values, grid_thw)
188
+ messages = [{'role': 'user', 'content': [{'type': 'image'}, {'type': 'text', 'text': 'test'}]}]
189
+ text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
190
+ inputs = self.processor(text=text, images=[image], return_tensors='pt')
191
+
192
+ pixel_values = inputs.pixel_values
193
+ grid_thw = inputs.image_grid_thw
194
+
195
+ # Compute pos_ids and max_grid_size
196
+ pos_ids, max_grid_size = self._compute_pos_ids(grid_thw)
197
+
198
+ # Convert to numpy arrays
199
+ pixel_values_np = pixel_values.numpy()
200
+ pos_ids_np = pos_ids.numpy()
201
+ max_grid_size_np = np.array(max_grid_size, dtype=np.int64)
202
+
203
+ # Run vision encoder
204
+ outputs = self.sessions["vision"].run(None, {
205
+ "pixel_values": pixel_values_np,
206
+ "pos_ids": pos_ids_np,
207
+ "max_grid_size": max_grid_size_np
208
+ })
209
+
210
+ return outputs[0] # image_features
211
+
212
+ def _compute_pos_ids(self, grid_thw, spatial_merge_size: int = 2):
213
+ """
214
+ Pre-compute position IDs for rotary embeddings.
215
+
216
+ Args:
217
+ grid_thw: [batch_size, 3] - (temporal, height_patches, width_patches) for each image
218
+ spatial_merge_size: The spatial merge factor (default 2)
219
+
220
+ Returns:
221
+ pos_ids: [total_patches, 2] - position indices for all patches
222
+ max_grid_size: int - maximum grid dimension
223
+ """
224
+ import torch
225
+ pos_ids_list = []
226
+ for t, h, w in grid_thw:
227
+ t, h, w = int(t), int(h), int(w)
228
+
229
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
230
+ hpos_ids = hpos_ids.reshape(
231
+ h // spatial_merge_size,
232
+ spatial_merge_size,
233
+ w // spatial_merge_size,
234
+ spatial_merge_size,
235
+ )
236
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten()
237
+
238
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
239
+ wpos_ids = wpos_ids.reshape(
240
+ h // spatial_merge_size,
241
+ spatial_merge_size,
242
+ w // spatial_merge_size,
243
+ spatial_merge_size,
244
+ )
245
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten()
246
+
247
+ pos_ids_list.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
248
+
249
+ pos_ids = torch.cat(pos_ids_list, dim=0)
250
+ max_grid_size = int(grid_thw[:, 1:].max())
251
+
252
+ return pos_ids, max_grid_size
253
+
254
+ def _get_rope_index(self, input_ids_list, image_grid_thw, attention_mask_list=None):
255
+ """
256
+ Calculate position_ids for M-RoPE (same logic as PyTorch's get_rope_index).
257
+
258
+ Args:
259
+ input_ids_list: List of input token IDs
260
+ image_grid_thw: Tensor of [t, h, w] for image grid
261
+ attention_mask_list: List of attention mask values
262
+
263
+ Returns:
264
+ position_ids: numpy array of shape [3, seq_len]
265
+ rope_deltas: int, the delta for decode position calculation
266
+ """
267
+ import itertools
268
+
269
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
270
+ image_token_id = self.config.image_token_id
271
+
272
+ # Get image grid dimensions
273
+ t, h, w = image_grid_thw[0][0].item(), image_grid_thw[0][1].item(), image_grid_thw[0][2].item()
274
+ llm_grid_t = t
275
+ llm_grid_h = h // spatial_merge_size
276
+ llm_grid_w = w // spatial_merge_size
277
+
278
+ # Find image token positions
279
+ boi_token_id = 59256 #
280
+ eoi_token_id = 59257 #
281
+
282
+ # Build position_ids
283
+ seq_len = len(input_ids_list)
284
+ position_ids = np.zeros((3, seq_len), dtype=np.int64)
285
+
286
+ # Find BOI and EOI positions
287
+ boi_pos = None
288
+ eoi_pos = None
289
+ for i, tid in enumerate(input_ids_list):
290
+ if tid == boi_token_id:
291
+ boi_pos = i
292
+ elif tid == eoi_token_id:
293
+ eoi_pos = i
294
+
295
+ if boi_pos is None or eoi_pos is None:
296
+ # No image tokens, use simple position_ids
297
+ for i in range(seq_len):
298
+ position_ids[0, i] = i
299
+ position_ids[1, i] = i
300
+ position_ids[2, i] = i
301
+ return position_ids, 0
302
+
303
+ # Text tokens before image
304
+ for i in range(boi_pos):
305
+ position_ids[0, i] = i
306
+ position_ids[1, i] = i
307
+ position_ids[2, i] = i
308
+
309
+ # BOI token
310
+ st_idx = boi_pos
311
+ position_ids[0, boi_pos] = st_idx
312
+ position_ids[1, boi_pos] = st_idx
313
+ position_ids[2, boi_pos] = st_idx
314
+
315
+ # Image tokens - use 3D position encoding
316
+ # t_index, h_index, w_index for each image token
317
+ img_start = boi_pos + 1
318
+ img_end = eoi_pos
319
+
320
+ for idx, pos in enumerate(range(img_start, img_end)):
321
+ t_idx = idx // (llm_grid_h * llm_grid_w)
322
+ hw_idx = idx % (llm_grid_h * llm_grid_w)
323
+ h_idx = hw_idx // llm_grid_w
324
+ w_idx = hw_idx % llm_grid_w
325
+
326
+ position_ids[0, pos] = st_idx + t_idx
327
+ position_ids[1, pos] = st_idx + h_idx
328
+ position_ids[2, pos] = st_idx + w_idx
329
+
330
+ # EOI token and text after
331
+ max_img_pos = max(
332
+ position_ids[0, img_start:img_end].max(),
333
+ position_ids[1, img_start:img_end].max(),
334
+ position_ids[2, img_start:img_end].max()
335
+ )
336
+
337
+ for i, pos in enumerate(range(eoi_pos, seq_len)):
338
+ position_ids[0, pos] = max_img_pos + 1 + i
339
+ position_ids[1, pos] = max_img_pos + 1 + i
340
+ position_ids[2, pos] = max_img_pos + 1 + i
341
+
342
+ # Calculate rope_deltas
343
+ max_pos = max(
344
+ position_ids[0].max(),
345
+ position_ids[1].max(),
346
+ position_ids[2].max()
347
+ )
348
+ rope_deltas = max_pos + 1 - seq_len
349
+
350
+ return position_ids, rope_deltas
351
+
352
+ def _run_with_io_binding(self, session, inputs_dict, device="cuda"):
353
+ """
354
+ Run inference (IO Binding temporarily disabled to ensure correct outputs).
355
+
356
+ Args:
357
+ session: ONNX Runtime InferenceSession
358
+ inputs_dict: Dictionary of input name -> numpy array
359
+ device: "cuda" or "cpu"
360
+
361
+ Returns:
362
+ list of numpy arrays
363
+ """
364
+ # Disable IO Binding temporarily to avoid garbage outputs
365
+ return session.run(None, inputs_dict)
366
+
367
+ def generate(
368
+ self,
369
+ image_path: str,
370
+ prompt: str = "",
371
+ max_new_tokens: int = 100,
372
+ temperature: float = 0.7,
373
+ top_p: float = 0.9,
374
+ ) -> str:
375
+ """
376
+ Generate text from image.
377
+
378
+ Args:
379
+ image_path: Path to input image
380
+ prompt: Optional text prompt
381
+ max_new_tokens: Maximum number of tokens to generate
382
+ temperature: Sampling temperature
383
+ top_p: Top-p sampling parameter
384
+
385
+ Returns:
386
+ Generated text
387
+ """
388
+ print(f"\nGenerating for image: {image_path}")
389
+ print(f" Prompt: '{prompt}'")
390
+ print(f" Max tokens: {max_new_tokens}")
391
+ print(f" Device: {self.device}")
392
+
393
+ # Step 1: Encode image
394
+ print("\n[1/4] Encoding image...")
395
+ start_time = time.time()
396
+ image_features = self.encode_image(image_path)
397
+ print(f" Image features shape: {image_features.shape}")
398
+ print(f" Time: {time.time() - start_time:.2f}s")
399
+
400
+ # Step 2: Prepare input
401
+ print("\n[2/4] Preparing input...")
402
+ start_time = time.time()
403
+
404
+ # Load image for processor
405
+ image = Image.open(image_path).convert("RGB")
406
+
407
+ # Create messages for GLM-OCR chat template (same as transformers_infer.py)
408
+ messages = [
409
+ {
410
+ "role": "user",
411
+ "content": [
412
+ {"type": "image", "url": image_path},
413
+ {"type": "text", "text": prompt if prompt else "Describe this image."}
414
+ ]
415
+ }
416
+ ]
417
+
418
+ inputs = self.processor.apply_chat_template(
419
+ messages,
420
+ tokenize=True,
421
+ add_generation_prompt=True,
422
+ return_dict=True,
423
+ return_tensors="pt"
424
+ )
425
+ inputs.pop("token_type_ids", None)
426
+
427
+ input_ids = inputs["input_ids"].numpy()
428
+ attention_mask = inputs["attention_mask"].numpy()
429
+
430
+ print(f" Input IDs shape: {input_ids.shape}")
431
+ print(f" Time: {time.time() - start_time:.2f}s")
432
+
433
+ # Step 3: Embedding
434
+ print("\n[3/4] Getting embeddings...")
435
+ start_time = time.time()
436
+
437
+ image_token_id = self.processor.tokenizer.convert_tokens_to_ids("<|image|>")
438
+ input_ids_list = input_ids[0].tolist()
439
+
440
+ # Get embeddings
441
+ embed_outputs = self._run_with_io_binding(
442
+ self.sessions["embedding"],
443
+ {"input_ids": input_ids},
444
+ device=self.device
445
+ )
446
+ inputs_embeds = embed_outputs[0]
447
+
448
+ # Replace image token embeddings with actual image features
449
+ image_positions = [i for i, tid in enumerate(input_ids_list) if tid == image_token_id]
450
+
451
+ if len(image_positions) > 0:
452
+ num_image_tokens = image_features.shape[0]
453
+
454
+ if len(image_positions) == num_image_tokens:
455
+ for i, pos in enumerate(image_positions):
456
+ inputs_embeds[0, pos] = image_features[i]
457
+ print(f" Replaced {num_image_tokens} image tokens")
458
+ else:
459
+ # Remove original <|image|> tokens from input_ids and get embeddings
460
+ non_image_mask = np.array([tid != image_token_id for tid in input_ids_list])
461
+ inputs_embeds = inputs_embeds[:, non_image_mask, :]
462
+
463
+ # Also update attention_mask to remove original image token
464
+ attention_mask = attention_mask[:, non_image_mask]
465
+
466
+ boi_token_id = self.processor.tokenizer.convert_tokens_to_ids("<|begin_of_image|>")
467
+ if boi_token_id in input_ids_list:
468
+ boi_pos = input_ids_list.index(boi_token_id)
469
+ before = inputs_embeds[:, :boi_pos+1, :]
470
+ after = inputs_embeds[:, boi_pos+1:, :]
471
+ image_features_batch = image_features[np.newaxis, :, :]
472
+ inputs_embeds = np.concatenate([before, image_features_batch, after], axis=1)
473
+
474
+ before_mask = attention_mask[:, :boi_pos+1]
475
+ image_mask = np.ones((1, num_image_tokens), dtype=np.int64)
476
+ after_mask = attention_mask[:, boi_pos+1:]
477
+ attention_mask = np.concatenate([before_mask, image_mask, after_mask], axis=1)
478
+
479
+ print(f" Inserted {num_image_tokens} image tokens")
480
+
481
+ print(f" Embeddings shape: {inputs_embeds.shape}")
482
+ print(f" Time: {time.time() - start_time:.2f}s")
483
+
484
+ # Step 4: Prefill
485
+ print("\n[4/4] Running inference...")
486
+ start_time = time.time()
487
+
488
+ seq_len = inputs_embeds.shape[1]
489
+
490
+ # M-RoPE: Calculate position_ids with proper 3D positions for image tokens
491
+ # We need to use the same logic as PyTorch's get_rope_index
492
+ image_grid_thw = inputs.get("image_grid_thw")
493
+ if image_grid_thw is not None:
494
+ # Calculate position_ids using the same logic as PyTorch
495
+ position_ids, rope_deltas = self._get_rope_index(
496
+ input_ids[0].tolist(),
497
+ image_grid_thw,
498
+ attention_mask[0].tolist()
499
+ )
500
+ position_ids = position_ids[:, np.newaxis, :]
501
+ print(f" M-RoPE enabled: rope_deltas={rope_deltas}")
502
+ else:
503
+ # Fallback to simple position_ids
504
+ position_ids = np.arange(seq_len, dtype=np.int64)
505
+ position_ids = np.stack([position_ids, position_ids, position_ids], axis=0)
506
+ position_ids = position_ids[:, np.newaxis, :]
507
+ rope_deltas = 0
508
+
509
+ prefill_inputs = {
510
+ "inputs_embeds": inputs_embeds.astype(np.float32),
511
+ "attention_mask": attention_mask.astype(np.int64),
512
+ "position_ids": position_ids.astype(np.int64),
513
+ }
514
+ prefill_outputs = self._run_with_io_binding(
515
+ self.sessions["prefill"],
516
+ prefill_inputs,
517
+ device=self.device
518
+ )
519
+
520
+ logits = prefill_outputs[0]
521
+ past_key_values = prefill_outputs[1:]
522
+
523
+ print(f" Prefill logits shape: {logits.shape}")
524
+ print(f" KV cache tensors: {len(past_key_values)}")
525
+ print(f" Time: {time.time() - start_time:.2f}s")
526
+
527
+ print(f"\n[5/5] Generating tokens...", flush=True)
528
+ print(f" DEBUG: seq_len={seq_len}, prefill positions=[0..{seq_len-1}]")
529
+ generated_tokens = []
530
+
531
+ decode_attention_mask = attention_mask.copy()
532
+
533
+ for step in range(max_new_tokens):
534
+ next_token_logits = logits[:, -1, :]
535
+ next_token_id = int(np.argmax(next_token_logits, axis=-1)[0])
536
+ generated_tokens.append(next_token_id)
537
+
538
+ if step < 5:
539
+ print(f" DEBUG step={step}: token={next_token_id} ('{self.processor.tokenizer.decode([next_token_id])}')")
540
+
541
+ if next_token_id in [self.processor.tokenizer.eos_token_id, 59253]:
542
+ print(f" EOS token reached at step {step + 1}")
543
+ break
544
+
545
+ # Update attention mask BEFORE decode (to match PyTorch behavior)
546
+ decode_attention_mask = np.concatenate(
547
+ [decode_attention_mask, np.ones((1, 1), dtype=np.int64)], axis=1
548
+ )
549
+
550
+ # Get next token embedding
551
+ next_token_embeds = self._run_with_io_binding(
552
+ self.sessions["embedding"],
553
+ {"input_ids": np.array([[next_token_id]], dtype=np.int64)},
554
+ device=self.device
555
+ )[0]
556
+
557
+ # Position IDs for M-RoPE: position = cache_position + rope_deltas
558
+ # This ensures correct position encoding after image tokens
559
+ cache_position = seq_len + step
560
+ new_position = cache_position + rope_deltas
561
+ decode_position_ids = np.full((3, 1, 1), new_position, dtype=np.int64)
562
+
563
+ if step < 5:
564
+ print(f" DEBUG step={step}: cache_pos={cache_position}, rope_delta={rope_deltas}, position_id={new_position}")
565
+
566
+ # Prepare decode inputs
567
+ decode_inputs = {
568
+ "inputs_embeds": next_token_embeds.astype(np.float32),
569
+ "attention_mask": decode_attention_mask,
570
+ "position_ids": decode_position_ids,
571
+ }
572
+ for layer_idx in range(16):
573
+ decode_inputs[f"past_key_{layer_idx}"] = past_key_values[layer_idx * 2]
574
+ decode_inputs[f"past_value_{layer_idx}"] = past_key_values[layer_idx * 2 + 1]
575
+
576
+ # Run decode
577
+ decode_outputs = self._run_with_io_binding(
578
+ self.sessions["decode"],
579
+ decode_inputs,
580
+ device=self.device
581
+ )
582
+
583
+ logits = decode_outputs[0]
584
+ past_key_values = decode_outputs[1:]
585
+
586
+ if (step + 1) % 10 == 0:
587
+ print(f" Generated {step + 1} tokens...")
588
+
589
+ print(f"\n Total tokens generated: {len(generated_tokens)}")
590
+ print(f" Time: {time.time() - start_time:.2f}s")
591
+
592
+ # Save full token sequence (input + generated) to file for comparison
593
+ # Note: input_ids_list contains the original 237 tokens from processor
594
+ # The actual tokens fed to prefill model may differ due to image token handling
595
+ full_sequence = input_ids_list + generated_tokens
596
+ with open("result_token_ids_onnx.txt", "w", encoding="utf-8") as f:
597
+ f.write(f"ONNX Full Token IDs (including input)\n")
598
+ f.write(f"Total: {len(full_sequence)} tokens\n")
599
+ f.write(f"Input length: {len(input_ids_list)} tokens (from processor)\n")
600
+ f.write(f"Prefill seq_len: {seq_len} tokens (actual embeddings fed to model)\n")
601
+ f.write(f"Generated: {len(generated_tokens)} tokens\n")
602
+ f.write("="*80 + "\n\n")
603
+ f.write(f"Full sequence:\n")
604
+ f.write(f"{full_sequence}\n\n")
605
+ f.write(f"Input part (first {len(input_ids_list)}):\n")
606
+ f.write(f"{input_ids_list}\n\n")
607
+ f.write(f"Generated part (last {len(generated_tokens)}):\n")
608
+ f.write(f"{generated_tokens}\n")
609
+ print(f" Full token IDs saved to result_token_ids_onnx.txt")
610
+
611
+ generated_text = self.processor.tokenizer.decode(
612
+ generated_tokens, skip_special_tokens=True
613
+ )
614
+
615
+ return generated_text
616
+
617
+ def _remove_duplicate_branches(self, text: str) -> str:
618
+ """
619
+ Remove duplicate branches from LaTeX formula output.
620
+ This fixes the issue where ONNX model generates repeated formula branches.
621
+ """
622
+ import re
623
+
624
+ # Split by line breaks (\\ in LaTeX)
625
+ lines = text.split('\\\\')
626
+
627
+ seen = set()
628
+ unique_lines = []
629
+
630
+ for line in lines:
631
+ # Normalize for comparison (remove extra spaces)
632
+ normalized = re.sub(r'\s+', ' ', line.strip())
633
+
634
+ if not normalized or normalized not in seen:
635
+ if normalized:
636
+ seen.add(normalized)
637
+ unique_lines.append(line)
638
+
639
+ return '\\\\'.join(unique_lines)
640
+
641
+ def generate_batch(
642
+ self,
643
+ image_paths: List[str],
644
+ prompt: str = "",
645
+ max_new_tokens: int = 100,
646
+ ) -> List[str]:
647
+ """
648
+ Generate text for multiple images.
649
+
650
+ Args:
651
+ image_paths: List of image paths
652
+ prompt: Optional text prompt
653
+ max_new_tokens: Maximum number of tokens to generate
654
+
655
+ Returns:
656
+ List of generated texts
657
+ """
658
+ results = []
659
+ for image_path in image_paths:
660
+ text = self.generate(image_path, prompt, max_new_tokens)
661
+ results.append(text)
662
+ return results
663
+
664
+
665
+ def main():
666
+ parser = argparse.ArgumentParser(description="GLM-OCR ONNX End-to-End Inference")
667
+ parser.add_argument(
668
+ "--onnx-dir",
669
+ type=str,
670
+ default=r"D:\models\onnx-v5\GLM-OCR",
671
+ help="ONNX models directory",
672
+ )
673
+ parser.add_argument(
674
+ "--image",
675
+ type=str,
676
+ default=None,
677
+ help="Single image path",
678
+ )
679
+ parser.add_argument(
680
+ "--prompt",
681
+ type=str,
682
+ default="Formula Recognition:",
683
+ help="Text prompt",
684
+ )
685
+ parser.add_argument(
686
+ "--max-tokens",
687
+ type=int,
688
+ default=1024,
689
+ help="Maximum tokens to generate",
690
+ )
691
+ parser.add_argument(
692
+ "--device",
693
+ type=str,
694
+ default="cpu",
695
+ choices=["cpu", "cuda"],
696
+ help="Device to use",
697
+ )
698
+
699
+ args = parser.parse_args()
700
+
701
+ # Get image paths
702
+ if args.image:
703
+ image_paths = [args.image]
704
+ else:
705
+ print("Error: --image must be specified")
706
+ sys.exit(1)
707
+
708
+ # Initialize inference
709
+ inference = GLMOcrOnnxInference(
710
+ onnx_dir=args.onnx_dir,
711
+ device=args.device,
712
+ )
713
+
714
+ # Generate
715
+ print("\n" + "=" * 60)
716
+ print("GLM-OCR ONNX End-to-End Inference")
717
+ print("=" * 60)
718
+
719
+ results = inference.generate_batch(
720
+ image_paths=image_paths,
721
+ prompt=args.prompt,
722
+ max_new_tokens=args.max_tokens,
723
+ )
724
+
725
+ # Print results
726
+ print("\n" + "=" * 60)
727
+ print("Results")
728
+ print("=" * 60)
729
+
730
+ for i, (image_path, text) in enumerate(zip(image_paths, results)):
731
+ print(f"\nImage {i + 1}: {image_path}")
732
+ print(f"Generated text:\n{text}")
733
+ print("-" * 60)
734
+
735
+
736
+ if __name__ == "__main__":
737
+ main()
738
+
739
+ ```
740
+
741
 
742
 
743
  # GLM-OCR