Kushalguptaiitb commited on
Commit
608282b
·
verified ·
1 Parent(s): e150363

Delete layout_detection_docling_heron.py

Browse files
Files changed (1) hide show
  1. layout_detection_docling_heron.py +0 -497
layout_detection_docling_heron.py DELETED
@@ -1,497 +0,0 @@
1
- import cv2
2
- import os
3
- import supervision as sv # pip install supervision
4
- from transformers import RTDetrV2ForObjectDetection, RTDetrImageProcessor
5
- from pdf2image import convert_from_path
6
- import numpy as np
7
- from PIL import Image
8
- import json
9
- import pytesseract
10
- import pandas as pd
11
- from sentence_transformers import SentenceTransformer, util
12
- from PyPDF2 import PdfReader
13
- from datetime import datetime
14
- import torch
15
- import logging
16
- from utils.utils_code import log_time_taken
17
- from concurrent.futures import ProcessPoolExecutor, as_completed
18
- import multiprocessing
19
- import sys
20
- import gc
21
-
22
- from src.table_processing.tree_structured_json import tree_structured_headers_pipeline
23
- from config.set_config import set_configuration
24
- set_config_project = set_configuration()
25
- layout_model_weights_path = set_config_project.layout_model_weights_path
26
- no_of_threads = set_config_project.no_of_threads
27
- from src.docling.ttsr_docling import tsr_inference_image, tsr_inference
28
- from src.table_processing.table_classification_extraction import process_table_classification_extraction_pipeline
29
- from src.table_processing.put_table_header import put_table_header_pipeline
30
- import gc
31
- from src.layout_detection.load_model import load_model_for_process
32
-
33
- # Set multiprocessing start method
34
- multiprocessing.set_start_method('spawn', force=True)
35
- logger = logging.getLogger(__name__)
36
-
37
- # Configure logging
38
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
39
-
40
- def load_torch(version):
41
- if version == "2.2.2":
42
- sys.path.insert(0, "./torch_2_2_2")
43
- elif version == "2.6.0":
44
- sys.path.insert(0, "./torch_2_6_0")
45
- import torch
46
- logger.info(f"Using Torch Version: {torch.__version__}")
47
- return torch
48
-
49
- torch = load_torch("2.2.2")
50
-
51
-
52
-
53
- def get_file_name_without_extension(file_path):
54
- directory, file_name = os.path.split(file_path)
55
- name, extension = os.path.splitext(file_name)
56
- return name
57
-
58
- def convert_numpy(data):
59
- if isinstance(data, dict):
60
- return {key: convert_numpy(value) for key, value in data.items()}
61
- elif isinstance(data, list):
62
- return [convert_numpy(item) for item in data]
63
- elif isinstance(data, np.integer):
64
- return int(data)
65
- elif isinstance(data, np.floating):
66
- return float(data)
67
- elif isinstance(data, np.ndarray):
68
- return data.tolist()
69
- elif isinstance(data, pd.DataFrame):
70
- return data.to_dict(orient='records')
71
- else:
72
- return data
73
-
74
- def filter_layout_blocks(input_data):
75
- filtered_layout_blocks = []
76
- for blocks in input_data.values():
77
- filtered_layout_blocks.extend([block for block in blocks])
78
- return filtered_layout_blocks
79
-
80
- def convert_pdf_to_images(file_path, batch_size=20, dpi=100):
81
- images = convert_from_path(file_path, dpi=dpi)
82
- total_pages = len(images)
83
-
84
- def page_generator():
85
- for start_page in range(1, total_pages + 1, batch_size):
86
- end_page = min(start_page + batch_size - 1, total_pages)
87
- yield images[start_page-1:end_page]
88
-
89
- return page_generator()
90
-
91
- def read_json(json_file):
92
- with open(json_file, 'r') as file:
93
- return json.load(file)
94
-
95
- def filter_and_sort_headers(data, modified_json_output_filepath):
96
- def sort_blocks_by_min_x(blocks):
97
- return sorted(blocks, key=lambda block: block['bbox'][0])
98
-
99
- def sort_blocks_by_min_y(blocks):
100
- return sorted(blocks, key=lambda block: block['bbox'][1])
101
-
102
- def find_headers_and_group(sorted_blocks):
103
- headers_list = []
104
- current_group = []
105
- previous_block = None
106
-
107
- for i, block in enumerate(sorted_blocks):
108
- if previous_block:
109
- prev_xmax = previous_block['bbox'][2]
110
- prev_xmax_threshold = int(previous_block['bbox'][2])
111
- if block['bbox'][0] > prev_xmax and block['bbox'][0] > prev_xmax_threshold:
112
- if current_group:
113
- headers_list.extend(sort_blocks_by_min_y(current_group))
114
- current_group = []
115
- current_group.append(block)
116
- previous_block = block
117
-
118
- if current_group:
119
- headers_list.extend(sort_blocks_by_min_y(current_group))
120
-
121
- return headers_list
122
-
123
- result = {}
124
- for key, blocks in data.items():
125
- sorted_blocks = sort_blocks_by_min_x(blocks)
126
- sorted_headers = find_headers_and_group(sorted_blocks)
127
- result[key] = sorted_headers
128
-
129
- sorted_data = result
130
- with open(modified_json_output_filepath, 'w') as f:
131
- json.dump(sorted_data, f, indent=4)
132
-
133
- return sorted_data, modified_json_output_filepath
134
-
135
- def filter_and_sort_layouts(data, modified_json_output_filepath):
136
- def sort_blocks_by_min_x(blocks):
137
- return sorted(blocks, key=lambda block: block['bbox'][0])
138
-
139
- def sort_blocks_by_min_y(blocks):
140
- return sorted(blocks, key=lambda block: block['bbox'][1])
141
-
142
- def find_classes_and_group(sorted_blocks):
143
- classes_list = []
144
- current_group = []
145
- previous_block = None
146
-
147
- for i, block in enumerate(sorted_blocks):
148
- if previous_block:
149
- prev_xmax = previous_block['bbox'][2]
150
- prev_xmax_threshold = int(previous_block['bbox'][2])
151
- if block['bbox'][0] > prev_xmax and block['bbox'][0] > prev_xmax_threshold:
152
- if current_group:
153
- classes_list.extend(sort_blocks_by_min_y(current_group))
154
- current_group = []
155
- current_group.append(block)
156
- previous_block = block
157
-
158
- if current_group:
159
- classes_list.extend(sort_blocks_by_min_y(current_group))
160
-
161
- return classes_list
162
-
163
- result = {}
164
- for key, blocks in data.items():
165
- sorted_blocks = sort_blocks_by_min_x(blocks)
166
- sorted_layouts = find_classes_and_group(sorted_blocks)
167
- result[key] = sorted_layouts
168
-
169
- sorted_layout_data = result
170
- with open(modified_json_output_filepath, 'w') as f:
171
- json.dump(sorted_layout_data, f, indent=4)
172
-
173
- return sorted_layout_data, modified_json_output_filepath
174
-
175
- @log_time_taken
176
- def layout_detection(img_path, model, image_processor, threshold=0.6, device='cuda' if torch.cuda.is_available() else 'cpu'):
177
- try:
178
- image = Image.open(img_path).convert("RGB")
179
-
180
- # Process image with the Docling Heron model
181
- inputs = image_processor(images=[image], return_tensors="pt")
182
-
183
- # Move inputs to the same device as the model
184
- inputs = {k: v.to(device) for k, v in inputs.items()}
185
-
186
- with torch.no_grad():
187
- outputs = model(**inputs)
188
-
189
- # Post-process the results
190
- results = image_processor.post_process_object_detection(
191
- outputs,
192
- target_sizes=torch.tensor([image.size[::-1]], device=device),
193
- threshold=threshold
194
- )[0]
195
-
196
- # Move results to CPU for further processing
197
- results = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in results.items()}
198
-
199
- # Convert to supervision Detections format for compatibility
200
- xyxy = results["boxes"].numpy()
201
- confidence = results["scores"].numpy()
202
- class_id = results["labels"].numpy()
203
- class_name = [model.config.id2label[label_id] for label_id in class_id]
204
-
205
- detections = sv.Detections(
206
- xyxy=xyxy,
207
- confidence=confidence,
208
- class_id=class_id,
209
- data={"class_name": class_name}
210
- )
211
-
212
- # Custom bounding box color (Red)
213
- bbox_color = sv.Color(r=255, g=0, b=0)
214
- bounding_box_annotator = sv.BoxAnnotator(color=bbox_color)
215
- label_annotator = sv.LabelAnnotator()
216
-
217
- # Annotate the image
218
- image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
219
- annotated_image = bounding_box_annotator.annotate(scene=image_cv, detections=detections)
220
- annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
221
-
222
- # Clean up
223
- del inputs, outputs
224
- torch.cuda.empty_cache() if device == 'cuda' else None
225
- gc.collect()
226
-
227
- return annotated_image, detections, results
228
-
229
- except Exception as e:
230
- logger.error(f"Error in layout_detection for {img_path}: {str(e)}")
231
- raise
232
-
233
- def enhance_dpi(image, new_dpi=300, old_dpi=150):
234
- old_dpi = int(old_dpi)
235
- new_dpi = int(new_dpi)
236
- scaling_factor = new_dpi / old_dpi
237
- new_size = (int(image.width * scaling_factor), int(image.height * scaling_factor))
238
- resized_image = image.resize(new_size, Image.LANCZOS)
239
- return resized_image
240
-
241
- def extract_text_from_bbox(image, bbox):
242
- if isinstance(image, Image.Image):
243
- image = np.array(image)
244
- elif isinstance(image, np.ndarray):
245
- pass
246
- else:
247
- raise TypeError("Unsupported image type. The image should be either a PIL Image or a NumPy array.")
248
-
249
- image_height, image_width = image.shape[:2]
250
- ymin = max(0, int(bbox['ymin'] - 5))
251
- ymax = min(image_height, int(bbox['ymax'] + 5))
252
- xmin = max(0, int(bbox['xmin'] - 20))
253
- xmax = min(image_width, int(bbox['xmax'] + 20))
254
-
255
- cropped_image = image[ymin:ymax, xmin:xmax]
256
- cropped_image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
257
- high_dpi_image = enhance_dpi(cropped_image_pil)
258
- high_dpi_image_cv = cv2.cvtColor(np.array(high_dpi_image), cv2.COLOR_RGB2BGR)
259
- gray_image = cv2.cvtColor(high_dpi_image_cv, cv2.COLOR_BGR2GRAY)
260
-
261
- custom_config = r'--oem 3 --psm 6 -c tessedit_create_alto=1'
262
- extracted_text = pytesseract.image_to_string(gray_image, config=custom_config)
263
-
264
- return extracted_text
265
-
266
- def check_extracted_text_headers(extracted_text, header_list, model_name='all-MiniLM-L6-v2', threshold=0.8):
267
- if not isinstance(extracted_text, pd.DataFrame):
268
- return False
269
-
270
- model = SentenceTransformer(model_name)
271
- extracted_headers = list(extracted_text.columns)
272
- extracted_embeddings = model.encode(extracted_headers, convert_to_tensor=True)
273
- header_embeddings = model.encode(header_list, convert_to_tensor=True)
274
-
275
- similarity_matrix = util.pytorch_cos_sim(header_embeddings, extracted_embeddings)
276
-
277
- for i, header in enumerate(header_list):
278
- for j, extracted_header in enumerate(extracted_headers):
279
- if similarity_matrix[i][j] > threshold:
280
- logger.info(f"Matching header found: {extracted_header} (similar to {header})")
281
- return True
282
-
283
- logger.info("No matching headers found.")
284
- return False
285
-
286
- def process_page(args):
287
- (page_img, current_page_num, file_name, pdf_images_path, bbox_images_path) = args
288
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
289
- try:
290
- model, image_processor, class_names = load_model_for_process()
291
- model.to(device) # Ensure model is on the correct device
292
- image = np.array(page_img)
293
-
294
- h, w, _ = image.shape
295
- page_number = str(current_page_num)
296
-
297
- img_output_filename = f"{file_name}_page_no_{page_number}.jpeg"
298
- img_output_filepath = os.path.join(pdf_images_path, img_output_filename)
299
- pil_image = Image.fromarray(image)
300
- pil_image.save(img_output_filepath)
301
-
302
- cropped_images_path = os.path.join(pdf_images_path, f"{file_name}_cropped_images")
303
- os.makedirs(cropped_images_path, exist_ok=True)
304
-
305
- bbox_image, page_detections_info, results_info = layout_detection(img_output_filepath, model, image_processor, device=device)
306
- logger.info(f"Processed layout detection for page {page_number}")
307
-
308
- pil_bbox_image = Image.fromarray(bbox_image)
309
- bbox_output_filename = f"bbox_{file_name}_page_no_{page_number}.jpeg"
310
- bbox_output_filepath = os.path.join(bbox_images_path, bbox_output_filename)
311
- pil_bbox_image.save(bbox_output_filepath)
312
- page_information = []
313
-
314
- for idx, bbox in enumerate(page_detections_info.xyxy):
315
- label_name = page_detections_info.data['class_name'][idx]
316
- class_id = page_detections_info.class_id[idx]
317
- score = page_detections_info.confidence[idx]
318
-
319
- image_height = h
320
- image_width = w
321
-
322
- ymin = max(0, bbox[1] - 10)
323
- ymax = min(image_height, bbox[3] + 10)
324
- xmin = max(0, bbox[0] - 10)
325
- xmax = min(image_width, bbox[2] + 10)
326
-
327
- new_bbox = {
328
- "xmin": int(bbox[0]),
329
- "ymin": int(bbox[1]),
330
- "xmax": int(bbox[2]),
331
- "ymax": int(bbox[3])
332
- }
333
-
334
- cropped_labels_images_path = os.path.join(cropped_images_path, f"{file_name}_{label_name}_cropped_images")
335
- os.makedirs(cropped_labels_images_path, exist_ok=True)
336
-
337
- crop_label_image_filename = f"{file_name}_label_name{label_name}_page_no_{page_number}_id_{idx + 1}.png"
338
- crop_label_image_filename_filepath = os.path.join(cropped_labels_images_path, crop_label_image_filename)
339
-
340
- crop_label_image_bbox = (new_bbox["xmin"], new_bbox["ymin"], new_bbox["xmax"], new_bbox["ymax"])
341
- cropped_label_pil_image = pil_image.crop(crop_label_image_bbox)
342
- cropped_label_pil_image.save(crop_label_image_filename_filepath)
343
-
344
- if label_name == 'Table':
345
- crop_bbox = (new_bbox["xmin"], new_bbox["ymin"], new_bbox["xmax"], new_bbox["ymax"])
346
- cropped_image = pil_image.crop(crop_bbox)
347
- df_post_processed, df_original = tsr_inference_image(cropped_image)
348
- extracted_df = df_post_processed
349
- extracted_text = extracted_df
350
-
351
- if isinstance(df_original, pd.DataFrame):
352
- extracted_df_markdown = df_original.to_markdown()
353
- else:
354
- extracted_df_markdown = df_original
355
- else:
356
- extracted_text = extract_text_from_bbox(image, new_bbox)
357
- extracted_df_markdown = ""
358
-
359
- page_block_id = f"{str(idx + 1) + str(current_page_num)}"
360
- page_block_id = int(page_block_id)
361
-
362
- page_information.append({
363
- 'page_block_id': page_block_id,
364
- 'label_name': label_name,
365
- 'pdf_page_id': current_page_num,
366
- 'pdf_name': file_name,
367
- 'label_id': class_id,
368
- 'yolo_detection_confidence_score': score,
369
- 'bbox': [xmin, ymin, xmax, ymax],
370
- 'page_img_width': w,
371
- 'page_img_height': h,
372
- 'extracted_text': [extracted_text],
373
- "extracted_table_markdown": [extracted_df_markdown]
374
- })
375
-
376
- # Clean up
377
- del image, bbox_image, model, image_processor
378
- torch.cuda.empty_cache() if device == 'cuda' else None
379
- gc.collect()
380
-
381
- return page_number, page_information, class_names
382
-
383
- except Exception as e:
384
- logger.error(f"Error processing page {current_page_num}: {str(e)}")
385
- raise
386
-
387
- @log_time_taken
388
- def yolov10_layout_pipeline(file_name, file_path, directory_path):
389
- if not file_path.lower().endswith('.pdf'):
390
- raise ValueError("Input file must be a PDF.")
391
-
392
- logger.info(f"Starting processing for {file_name}")
393
- start_time = datetime.now()
394
- file_name = get_file_name_without_extension(file_path)
395
-
396
- pdf_images_path = os.path.join(directory_path, f"{file_name}_images")
397
- os.makedirs(pdf_images_path, exist_ok=True)
398
-
399
- bbox_images_path = os.path.join(pdf_images_path, f"{file_name}_bbox_images")
400
- os.makedirs(bbox_images_path, exist_ok=True)
401
-
402
- json_output_path = os.path.join(directory_path, f"{file_name}_json_output")
403
- os.makedirs(json_output_path, exist_ok=True)
404
-
405
- total_pages_processed = 0
406
- data_pdf = {}
407
-
408
- try:
409
- page_generator = convert_pdf_to_images(file_path, batch_size=20, dpi=150)
410
-
411
- page_args = []
412
- for pages in page_generator:
413
- if not pages:
414
- break
415
-
416
- for page_num, page_img in enumerate(pages):
417
- current_page_num = total_pages_processed + page_num + 1
418
- logger.info(f"Processing file {file_name}, page {current_page_num}")
419
-
420
- page_args.append((
421
- page_img,
422
- current_page_num,
423
- file_name,
424
- pdf_images_path,
425
- bbox_images_path
426
- ))
427
-
428
- total_pages_processed += len(pages)
429
-
430
- logger.info(f"Total pages to process: {total_pages_processed}")
431
- with ProcessPoolExecutor(max_workers=no_of_threads) as executor:
432
- future_to_page = {executor.submit(process_page, arg): arg[1] for arg in page_args}
433
- for future in as_completed(future_to_page):
434
- page_number = future_to_page[future]
435
- try:
436
- result = future.result()
437
- page_number, page_information, class_names = result
438
- data_pdf[page_number] = page_information
439
- except Exception as e:
440
- logger.error(f"Error processing page {page_number}: {str(e)}")
441
- raise
442
-
443
- logger.info(f"Processed pages: {data_pdf.keys()}")
444
- layout_json_file_path = os.path.join(json_output_path, f"yolo_model_detections_{file_name}.json")
445
- user_modification_json_file_path = os.path.join(json_output_path, f"user_modified_{file_name}.json")
446
- tree_structured_json_output_path = os.path.join(json_output_path, f"tree_structured_headers_{file_name}.json")
447
- data_pdf = convert_numpy(data_pdf)
448
- layout_list_data = filter_layout_blocks(data_pdf)
449
-
450
- with open(layout_json_file_path, 'w') as json_file:
451
- json.dump(data_pdf, json_file, indent=4)
452
-
453
- with open(user_modification_json_file_path, 'w') as json_file:
454
- json.dump(data_pdf, json_file, indent=4)
455
-
456
- sorted_data, modified_json_output_filepath = filter_and_sort_headers(data_pdf, user_modification_json_file_path)
457
- tree_structured_organized_json_data = tree_structured_headers_pipeline(user_modification_json_file_path, tree_structured_json_output_path)
458
- sorted_layout_data, sorted_layout_json_filepath = filter_and_sort_layouts(data_pdf, layout_json_file_path)
459
-
460
- filtered_table_header_data, filtered_table_header_data_json_path = put_table_header_pipeline(user_modification_json_file_path, json_output_path, file_name)
461
- end_time = datetime.now()
462
-
463
- logger.info(f"Processed {file_name} from {start_time} to {end_time}, duration: {end_time - start_time}")
464
- logger.info(f"JSON file created at: {modified_json_output_filepath}")
465
- return (
466
- json_output_path,
467
- layout_list_data,
468
- class_names,
469
- sorted_data,
470
- modified_json_output_filepath,
471
- pdf_images_path,
472
- file_name,
473
- sorted_layout_data,
474
- sorted_layout_json_filepath,
475
- tree_structured_organized_json_data,
476
- tree_structured_json_output_path,
477
- filtered_table_header_data,
478
- filtered_table_header_data_json_path
479
- )
480
-
481
- except Exception as e:
482
- logger.error(f"Error in yolov10_layout_pipeline: {str(e)}")
483
- raise
484
- finally:
485
- # Ensure GPU memory is cleared
486
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
487
- gc.collect()
488
-
489
- # Example usage
490
- if __name__ == "__main__":
491
- pdf_path = "/shared_disk/kushal/db_str_chunking/new_ws_structured_code/Flexstone_Investor_Report_Test.pdf"
492
- output_directory = "/shared_disk/kushal/db_str_chunking/new_ws_structured_code/clearstreet_docs/iqeq_docling_heron_bbox_images"
493
- file_name = get_file_name_without_extension(pdf_path)
494
- yolov10_layout_pipeline(file_name, pdf_path, output_directory)
495
-
496
-
497
-