heerjtdev commited on
Commit
37a91ca
·
verified ·
1 Parent(s): 4448236

Upload test_layout_yolo_columns_log.py

Browse files
Files changed (1) hide show
  1. test_layout_yolo_columns_log.py +714 -0
test_layout_yolo_columns_log.py ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ from TorchCRF import CRF
7
+ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model, LayoutLMv3Config
8
+ import pytesseract
9
+ from PIL import Image
10
+ import fitz # PyMuPDF
11
+ from typing import List, Dict, Any, Optional, Union, Tuple
12
+ import numpy as np
13
+ from scipy.signal import find_peaks
14
+ from scipy.ndimage import gaussian_filter1d
15
+ import sys
16
+ import io
17
+
18
+ # ============================================================================
19
+ # CONSTANTS & MODEL DEFINITION
20
+ # ============================================================================
21
+
22
+ # Labels must match the training labels! (Use the most detailed set)
23
+ ID_TO_LABEL = {
24
+ 0: "O",
25
+ 1: "B-QUESTION", 2: "I-QUESTION",
26
+ 3: "B-OPTION", 4: "I-OPTION",
27
+ 5: "B-ANSWER", 6: "I-ANSWER",
28
+ 7: "B-SECTION_HEADING", 8: "I-SECTION_HEADING",
29
+ 9: "B-PASSAGE", 10: "I-PASSAGE"
30
+ }
31
+ NUM_LABELS = len(ID_TO_LABEL)
32
+
33
+
34
+ class LayoutLMv3ForTokenClassification(nn.Module):
35
+ """LayoutLMv3 model with a linear layer and a CRF layer on top."""
36
+
37
+ def __init__(self, num_labels: int = NUM_LABELS):
38
+ super().__init__()
39
+ self.num_labels = num_labels
40
+
41
+ config = LayoutLMv3Config.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_labels)
42
+ self.layoutlmv3 = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", config=config)
43
+ self.classifier = nn.Linear(config.hidden_size, num_labels)
44
+ self.crf = CRF(num_labels)
45
+ self.init_weights()
46
+
47
+ def init_weights(self):
48
+ nn.init.xavier_uniform_(self.classifier.weight)
49
+ if self.classifier.bias is not None:
50
+ nn.init.zeros_(self.classifier.bias)
51
+
52
+ def forward(
53
+ self,
54
+ input_ids: torch.Tensor,
55
+ bbox: torch.Tensor,
56
+ attention_mask: torch.Tensor,
57
+ labels: Optional[torch.Tensor] = None,
58
+ ) -> Union[torch.Tensor, Tuple[List[List[int]], Any]]:
59
+
60
+ outputs = self.layoutlmv3(
61
+ input_ids=input_ids,
62
+ bbox=bbox,
63
+ attention_mask=attention_mask,
64
+ return_dict=True
65
+ )
66
+
67
+ sequence_output = outputs.last_hidden_state
68
+ emissions = self.classifier(sequence_output)
69
+ mask = attention_mask.bool()
70
+
71
+ if labels is not None:
72
+ log_likelihood = self.crf(emissions, labels, mask=mask)
73
+ loss = -log_likelihood.mean()
74
+ return loss
75
+ else:
76
+ best_paths = self.crf.viterbi_decode(emissions, mask=mask)
77
+ return best_paths
78
+
79
+
80
+ # ============================================================================
81
+ # COLUMN DETECTION MODULE (Re-included for completeness)
82
+ # ============================================================================
83
+
84
+ def get_word_data_for_detection(page: fitz.Page, top_margin_percent=0.10, bottom_margin_percent=0.10) -> list:
85
+ """Extracts word data for column detection with Y-axis filtering."""
86
+ word_data = page.get_text("words")
87
+ if len(word_data) == 0:
88
+ # Fallback to Tesseract if PyMuPDF finds no words
89
+ try:
90
+ pix = page.get_pixmap(matrix=fitz.Matrix(3, 3))
91
+ img_bytes = pix.tobytes("png")
92
+ img = Image.open(io.BytesIO(img_bytes))
93
+ data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT)
94
+
95
+ full_word_data = []
96
+ for i in range(len(data['level'])):
97
+ if data['text'][i].strip():
98
+ x1 = data['left'][i] / 3
99
+ y1 = data['top'][i] / 3
100
+ x2 = (data['left'][i] + data['width'][i]) / 3
101
+ y2 = (data['top'][i] + data['height'][i]) / 3
102
+ word = data['text'][i]
103
+ full_word_data.append((word, x1, y1, x2, y2))
104
+
105
+ word_data = full_word_data
106
+ except Exception as e:
107
+ # print(f"Tesseract fallback failed: {e}")
108
+ return []
109
+ else:
110
+ word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data]
111
+
112
+ page_height = page.rect.height
113
+ y_min = page_height * top_margin_percent
114
+ y_max = page_height * (1 - bottom_margin_percent)
115
+
116
+ filtered_data = [
117
+ (word, x1, y1, x2, y2)
118
+ for word, x1, y1, x2, y2 in word_data
119
+ if y1 >= y_min and y2 <= y_max
120
+ ]
121
+ return filtered_data
122
+
123
+
124
+ def calculate_x_gutters(word_data: list, params: Dict) -> List[int]:
125
+ """Calculates the X-axis histogram and detects significant gutters."""
126
+ if not word_data: return []
127
+
128
+ x_points = []
129
+ for _, x1, _, x2, _ in word_data:
130
+ x_points.extend([x1, x2])
131
+
132
+ max_x = max(x_points)
133
+ bin_size = params['cluster_bin_size']
134
+ num_bins = int(np.ceil(max_x / bin_size))
135
+
136
+ hist, bin_edges = np.histogram(x_points, bins=num_bins, range=(0, max_x))
137
+ smoothed_hist = gaussian_filter1d(hist.astype(float), sigma=params['cluster_smoothing'])
138
+ inverted_signal = np.max(smoothed_hist) - smoothed_hist
139
+
140
+ peaks, properties = find_peaks(
141
+ inverted_signal,
142
+ height=0,
143
+ distance=params['cluster_min_width'] / bin_size
144
+ )
145
+
146
+ if not peaks.size: return []
147
+
148
+ threshold_value = np.percentile(smoothed_hist, params['cluster_threshold_percentile'])
149
+ inverted_threshold = np.max(smoothed_hist) - threshold_value
150
+ significant_peaks = peaks[properties['peak_heights'] >= inverted_threshold]
151
+ separator_x_coords = [int(bin_edges[p]) for p in significant_peaks]
152
+
153
+ final_separators = []
154
+ prominence_threshold = params['cluster_prominence'] * np.max(smoothed_hist)
155
+
156
+ for x_coord in separator_x_coords:
157
+ bin_idx = np.searchsorted(bin_edges, x_coord) - 1
158
+ window_size = int(params['cluster_min_width'] / bin_size)
159
+
160
+ left_start, left_end = max(0, bin_idx - window_size), bin_idx
161
+ right_start, right_end = bin_idx + 1, min(len(smoothed_hist), bin_idx + 1 + window_size)
162
+
163
+ if left_end <= left_start or right_end <= right_start: continue
164
+
165
+ avg_left_density = np.mean(smoothed_hist[left_start:left_end])
166
+ avg_right_density = np.mean(smoothed_hist[right_start:right_end])
167
+
168
+ if avg_left_density >= prominence_threshold and avg_right_density >= prominence_threshold:
169
+ final_separators.append(x_coord)
170
+
171
+ return sorted(final_separators)
172
+
173
+
174
+ def detect_column_gutters(pdf_path: str, page_num: int, **params) -> Optional[int]:
175
+ """Main function for column detection."""
176
+ try:
177
+ doc = fitz.open(pdf_path)
178
+ page = doc.load_page(page_num)
179
+ word_data = get_word_data_for_detection(page, params.get('top_margin_percent', 0.10),
180
+ params.get('bottom_margin_percent', 0.10))
181
+ if not word_data:
182
+ doc.close()
183
+ return None
184
+
185
+ separators = calculate_x_gutters(word_data, params)
186
+ doc.close()
187
+
188
+ if len(separators) == 1:
189
+ return separators[0]
190
+ elif len(separators) > 1:
191
+ page_width = page.rect.width
192
+ center_x = page_width / 2
193
+ best_separator = min(separators, key=lambda x: abs(x - center_x))
194
+ return best_separator
195
+ return None
196
+ except Exception as e:
197
+ print(f"DEBUG: Column detection failed for page {page_num}: {e}")
198
+ return None
199
+
200
+
201
+ def _merge_integrity(all_words_by_page: List[str], all_bboxes_raw: List[List[int]],
202
+ column_separator_x: Optional[int]) -> List[List[str]]:
203
+ """Splits the words/bboxes into two columns if a separator is present."""
204
+ if column_separator_x is None:
205
+ return [all_words_by_page]
206
+
207
+ left_column_words = []
208
+ right_column_words = []
209
+ gutter_min_x = column_separator_x - 10
210
+ gutter_max_x = column_separator_x + 10
211
+
212
+ for i, (word, bbox_raw) in enumerate(zip(all_words_by_page, all_bboxes_raw)):
213
+ x1_raw, _, x2_raw, _ = bbox_raw
214
+ center_x = (x1_raw + x2_raw) / 2
215
+
216
+ if center_x < column_separator_x:
217
+ left_column_words.append(word)
218
+ else:
219
+ right_column_words.append(word)
220
+
221
+ return [c for c in [left_column_words, right_column_words] if c]
222
+
223
+
224
+ def post_process_predictions(words: List[str], bboxes: List[List[int]], predictions: List[str]) -> List[Dict[str, Any]]:
225
+ """Converts a flat list of words (and string label predictions) into structured blocks."""
226
+ structured_blocks = []
227
+ current_block = None
228
+
229
+ # DEBUG: Track how many blocks are created
230
+ block_count = 0
231
+
232
+ for word, bbox, label in zip(words, bboxes, predictions):
233
+ prefix, tag = (label.split('-', 1) + [None])[:2]
234
+
235
+ if prefix == 'B':
236
+ if current_block:
237
+ structured_blocks.append(current_block)
238
+ block_count += 1
239
+ # print(f" DEBUG POST: Closed block {block_count}: Tag={current_block['tag']}, Words={len(current_block['words'])}")
240
+
241
+ current_block = {
242
+ 'text': word,
243
+ 'tag': tag,
244
+ 'words': [{'text': word, 'bbox': bbox, 'label': label}],
245
+ 'bbox': list(bbox)
246
+ }
247
+
248
+ elif prefix == 'I' and current_block and current_block['tag'] == tag:
249
+ current_block['text'] += ' ' + word
250
+ current_block['words'].append({'text': word, 'bbox': bbox, 'label': label})
251
+ current_block['bbox'][0] = min(current_block['bbox'][0], bbox[0])
252
+ current_block['bbox'][1] = min(current_block['bbox'][1], bbox[1])
253
+ current_block['bbox'][2] = max(current_block['bbox'][2], bbox[2])
254
+ current_block['bbox'][3] = max(current_block['bbox'][3], bbox[3])
255
+
256
+ else: # 'O' or mismatching tag
257
+ if current_block:
258
+ structured_blocks.append(current_block)
259
+ block_count += 1
260
+ # print(f" DEBUG POST: Closed block {block_count}: Tag={current_block['tag']}, Words={len(current_block['words'])}")
261
+ current_block = None
262
+
263
+ # Handle 'O' or isolated 'I'. We include 'O' for completeness, but they might be filtered later.
264
+ if label == 'O':
265
+ # print(" DEBUG POST: Created individual 'OTHER' block (O).")
266
+ structured_blocks.append({
267
+ 'text': word,
268
+ 'tag': 'OTHER',
269
+ 'words': [{'text': word, 'bbox': bbox, 'label': label}],
270
+ 'bbox': list(bbox)
271
+ })
272
+ elif prefix == 'I':
273
+ # Start a block that missed the 'B'
274
+ # print(f" DEBUG POST: Created isolated 'I' block: Tag={tag}.")
275
+ structured_blocks.append({
276
+ 'text': word,
277
+ 'tag': tag,
278
+ 'words': [{'text': word, 'bbox': bbox, 'label': label}],
279
+ 'bbox': list(bbox)
280
+ })
281
+
282
+ if current_block:
283
+ structured_blocks.append(current_block)
284
+ block_count += 1
285
+ # print(f" DEBUG POST: Closed final block {block_count}: Tag={current_block['tag']}, Words={len(current_block['words'])}")
286
+
287
+ # print(f"DEBUG POST: Total structured blocks created: {len(structured_blocks)}")
288
+ return structured_blocks
289
+
290
+
291
+ # ============================================================================
292
+ # CORE INFERENCE FUNCTION (WITH DEBUGGING LOGS)
293
+ # ============================================================================
294
+
295
+ def run_inference_and_structure(pdf_path: str, model_path: str, inference_output_path: str,
296
+ preprocessed_json_path: str,
297
+ column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]:
298
+ """
299
+ Runs LayoutLMv3-CRF inference with extensive debugging logs.
300
+ """
301
+ print("--- 1. MODEL SETUP ---")
302
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
303
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
304
+ print(f"DEBUG: Using device: {device}")
305
+
306
+ try:
307
+ model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS)
308
+ checkpoint = torch.load(model_path, map_location=device)
309
+
310
+ # Fixed Loading Logic
311
+ model_state = checkpoint.get('model_state_dict', checkpoint)
312
+ fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()}
313
+ model.load_state_dict(fixed_state_dict)
314
+ model.to(device)
315
+ model.eval()
316
+ print(f"✅ Model loaded successfully from {model_path}. Total {len(fixed_state_dict)} keys loaded.")
317
+ except Exception as e:
318
+ print(f"❌ FATAL ERROR during model loading: {e}")
319
+ return []
320
+
321
+ # --------------------------------------------------------------------------
322
+ # 2. DATA LOADING & PREPARATION
323
+ # --------------------------------------------------------------------------
324
+ print("\n--- 2. DATA LOADING ---")
325
+ try:
326
+ with open(preprocessed_json_path, 'r', encoding='utf-8') as f:
327
+ preprocessed_data = json.load(f)
328
+ print(f"✅ Loaded preprocessed data with {len(preprocessed_data)} pages.")
329
+ except Exception as e:
330
+ print(f"❌ Error loading preprocessed JSON: {e}")
331
+ return []
332
+
333
+ try:
334
+ doc = fitz.open(pdf_path)
335
+ except Exception as e:
336
+ print(f"❌ Error loading PDF: {e}")
337
+ return []
338
+
339
+ all_pages_data = []
340
+ CHUNK_SIZE = 500
341
+
342
+ for page_data in preprocessed_data:
343
+ page_num_1_based = page_data['page_number']
344
+ page_num_0_based = page_num_1_based - 1
345
+ print(f"\nProcessing Page {page_num_1_based}...")
346
+
347
+ fitz_page = doc.load_page(page_num_0_based)
348
+ page_width, page_height = fitz_page.rect.width, fitz_page.rect.height
349
+
350
+ words = []
351
+ bboxes_raw_pdf_space = []
352
+ normalized_bboxes_list = []
353
+ scale_factor = 2.0
354
+
355
+ for item in page_data['data']:
356
+ word = item['word']
357
+ raw_yolo_bbox = item['bbox']
358
+
359
+ bbox_pdf = [
360
+ int(raw_yolo_bbox[0] / scale_factor),
361
+ int(raw_yolo_bbox[1] / scale_factor),
362
+ int(raw_yolo_bbox[2] / scale_factor),
363
+ int(raw_yolo_bbox[3] / scale_factor)
364
+ ]
365
+
366
+ normalized_bbox = [
367
+ max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))),
368
+ max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))),
369
+ max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))),
370
+ max(0, min(1000, int(1000 * bbox_pdf[3] / page_height)))
371
+ ]
372
+
373
+ words.append(word)
374
+ bboxes_raw_pdf_space.append(bbox_pdf)
375
+ normalized_bboxes_list.append(normalized_bbox)
376
+
377
+ if not words:
378
+ print(f" DEBUG: Page {page_num_1_based} has no words in preprocessed data. Skipping.")
379
+ continue
380
+
381
+ print(f" DEBUG: Page {page_num_1_based} extracted {len(words)} words.")
382
+
383
+ # --------------------------------------------------------------------------
384
+ # 3. COLUMN DETECTION & CHUNKING
385
+ # --------------------------------------------------------------------------
386
+ column_detection_params = column_detection_params or {}
387
+ column_separator_x = detect_column_gutters(pdf_path, page_num_0_based, **column_detection_params)
388
+
389
+ if column_separator_x is not None:
390
+ print(f" DEBUG: Column detected at X={column_separator_x}. Splitting.")
391
+ else:
392
+ print(f" DEBUG: No column detected. Processing as a single chunk.")
393
+
394
+ word_chunks = _merge_integrity(words, bboxes_raw_pdf_space, column_separator_x)
395
+ print(f" DEBUG: Split into {len(word_chunks)} column/chunks.")
396
+
397
+ page_structured_data = {'page_number': page_num_1_based, 'structured_blocks': []}
398
+
399
+ # --------------------------------------------------------------------------
400
+ # 4. INFERENCE LOOP
401
+ # --------------------------------------------------------------------------
402
+
403
+ # Re-alignment is simplified and potentially slow. A proper way would be to
404
+ # split all three lists (words, bboxes_pdf, bboxes_norm) at the same time.
405
+ # But we stick to your original approach for minimal changes.
406
+
407
+ current_word_idx = 0
408
+
409
+ for chunk_idx, chunk_words in enumerate(word_chunks):
410
+ if not chunk_words: continue
411
+
412
+ # Reconstruct the aligned chunk data (Simplified version of your complex loop)
413
+ current_original_index = 0
414
+ temp_chunk_norm_bboxes = []
415
+ temp_chunk_pdf_bboxes = []
416
+ found_words = []
417
+
418
+ # Simple, but slow, way to re-align data for the chunk:
419
+ for word_to_find in chunk_words:
420
+ try:
421
+ # Find the index of the word in the master list, starting search from the last found position
422
+ i = words[current_original_index:].index(word_to_find) + current_original_index
423
+ temp_chunk_norm_bboxes.append(normalized_bboxes_list[i])
424
+ temp_chunk_pdf_bboxes.append(bboxes_raw_pdf_space[i])
425
+ found_words.append(words[i])
426
+ current_original_index = i + 1
427
+ except ValueError:
428
+ # print(f" WARNING: Word '{word_to_find}' not found during re-alignment.")
429
+ pass # Skip missing words
430
+
431
+ chunk_words = found_words
432
+ chunk_normalized_bboxes = temp_chunk_norm_bboxes
433
+ chunk_bboxes_pdf = temp_chunk_pdf_bboxes
434
+
435
+ print(f" DEBUG: Column/Chunk {chunk_idx + 1} has {len(chunk_words)} words.")
436
+
437
+ # Sub-chunking for max_seq_len (512)
438
+ for i in range(0, len(chunk_words), CHUNK_SIZE):
439
+ sub_words = chunk_words[i:i + CHUNK_SIZE]
440
+ sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE]
441
+ sub_bboxes_pdf = chunk_bboxes_pdf[i:i + CHUNK_SIZE]
442
+
443
+ encoded_input = tokenizer(
444
+ sub_words,
445
+ boxes=sub_bboxes,
446
+ truncation=True,
447
+ padding="max_length",
448
+ max_length=512,
449
+ # is_split_into_words=True,
450
+ return_tensors="pt"
451
+ )
452
+
453
+ input_ids = encoded_input['input_ids'].to(device)
454
+ bbox = encoded_input['bbox'].to(device)
455
+ attention_mask = encoded_input['attention_mask'].to(device)
456
+
457
+ print(f" DEBUG INFER: Sub-chunk size: {len(sub_words)} words. Input shape: {input_ids.shape}")
458
+
459
+ with torch.no_grad():
460
+ predictions_int_list = model(input_ids, bbox, attention_mask)
461
+
462
+ if not predictions_int_list:
463
+ print(" ❌ INFERENCE FAILED: Model returned empty list of predictions.")
464
+ continue
465
+
466
+ predictions_int = predictions_int_list[0]
467
+
468
+ # --- CHECK FOR NON-'O' PREDICTIONS ---
469
+ non_o_count = sum(1 for p in predictions_int if p != 0)
470
+ print(
471
+ f" DEBUG INFER: Raw predictions (tokens): Total {len(predictions_int)}. Non-'O' tokens: {non_o_count}.")
472
+ if non_o_count == 0:
473
+ print(" ⚠️ WARNING: Model is predicting 'O' for all tokens. Check training or input quality.")
474
+ # -----------------------------------
475
+
476
+ # Map token predictions back to original words
477
+ word_ids = encoded_input.word_ids()
478
+ word_idx_to_pred_id = {}
479
+
480
+ for token_idx, word_idx in enumerate(word_ids):
481
+ if word_idx is not None and word_idx < len(sub_words):
482
+ # Only take the prediction of the FIRST sub-token for a word
483
+ if word_idx not in word_idx_to_pred_id:
484
+ word_idx_to_pred_id[word_idx] = predictions_int[token_idx]
485
+
486
+ final_predictions_str = []
487
+ # Map integer IDs back to string labels
488
+ for current_word_idx in range(len(sub_words)):
489
+ pred_id_or_tensor = word_idx_to_pred_id.get(current_word_idx, 0)
490
+ pred_id = pred_id_or_tensor.item() if torch.is_tensor(pred_id_or_tensor) else pred_id_or_tensor
491
+
492
+ # This is the final word-level prediction. If it's always 0, post-processing fails.
493
+ final_predictions_str.append(ID_TO_LABEL[pred_id])
494
+
495
+ # --- POST-PROCESSING ---
496
+ structured_blocks = post_process_predictions(sub_words, sub_bboxes_pdf, final_predictions_str)
497
+
498
+ print(f" DEBUG POST: Created {len(structured_blocks)} structured blocks from this sub-chunk.")
499
+
500
+ page_structured_data['structured_blocks'].extend(structured_blocks)
501
+
502
+ print(
503
+ f" DEBUG: Page {page_num_1_based} final total structured blocks: {len(page_structured_data['structured_blocks'])}")
504
+ all_pages_data.append(page_structured_data)
505
+
506
+ doc.close()
507
+
508
+ # Save final structured predictions
509
+ with open(inference_output_path, 'w', encoding='utf-8') as f:
510
+ json.dump(all_pages_data, f, indent=4)
511
+
512
+ print(f"\n✅ All pages processed. Structured data saved to {os.path.basename(inference_output_path)}")
513
+
514
+ return all_pages_data
515
+
516
+
517
+ # --- 5. Label Studio Conversion Utility (Included for completeness) ---
518
+
519
+ def create_label_studio_span(all_results, start_idx, end_idx, label):
520
+ """Create a Label Studio span with character-level offsets."""
521
+ entity_words = [all_results[i]['word'] for i in range(start_idx, end_idx + 1)]
522
+ entity_bboxes = [all_results[i]['bbox'] for i in range(start_idx, end_idx + 1)]
523
+
524
+ x0 = min(bbox[0] for bbox in entity_bboxes)
525
+ y0 = min(bbox[1] for bbox in entity_bboxes)
526
+ x1 = max(bbox[2] for bbox in entity_bboxes)
527
+ y1 = max(bbox[3] for bbox in entity_bboxes)
528
+
529
+ all_words = [r['word'] for r in all_results]
530
+ text_string = " ".join(all_words)
531
+
532
+ prefix_words = all_words[:start_idx]
533
+ start_char = len(" ".join(prefix_words)) + (1 if prefix_words else 0)
534
+ span_text = " ".join(entity_words)
535
+ end_char = start_char + len(span_text)
536
+
537
+ return {
538
+ "from_name": "label",
539
+ "to_name": "text",
540
+ "type": "labels",
541
+ "value": {
542
+ "start": start_char,
543
+ "end": end_char,
544
+ "text": span_text,
545
+ "labels": [label],
546
+ "bbox": {
547
+ "x": x0,
548
+ "y": y0,
549
+ "width": x1 - x0,
550
+ "height": y1 - y0
551
+ }
552
+ },
553
+ "score": 0.99
554
+ }
555
+
556
+
557
+ def convert_to_label_studio_format(structured_data: List[Dict[str, Any]],
558
+ output_path: str,
559
+ pdf_file_name: str) -> None:
560
+ """Convert structured predictions to Label Studio format."""
561
+ final_tasks = []
562
+
563
+ for page_data in structured_data:
564
+ page_num = page_data['page_number']
565
+ if 'structured_blocks' not in page_data: continue
566
+
567
+ page_results = []
568
+ for block in page_data['structured_blocks']:
569
+ if 'words' in block:
570
+ for word_info in block['words']:
571
+ page_results.append({
572
+ 'word': word_info['text'],
573
+ 'bbox': word_info['bbox'],
574
+ # FIX: Use the full label string (e.g., 'B-QUESTION')
575
+ 'predicted_label': word_info['label']
576
+ })
577
+
578
+ if not page_results:
579
+ print(f"DEBUG LS: Page {page_num} has no word-level results. Skipping.")
580
+ continue
581
+
582
+ original_words = [r['word'] for r in page_results]
583
+ original_bboxes = [r['bbox'] for r in page_results]
584
+ text_string = " ".join(original_words)
585
+
586
+ results = []
587
+ current_entity_label = None
588
+ current_entity_start_word_index = None
589
+
590
+ for i, pred_item in enumerate(page_results):
591
+ label = pred_item['predicted_label']
592
+
593
+ # Get the tag (e.g., 'QUESTION' from 'B-QUESTION')
594
+ tag_only = label.split('-', 1)[-1] if '-' in label else label
595
+
596
+ if label.startswith('B-'):
597
+ if current_entity_label:
598
+ results.append(create_label_studio_span(
599
+ page_results, current_entity_start_word_index, i - 1, current_entity_label
600
+ ))
601
+ current_entity_label = tag_only
602
+ current_entity_start_word_index = i
603
+
604
+ elif label.startswith('I-') and current_entity_label == tag_only:
605
+ continue
606
+
607
+ else: # Label is 'O' or doesn't match current entity
608
+ if current_entity_label:
609
+ results.append(create_label_studio_span(
610
+ page_results, current_entity_start_word_index, i - 1, current_entity_label
611
+ ))
612
+ current_entity_label = None
613
+ current_entity_start_word_index = None
614
+
615
+ if current_entity_label:
616
+ results.append(create_label_studio_span(
617
+ page_results, current_entity_start_word_index, len(page_results) - 1, current_entity_label
618
+ ))
619
+
620
+ print(f"DEBUG LS: Page {page_num} generated {len(results)} Label Studio spans.")
621
+
622
+ task = {
623
+ "data": {
624
+ "text": text_string,
625
+ "original_words": original_words,
626
+ "original_bboxes": original_bboxes
627
+ },
628
+ "annotations": [{"result": results}],
629
+ "meta": {"page_number": page_num, "column_index": 1}
630
+ }
631
+ final_tasks.append(task)
632
+
633
+ with open(output_path, "w", encoding='utf-8') as f:
634
+ json.dump(final_tasks, f, indent=2, ensure_ascii=False)
635
+ print(f"\n✅ Label Studio tasks created and saved to {output_path}")
636
+
637
+
638
+ if __name__ == "__main__":
639
+ parser = argparse.ArgumentParser(
640
+ description="LayoutLMv3 Inference Pipeline for PDF and Label Studio OCR Conversion.")
641
+ parser.add_argument("--input_pdf", type=str, required=True,
642
+ help="Path to the input PDF file for inference.")
643
+ parser.add_argument("--model_path", type=str,
644
+ default="checkpoints/layoutlmv3_trained_20251031_102846_recovered.pth",
645
+ help="Path to the saved LayoutLMv3-CRF PyTorch model checkpoint.")
646
+ parser.add_argument("--inference_output", type=str, default="structured_yolo_predictions.json",
647
+ help="Path to save the intermediate structured predictions.")
648
+ parser.add_argument("--label_studio_output", type=str, default="label_studio_import.json",
649
+ help="Path to save the final Label Studio import JSON.")
650
+ parser.add_argument("--preprocessed_json", type=str, required=True,
651
+ help="Path to the combined JSON output from the YOLO/OCR script.")
652
+ parser.add_argument("--no_labelstudio", action="store_true",
653
+ help="If set, skip creating the Label Studio import JSON and only write structured predictions.")
654
+ parser.add_argument("--verbose", action="store_true",
655
+ help="Enable verbose printing.")
656
+ args = parser.parse_args()
657
+
658
+ # 1. Check for required files
659
+ print("--- 0. PRE-CHECK ---")
660
+ if not os.path.exists(args.model_path):
661
+ print(f"❌ FATAL ERROR: Model checkpoint not found at {args.model_path}.")
662
+ sys.exit(1)
663
+ if not os.path.exists(args.input_pdf):
664
+ print(f"❌ FATAL ERROR: Input PDF not found at {args.input_pdf}.")
665
+ sys.exit(1)
666
+ if not os.path.exists(args.preprocessed_json):
667
+ print(f"❌ FATAL ERROR: Preprocessed JSON not found at {args.preprocessed_json}. Run the YOLO/OCR script first.")
668
+ sys.exit(1)
669
+ print("✅ All required files found.")
670
+
671
+ # 2. Column Detection Parameters (Tuning required)
672
+ column_params = {
673
+ 'top_margin_percent': 0.10,
674
+ 'bottom_margin_percent': 0.10,
675
+ 'cluster_prominence': 0.70,
676
+ 'cluster_bin_size': 5,
677
+ 'cluster_smoothing': 2,
678
+ 'cluster_threshold_percentile': 30,
679
+ 'cluster_min_width': 25,
680
+ }
681
+
682
+ # 3. Run inference
683
+ try:
684
+ structured_data = run_inference_and_structure(
685
+ args.input_pdf,
686
+ args.model_path,
687
+ args.inference_output,
688
+ args.preprocessed_json,
689
+ column_detection_params=column_params
690
+ )
691
+ except Exception as e:
692
+ print(f"❌ Fatal error while running inference: {e}")
693
+ structured_data = []
694
+
695
+ # 4. If requested, convert to Label Studio format
696
+ if structured_data and not args.no_labelstudio:
697
+ try:
698
+ convert_to_label_studio_format(
699
+ structured_data=structured_data,
700
+ output_path=args.label_studio_output,
701
+ pdf_file_name=args.input_pdf
702
+ )
703
+ except Exception as e:
704
+ print(f"❌ Error while converting to Label Studio format: {e}")
705
+ elif not structured_data:
706
+ print("⚠️ No structured data produced — skipping Label Studio conversion.")
707
+ else:
708
+ print("ℹ️ Skipped Label Studio conversion as requested (--no_labelstudio).")
709
+
710
+ # 5. Final status message
711
+ print("\n--- 5. FINAL STATUS ---")
712
+ print(f"Finished. Structured predictions file: {os.path.abspath(args.inference_output)}")
713
+ if structured_data and not args.no_labelstudio:
714
+ print(f"Label Studio import file: {os.path.abspath(args.label_studio_output)}")