heerjtdev commited on
Commit
fc8c0fc
·
verified ·
1 Parent(s): 37a91ca

Rename test_layout_yolo_columns_log.py to app.py

Browse files
Files changed (2) hide show
  1. app.py +59 -0
  2. test_layout_yolo_columns_log.py +0 -714
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from rapidocr import RapidOCR, OCRVersion
3
+
4
+ # 1. Initialize the OCR engine once with v5 defaults
5
+ # We use v5 for Detection/Recognition and v4 for Classification (most stable v5 setup)
6
+ engine = RapidOCR(params={
7
+ "Det.ocr_version": OCRVersion.PPOCRV5,
8
+ "Rec.ocr_version": OCRVersion.PPOCRV5,
9
+ "Cls.ocr_version": OCRVersion.PPOCRV4,
10
+ })
11
+
12
+ def perform_ocr(img):
13
+ if img is None:
14
+ return None, None, "0.0"
15
+
16
+ # 2. Run OCR. return_word_box=True provides the word/char level detail
17
+ ocr_result = engine(img, return_word_box=True)
18
+
19
+ # 3. Get the annotated preview image
20
+ vis_img = ocr_result.vis()
21
+
22
+ # 4. Format word-level results for the Dataframe
23
+ # We flatten the word_results list using the logic from your advanced script
24
+ word_list = []
25
+ if ocr_result.word_results:
26
+ flat_results = sum(ocr_result.word_results, ())
27
+ for i, (text, score, _) in enumerate(flat_results):
28
+ word_list.append([i + 1, text, round(float(score), 3)])
29
+
30
+ return vis_img, word_list, f"{ocr_result.elapse:.3f}s"
31
+
32
+ # 5. Build a clean, minimal UI
33
+ with gr.Blocks(title="Rapid⚡OCR Simple") as demo:
34
+ gr.Markdown("# Rapid⚡OCR v5")
35
+ gr.Markdown("Upload an image to extract text with word-level bounding boxes.")
36
+
37
+ with gr.Row():
38
+ with gr.Column():
39
+ input_img = gr.Image(label="Input Image", type="numpy")
40
+ run_btn = gr.Button("Run OCR", variant="primary")
41
+
42
+ with gr.Column():
43
+ output_img = gr.Image(label="Preview (Bounding Boxes)")
44
+ elapse_info = gr.Textbox(label="Processing Time")
45
+
46
+ result_table = gr.Dataframe(
47
+ headers=["ID", "Text", "Confidence"],
48
+ label="Detected Words",
49
+ interactive=False
50
+ )
51
+
52
+ run_btn.click(
53
+ fn=perform_ocr,
54
+ inputs=[input_img],
55
+ outputs=[output_img, result_table, elapse_info]
56
+ )
57
+
58
+ if __name__ == "__main__":
59
+ demo.launch()
test_layout_yolo_columns_log.py DELETED
@@ -1,714 +0,0 @@
1
- import json
2
- import argparse
3
- import os
4
- import torch
5
- import torch.nn as nn
6
- from TorchCRF import CRF
7
- from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model, LayoutLMv3Config
8
- import pytesseract
9
- from PIL import Image
10
- import fitz # PyMuPDF
11
- from typing import List, Dict, Any, Optional, Union, Tuple
12
- import numpy as np
13
- from scipy.signal import find_peaks
14
- from scipy.ndimage import gaussian_filter1d
15
- import sys
16
- import io
17
-
18
- # ============================================================================
19
- # CONSTANTS & MODEL DEFINITION
20
- # ============================================================================
21
-
22
- # Labels must match the training labels! (Use the most detailed set)
23
- ID_TO_LABEL = {
24
- 0: "O",
25
- 1: "B-QUESTION", 2: "I-QUESTION",
26
- 3: "B-OPTION", 4: "I-OPTION",
27
- 5: "B-ANSWER", 6: "I-ANSWER",
28
- 7: "B-SECTION_HEADING", 8: "I-SECTION_HEADING",
29
- 9: "B-PASSAGE", 10: "I-PASSAGE"
30
- }
31
- NUM_LABELS = len(ID_TO_LABEL)
32
-
33
-
34
- class LayoutLMv3ForTokenClassification(nn.Module):
35
- """LayoutLMv3 model with a linear layer and a CRF layer on top."""
36
-
37
- def __init__(self, num_labels: int = NUM_LABELS):
38
- super().__init__()
39
- self.num_labels = num_labels
40
-
41
- config = LayoutLMv3Config.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_labels)
42
- self.layoutlmv3 = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", config=config)
43
- self.classifier = nn.Linear(config.hidden_size, num_labels)
44
- self.crf = CRF(num_labels)
45
- self.init_weights()
46
-
47
- def init_weights(self):
48
- nn.init.xavier_uniform_(self.classifier.weight)
49
- if self.classifier.bias is not None:
50
- nn.init.zeros_(self.classifier.bias)
51
-
52
- def forward(
53
- self,
54
- input_ids: torch.Tensor,
55
- bbox: torch.Tensor,
56
- attention_mask: torch.Tensor,
57
- labels: Optional[torch.Tensor] = None,
58
- ) -> Union[torch.Tensor, Tuple[List[List[int]], Any]]:
59
-
60
- outputs = self.layoutlmv3(
61
- input_ids=input_ids,
62
- bbox=bbox,
63
- attention_mask=attention_mask,
64
- return_dict=True
65
- )
66
-
67
- sequence_output = outputs.last_hidden_state
68
- emissions = self.classifier(sequence_output)
69
- mask = attention_mask.bool()
70
-
71
- if labels is not None:
72
- log_likelihood = self.crf(emissions, labels, mask=mask)
73
- loss = -log_likelihood.mean()
74
- return loss
75
- else:
76
- best_paths = self.crf.viterbi_decode(emissions, mask=mask)
77
- return best_paths
78
-
79
-
80
- # ============================================================================
81
- # COLUMN DETECTION MODULE (Re-included for completeness)
82
- # ============================================================================
83
-
84
- def get_word_data_for_detection(page: fitz.Page, top_margin_percent=0.10, bottom_margin_percent=0.10) -> list:
85
- """Extracts word data for column detection with Y-axis filtering."""
86
- word_data = page.get_text("words")
87
- if len(word_data) == 0:
88
- # Fallback to Tesseract if PyMuPDF finds no words
89
- try:
90
- pix = page.get_pixmap(matrix=fitz.Matrix(3, 3))
91
- img_bytes = pix.tobytes("png")
92
- img = Image.open(io.BytesIO(img_bytes))
93
- data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT)
94
-
95
- full_word_data = []
96
- for i in range(len(data['level'])):
97
- if data['text'][i].strip():
98
- x1 = data['left'][i] / 3
99
- y1 = data['top'][i] / 3
100
- x2 = (data['left'][i] + data['width'][i]) / 3
101
- y2 = (data['top'][i] + data['height'][i]) / 3
102
- word = data['text'][i]
103
- full_word_data.append((word, x1, y1, x2, y2))
104
-
105
- word_data = full_word_data
106
- except Exception as e:
107
- # print(f"Tesseract fallback failed: {e}")
108
- return []
109
- else:
110
- word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data]
111
-
112
- page_height = page.rect.height
113
- y_min = page_height * top_margin_percent
114
- y_max = page_height * (1 - bottom_margin_percent)
115
-
116
- filtered_data = [
117
- (word, x1, y1, x2, y2)
118
- for word, x1, y1, x2, y2 in word_data
119
- if y1 >= y_min and y2 <= y_max
120
- ]
121
- return filtered_data
122
-
123
-
124
- def calculate_x_gutters(word_data: list, params: Dict) -> List[int]:
125
- """Calculates the X-axis histogram and detects significant gutters."""
126
- if not word_data: return []
127
-
128
- x_points = []
129
- for _, x1, _, x2, _ in word_data:
130
- x_points.extend([x1, x2])
131
-
132
- max_x = max(x_points)
133
- bin_size = params['cluster_bin_size']
134
- num_bins = int(np.ceil(max_x / bin_size))
135
-
136
- hist, bin_edges = np.histogram(x_points, bins=num_bins, range=(0, max_x))
137
- smoothed_hist = gaussian_filter1d(hist.astype(float), sigma=params['cluster_smoothing'])
138
- inverted_signal = np.max(smoothed_hist) - smoothed_hist
139
-
140
- peaks, properties = find_peaks(
141
- inverted_signal,
142
- height=0,
143
- distance=params['cluster_min_width'] / bin_size
144
- )
145
-
146
- if not peaks.size: return []
147
-
148
- threshold_value = np.percentile(smoothed_hist, params['cluster_threshold_percentile'])
149
- inverted_threshold = np.max(smoothed_hist) - threshold_value
150
- significant_peaks = peaks[properties['peak_heights'] >= inverted_threshold]
151
- separator_x_coords = [int(bin_edges[p]) for p in significant_peaks]
152
-
153
- final_separators = []
154
- prominence_threshold = params['cluster_prominence'] * np.max(smoothed_hist)
155
-
156
- for x_coord in separator_x_coords:
157
- bin_idx = np.searchsorted(bin_edges, x_coord) - 1
158
- window_size = int(params['cluster_min_width'] / bin_size)
159
-
160
- left_start, left_end = max(0, bin_idx - window_size), bin_idx
161
- right_start, right_end = bin_idx + 1, min(len(smoothed_hist), bin_idx + 1 + window_size)
162
-
163
- if left_end <= left_start or right_end <= right_start: continue
164
-
165
- avg_left_density = np.mean(smoothed_hist[left_start:left_end])
166
- avg_right_density = np.mean(smoothed_hist[right_start:right_end])
167
-
168
- if avg_left_density >= prominence_threshold and avg_right_density >= prominence_threshold:
169
- final_separators.append(x_coord)
170
-
171
- return sorted(final_separators)
172
-
173
-
174
- def detect_column_gutters(pdf_path: str, page_num: int, **params) -> Optional[int]:
175
- """Main function for column detection."""
176
- try:
177
- doc = fitz.open(pdf_path)
178
- page = doc.load_page(page_num)
179
- word_data = get_word_data_for_detection(page, params.get('top_margin_percent', 0.10),
180
- params.get('bottom_margin_percent', 0.10))
181
- if not word_data:
182
- doc.close()
183
- return None
184
-
185
- separators = calculate_x_gutters(word_data, params)
186
- doc.close()
187
-
188
- if len(separators) == 1:
189
- return separators[0]
190
- elif len(separators) > 1:
191
- page_width = page.rect.width
192
- center_x = page_width / 2
193
- best_separator = min(separators, key=lambda x: abs(x - center_x))
194
- return best_separator
195
- return None
196
- except Exception as e:
197
- print(f"DEBUG: Column detection failed for page {page_num}: {e}")
198
- return None
199
-
200
-
201
- def _merge_integrity(all_words_by_page: List[str], all_bboxes_raw: List[List[int]],
202
- column_separator_x: Optional[int]) -> List[List[str]]:
203
- """Splits the words/bboxes into two columns if a separator is present."""
204
- if column_separator_x is None:
205
- return [all_words_by_page]
206
-
207
- left_column_words = []
208
- right_column_words = []
209
- gutter_min_x = column_separator_x - 10
210
- gutter_max_x = column_separator_x + 10
211
-
212
- for i, (word, bbox_raw) in enumerate(zip(all_words_by_page, all_bboxes_raw)):
213
- x1_raw, _, x2_raw, _ = bbox_raw
214
- center_x = (x1_raw + x2_raw) / 2
215
-
216
- if center_x < column_separator_x:
217
- left_column_words.append(word)
218
- else:
219
- right_column_words.append(word)
220
-
221
- return [c for c in [left_column_words, right_column_words] if c]
222
-
223
-
224
- def post_process_predictions(words: List[str], bboxes: List[List[int]], predictions: List[str]) -> List[Dict[str, Any]]:
225
- """Converts a flat list of words (and string label predictions) into structured blocks."""
226
- structured_blocks = []
227
- current_block = None
228
-
229
- # DEBUG: Track how many blocks are created
230
- block_count = 0
231
-
232
- for word, bbox, label in zip(words, bboxes, predictions):
233
- prefix, tag = (label.split('-', 1) + [None])[:2]
234
-
235
- if prefix == 'B':
236
- if current_block:
237
- structured_blocks.append(current_block)
238
- block_count += 1
239
- # print(f" DEBUG POST: Closed block {block_count}: Tag={current_block['tag']}, Words={len(current_block['words'])}")
240
-
241
- current_block = {
242
- 'text': word,
243
- 'tag': tag,
244
- 'words': [{'text': word, 'bbox': bbox, 'label': label}],
245
- 'bbox': list(bbox)
246
- }
247
-
248
- elif prefix == 'I' and current_block and current_block['tag'] == tag:
249
- current_block['text'] += ' ' + word
250
- current_block['words'].append({'text': word, 'bbox': bbox, 'label': label})
251
- current_block['bbox'][0] = min(current_block['bbox'][0], bbox[0])
252
- current_block['bbox'][1] = min(current_block['bbox'][1], bbox[1])
253
- current_block['bbox'][2] = max(current_block['bbox'][2], bbox[2])
254
- current_block['bbox'][3] = max(current_block['bbox'][3], bbox[3])
255
-
256
- else: # 'O' or mismatching tag
257
- if current_block:
258
- structured_blocks.append(current_block)
259
- block_count += 1
260
- # print(f" DEBUG POST: Closed block {block_count}: Tag={current_block['tag']}, Words={len(current_block['words'])}")
261
- current_block = None
262
-
263
- # Handle 'O' or isolated 'I'. We include 'O' for completeness, but they might be filtered later.
264
- if label == 'O':
265
- # print(" DEBUG POST: Created individual 'OTHER' block (O).")
266
- structured_blocks.append({
267
- 'text': word,
268
- 'tag': 'OTHER',
269
- 'words': [{'text': word, 'bbox': bbox, 'label': label}],
270
- 'bbox': list(bbox)
271
- })
272
- elif prefix == 'I':
273
- # Start a block that missed the 'B'
274
- # print(f" DEBUG POST: Created isolated 'I' block: Tag={tag}.")
275
- structured_blocks.append({
276
- 'text': word,
277
- 'tag': tag,
278
- 'words': [{'text': word, 'bbox': bbox, 'label': label}],
279
- 'bbox': list(bbox)
280
- })
281
-
282
- if current_block:
283
- structured_blocks.append(current_block)
284
- block_count += 1
285
- # print(f" DEBUG POST: Closed final block {block_count}: Tag={current_block['tag']}, Words={len(current_block['words'])}")
286
-
287
- # print(f"DEBUG POST: Total structured blocks created: {len(structured_blocks)}")
288
- return structured_blocks
289
-
290
-
291
- # ============================================================================
292
- # CORE INFERENCE FUNCTION (WITH DEBUGGING LOGS)
293
- # ============================================================================
294
-
295
- def run_inference_and_structure(pdf_path: str, model_path: str, inference_output_path: str,
296
- preprocessed_json_path: str,
297
- column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]:
298
- """
299
- Runs LayoutLMv3-CRF inference with extensive debugging logs.
300
- """
301
- print("--- 1. MODEL SETUP ---")
302
- tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
303
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
304
- print(f"DEBUG: Using device: {device}")
305
-
306
- try:
307
- model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS)
308
- checkpoint = torch.load(model_path, map_location=device)
309
-
310
- # Fixed Loading Logic
311
- model_state = checkpoint.get('model_state_dict', checkpoint)
312
- fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()}
313
- model.load_state_dict(fixed_state_dict)
314
- model.to(device)
315
- model.eval()
316
- print(f"✅ Model loaded successfully from {model_path}. Total {len(fixed_state_dict)} keys loaded.")
317
- except Exception as e:
318
- print(f"❌ FATAL ERROR during model loading: {e}")
319
- return []
320
-
321
- # --------------------------------------------------------------------------
322
- # 2. DATA LOADING & PREPARATION
323
- # --------------------------------------------------------------------------
324
- print("\n--- 2. DATA LOADING ---")
325
- try:
326
- with open(preprocessed_json_path, 'r', encoding='utf-8') as f:
327
- preprocessed_data = json.load(f)
328
- print(f"✅ Loaded preprocessed data with {len(preprocessed_data)} pages.")
329
- except Exception as e:
330
- print(f"❌ Error loading preprocessed JSON: {e}")
331
- return []
332
-
333
- try:
334
- doc = fitz.open(pdf_path)
335
- except Exception as e:
336
- print(f"❌ Error loading PDF: {e}")
337
- return []
338
-
339
- all_pages_data = []
340
- CHUNK_SIZE = 500
341
-
342
- for page_data in preprocessed_data:
343
- page_num_1_based = page_data['page_number']
344
- page_num_0_based = page_num_1_based - 1
345
- print(f"\nProcessing Page {page_num_1_based}...")
346
-
347
- fitz_page = doc.load_page(page_num_0_based)
348
- page_width, page_height = fitz_page.rect.width, fitz_page.rect.height
349
-
350
- words = []
351
- bboxes_raw_pdf_space = []
352
- normalized_bboxes_list = []
353
- scale_factor = 2.0
354
-
355
- for item in page_data['data']:
356
- word = item['word']
357
- raw_yolo_bbox = item['bbox']
358
-
359
- bbox_pdf = [
360
- int(raw_yolo_bbox[0] / scale_factor),
361
- int(raw_yolo_bbox[1] / scale_factor),
362
- int(raw_yolo_bbox[2] / scale_factor),
363
- int(raw_yolo_bbox[3] / scale_factor)
364
- ]
365
-
366
- normalized_bbox = [
367
- max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))),
368
- max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))),
369
- max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))),
370
- max(0, min(1000, int(1000 * bbox_pdf[3] / page_height)))
371
- ]
372
-
373
- words.append(word)
374
- bboxes_raw_pdf_space.append(bbox_pdf)
375
- normalized_bboxes_list.append(normalized_bbox)
376
-
377
- if not words:
378
- print(f" DEBUG: Page {page_num_1_based} has no words in preprocessed data. Skipping.")
379
- continue
380
-
381
- print(f" DEBUG: Page {page_num_1_based} extracted {len(words)} words.")
382
-
383
- # --------------------------------------------------------------------------
384
- # 3. COLUMN DETECTION & CHUNKING
385
- # --------------------------------------------------------------------------
386
- column_detection_params = column_detection_params or {}
387
- column_separator_x = detect_column_gutters(pdf_path, page_num_0_based, **column_detection_params)
388
-
389
- if column_separator_x is not None:
390
- print(f" DEBUG: Column detected at X={column_separator_x}. Splitting.")
391
- else:
392
- print(f" DEBUG: No column detected. Processing as a single chunk.")
393
-
394
- word_chunks = _merge_integrity(words, bboxes_raw_pdf_space, column_separator_x)
395
- print(f" DEBUG: Split into {len(word_chunks)} column/chunks.")
396
-
397
- page_structured_data = {'page_number': page_num_1_based, 'structured_blocks': []}
398
-
399
- # --------------------------------------------------------------------------
400
- # 4. INFERENCE LOOP
401
- # --------------------------------------------------------------------------
402
-
403
- # Re-alignment is simplified and potentially slow. A proper way would be to
404
- # split all three lists (words, bboxes_pdf, bboxes_norm) at the same time.
405
- # But we stick to your original approach for minimal changes.
406
-
407
- current_word_idx = 0
408
-
409
- for chunk_idx, chunk_words in enumerate(word_chunks):
410
- if not chunk_words: continue
411
-
412
- # Reconstruct the aligned chunk data (Simplified version of your complex loop)
413
- current_original_index = 0
414
- temp_chunk_norm_bboxes = []
415
- temp_chunk_pdf_bboxes = []
416
- found_words = []
417
-
418
- # Simple, but slow, way to re-align data for the chunk:
419
- for word_to_find in chunk_words:
420
- try:
421
- # Find the index of the word in the master list, starting search from the last found position
422
- i = words[current_original_index:].index(word_to_find) + current_original_index
423
- temp_chunk_norm_bboxes.append(normalized_bboxes_list[i])
424
- temp_chunk_pdf_bboxes.append(bboxes_raw_pdf_space[i])
425
- found_words.append(words[i])
426
- current_original_index = i + 1
427
- except ValueError:
428
- # print(f" WARNING: Word '{word_to_find}' not found during re-alignment.")
429
- pass # Skip missing words
430
-
431
- chunk_words = found_words
432
- chunk_normalized_bboxes = temp_chunk_norm_bboxes
433
- chunk_bboxes_pdf = temp_chunk_pdf_bboxes
434
-
435
- print(f" DEBUG: Column/Chunk {chunk_idx + 1} has {len(chunk_words)} words.")
436
-
437
- # Sub-chunking for max_seq_len (512)
438
- for i in range(0, len(chunk_words), CHUNK_SIZE):
439
- sub_words = chunk_words[i:i + CHUNK_SIZE]
440
- sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE]
441
- sub_bboxes_pdf = chunk_bboxes_pdf[i:i + CHUNK_SIZE]
442
-
443
- encoded_input = tokenizer(
444
- sub_words,
445
- boxes=sub_bboxes,
446
- truncation=True,
447
- padding="max_length",
448
- max_length=512,
449
- # is_split_into_words=True,
450
- return_tensors="pt"
451
- )
452
-
453
- input_ids = encoded_input['input_ids'].to(device)
454
- bbox = encoded_input['bbox'].to(device)
455
- attention_mask = encoded_input['attention_mask'].to(device)
456
-
457
- print(f" DEBUG INFER: Sub-chunk size: {len(sub_words)} words. Input shape: {input_ids.shape}")
458
-
459
- with torch.no_grad():
460
- predictions_int_list = model(input_ids, bbox, attention_mask)
461
-
462
- if not predictions_int_list:
463
- print(" ❌ INFERENCE FAILED: Model returned empty list of predictions.")
464
- continue
465
-
466
- predictions_int = predictions_int_list[0]
467
-
468
- # --- CHECK FOR NON-'O' PREDICTIONS ---
469
- non_o_count = sum(1 for p in predictions_int if p != 0)
470
- print(
471
- f" DEBUG INFER: Raw predictions (tokens): Total {len(predictions_int)}. Non-'O' tokens: {non_o_count}.")
472
- if non_o_count == 0:
473
- print(" ⚠️ WARNING: Model is predicting 'O' for all tokens. Check training or input quality.")
474
- # -----------------------------------
475
-
476
- # Map token predictions back to original words
477
- word_ids = encoded_input.word_ids()
478
- word_idx_to_pred_id = {}
479
-
480
- for token_idx, word_idx in enumerate(word_ids):
481
- if word_idx is not None and word_idx < len(sub_words):
482
- # Only take the prediction of the FIRST sub-token for a word
483
- if word_idx not in word_idx_to_pred_id:
484
- word_idx_to_pred_id[word_idx] = predictions_int[token_idx]
485
-
486
- final_predictions_str = []
487
- # Map integer IDs back to string labels
488
- for current_word_idx in range(len(sub_words)):
489
- pred_id_or_tensor = word_idx_to_pred_id.get(current_word_idx, 0)
490
- pred_id = pred_id_or_tensor.item() if torch.is_tensor(pred_id_or_tensor) else pred_id_or_tensor
491
-
492
- # This is the final word-level prediction. If it's always 0, post-processing fails.
493
- final_predictions_str.append(ID_TO_LABEL[pred_id])
494
-
495
- # --- POST-PROCESSING ---
496
- structured_blocks = post_process_predictions(sub_words, sub_bboxes_pdf, final_predictions_str)
497
-
498
- print(f" DEBUG POST: Created {len(structured_blocks)} structured blocks from this sub-chunk.")
499
-
500
- page_structured_data['structured_blocks'].extend(structured_blocks)
501
-
502
- print(
503
- f" DEBUG: Page {page_num_1_based} final total structured blocks: {len(page_structured_data['structured_blocks'])}")
504
- all_pages_data.append(page_structured_data)
505
-
506
- doc.close()
507
-
508
- # Save final structured predictions
509
- with open(inference_output_path, 'w', encoding='utf-8') as f:
510
- json.dump(all_pages_data, f, indent=4)
511
-
512
- print(f"\n✅ All pages processed. Structured data saved to {os.path.basename(inference_output_path)}")
513
-
514
- return all_pages_data
515
-
516
-
517
- # --- 5. Label Studio Conversion Utility (Included for completeness) ---
518
-
519
- def create_label_studio_span(all_results, start_idx, end_idx, label):
520
- """Create a Label Studio span with character-level offsets."""
521
- entity_words = [all_results[i]['word'] for i in range(start_idx, end_idx + 1)]
522
- entity_bboxes = [all_results[i]['bbox'] for i in range(start_idx, end_idx + 1)]
523
-
524
- x0 = min(bbox[0] for bbox in entity_bboxes)
525
- y0 = min(bbox[1] for bbox in entity_bboxes)
526
- x1 = max(bbox[2] for bbox in entity_bboxes)
527
- y1 = max(bbox[3] for bbox in entity_bboxes)
528
-
529
- all_words = [r['word'] for r in all_results]
530
- text_string = " ".join(all_words)
531
-
532
- prefix_words = all_words[:start_idx]
533
- start_char = len(" ".join(prefix_words)) + (1 if prefix_words else 0)
534
- span_text = " ".join(entity_words)
535
- end_char = start_char + len(span_text)
536
-
537
- return {
538
- "from_name": "label",
539
- "to_name": "text",
540
- "type": "labels",
541
- "value": {
542
- "start": start_char,
543
- "end": end_char,
544
- "text": span_text,
545
- "labels": [label],
546
- "bbox": {
547
- "x": x0,
548
- "y": y0,
549
- "width": x1 - x0,
550
- "height": y1 - y0
551
- }
552
- },
553
- "score": 0.99
554
- }
555
-
556
-
557
- def convert_to_label_studio_format(structured_data: List[Dict[str, Any]],
558
- output_path: str,
559
- pdf_file_name: str) -> None:
560
- """Convert structured predictions to Label Studio format."""
561
- final_tasks = []
562
-
563
- for page_data in structured_data:
564
- page_num = page_data['page_number']
565
- if 'structured_blocks' not in page_data: continue
566
-
567
- page_results = []
568
- for block in page_data['structured_blocks']:
569
- if 'words' in block:
570
- for word_info in block['words']:
571
- page_results.append({
572
- 'word': word_info['text'],
573
- 'bbox': word_info['bbox'],
574
- # FIX: Use the full label string (e.g., 'B-QUESTION')
575
- 'predicted_label': word_info['label']
576
- })
577
-
578
- if not page_results:
579
- print(f"DEBUG LS: Page {page_num} has no word-level results. Skipping.")
580
- continue
581
-
582
- original_words = [r['word'] for r in page_results]
583
- original_bboxes = [r['bbox'] for r in page_results]
584
- text_string = " ".join(original_words)
585
-
586
- results = []
587
- current_entity_label = None
588
- current_entity_start_word_index = None
589
-
590
- for i, pred_item in enumerate(page_results):
591
- label = pred_item['predicted_label']
592
-
593
- # Get the tag (e.g., 'QUESTION' from 'B-QUESTION')
594
- tag_only = label.split('-', 1)[-1] if '-' in label else label
595
-
596
- if label.startswith('B-'):
597
- if current_entity_label:
598
- results.append(create_label_studio_span(
599
- page_results, current_entity_start_word_index, i - 1, current_entity_label
600
- ))
601
- current_entity_label = tag_only
602
- current_entity_start_word_index = i
603
-
604
- elif label.startswith('I-') and current_entity_label == tag_only:
605
- continue
606
-
607
- else: # Label is 'O' or doesn't match current entity
608
- if current_entity_label:
609
- results.append(create_label_studio_span(
610
- page_results, current_entity_start_word_index, i - 1, current_entity_label
611
- ))
612
- current_entity_label = None
613
- current_entity_start_word_index = None
614
-
615
- if current_entity_label:
616
- results.append(create_label_studio_span(
617
- page_results, current_entity_start_word_index, len(page_results) - 1, current_entity_label
618
- ))
619
-
620
- print(f"DEBUG LS: Page {page_num} generated {len(results)} Label Studio spans.")
621
-
622
- task = {
623
- "data": {
624
- "text": text_string,
625
- "original_words": original_words,
626
- "original_bboxes": original_bboxes
627
- },
628
- "annotations": [{"result": results}],
629
- "meta": {"page_number": page_num, "column_index": 1}
630
- }
631
- final_tasks.append(task)
632
-
633
- with open(output_path, "w", encoding='utf-8') as f:
634
- json.dump(final_tasks, f, indent=2, ensure_ascii=False)
635
- print(f"\n✅ Label Studio tasks created and saved to {output_path}")
636
-
637
-
638
- if __name__ == "__main__":
639
- parser = argparse.ArgumentParser(
640
- description="LayoutLMv3 Inference Pipeline for PDF and Label Studio OCR Conversion.")
641
- parser.add_argument("--input_pdf", type=str, required=True,
642
- help="Path to the input PDF file for inference.")
643
- parser.add_argument("--model_path", type=str,
644
- default="checkpoints/layoutlmv3_trained_20251031_102846_recovered.pth",
645
- help="Path to the saved LayoutLMv3-CRF PyTorch model checkpoint.")
646
- parser.add_argument("--inference_output", type=str, default="structured_yolo_predictions.json",
647
- help="Path to save the intermediate structured predictions.")
648
- parser.add_argument("--label_studio_output", type=str, default="label_studio_import.json",
649
- help="Path to save the final Label Studio import JSON.")
650
- parser.add_argument("--preprocessed_json", type=str, required=True,
651
- help="Path to the combined JSON output from the YOLO/OCR script.")
652
- parser.add_argument("--no_labelstudio", action="store_true",
653
- help="If set, skip creating the Label Studio import JSON and only write structured predictions.")
654
- parser.add_argument("--verbose", action="store_true",
655
- help="Enable verbose printing.")
656
- args = parser.parse_args()
657
-
658
- # 1. Check for required files
659
- print("--- 0. PRE-CHECK ---")
660
- if not os.path.exists(args.model_path):
661
- print(f"❌ FATAL ERROR: Model checkpoint not found at {args.model_path}.")
662
- sys.exit(1)
663
- if not os.path.exists(args.input_pdf):
664
- print(f"❌ FATAL ERROR: Input PDF not found at {args.input_pdf}.")
665
- sys.exit(1)
666
- if not os.path.exists(args.preprocessed_json):
667
- print(f"❌ FATAL ERROR: Preprocessed JSON not found at {args.preprocessed_json}. Run the YOLO/OCR script first.")
668
- sys.exit(1)
669
- print("✅ All required files found.")
670
-
671
- # 2. Column Detection Parameters (Tuning required)
672
- column_params = {
673
- 'top_margin_percent': 0.10,
674
- 'bottom_margin_percent': 0.10,
675
- 'cluster_prominence': 0.70,
676
- 'cluster_bin_size': 5,
677
- 'cluster_smoothing': 2,
678
- 'cluster_threshold_percentile': 30,
679
- 'cluster_min_width': 25,
680
- }
681
-
682
- # 3. Run inference
683
- try:
684
- structured_data = run_inference_and_structure(
685
- args.input_pdf,
686
- args.model_path,
687
- args.inference_output,
688
- args.preprocessed_json,
689
- column_detection_params=column_params
690
- )
691
- except Exception as e:
692
- print(f"❌ Fatal error while running inference: {e}")
693
- structured_data = []
694
-
695
- # 4. If requested, convert to Label Studio format
696
- if structured_data and not args.no_labelstudio:
697
- try:
698
- convert_to_label_studio_format(
699
- structured_data=structured_data,
700
- output_path=args.label_studio_output,
701
- pdf_file_name=args.input_pdf
702
- )
703
- except Exception as e:
704
- print(f"❌ Error while converting to Label Studio format: {e}")
705
- elif not structured_data:
706
- print("⚠️ No structured data produced — skipping Label Studio conversion.")
707
- else:
708
- print("ℹ️ Skipped Label Studio conversion as requested (--no_labelstudio).")
709
-
710
- # 5. Final status message
711
- print("\n--- 5. FINAL STATUS ---")
712
- print(f"Finished. Structured predictions file: {os.path.abspath(args.inference_output)}")
713
- if structured_data and not args.no_labelstudio:
714
- print(f"Label Studio import file: {os.path.abspath(args.label_studio_output)}")