Kushalguptaiitb commited on
Commit
9568e27
·
verified ·
1 Parent(s): 8d374b9

Upload 3 files

Browse files
layout_detection_docling_heron.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import supervision as sv # pip install supervision
4
+ from transformers import RTDetrV2ForObjectDetection, RTDetrImageProcessor
5
+ from pdf2image import convert_from_path
6
+ import numpy as np
7
+ from PIL import Image
8
+ import json
9
+ import pytesseract
10
+ import pandas as pd
11
+ from sentence_transformers import SentenceTransformer, util
12
+ from PyPDF2 import PdfReader
13
+ from datetime import datetime
14
+ import torch
15
+ import logging
16
+ from utils.utils_code import log_time_taken
17
+ from concurrent.futures import ProcessPoolExecutor, as_completed
18
+ import multiprocessing
19
+ import sys
20
+ import gc
21
+
22
+ from src.table_processing.tree_structured_json import tree_structured_headers_pipeline
23
+ from config.set_config import set_configuration
24
+ set_config_project = set_configuration()
25
+ layout_model_weights_path = set_config_project.layout_model_weights_path
26
+ no_of_threads = set_config_project.no_of_threads
27
+ from src.docling.ttsr_docling import tsr_inference_image, tsr_inference
28
+ from src.table_processing.table_classification_extraction import process_table_classification_extraction_pipeline
29
+ from src.table_processing.put_table_header import put_table_header_pipeline
30
+ import gc
31
+ from src.layout_detection.load_model import load_model_for_process
32
+
33
+ # Set multiprocessing start method
34
+ multiprocessing.set_start_method('spawn', force=True)
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # Configure logging
38
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
39
+
40
+ def load_torch(version):
41
+ if version == "2.2.2":
42
+ sys.path.insert(0, "./torch_2_2_2")
43
+ elif version == "2.6.0":
44
+ sys.path.insert(0, "./torch_2_6_0")
45
+ import torch
46
+ logger.info(f"Using Torch Version: {torch.__version__}")
47
+ return torch
48
+
49
+ torch = load_torch("2.2.2")
50
+
51
+ MODEL_NAME_DOCLING = "ds4sd/docling-layout-heron"
52
+
53
+ def get_file_name_without_extension(file_path):
54
+ directory, file_name = os.path.split(file_path)
55
+ name, extension = os.path.splitext(file_name)
56
+ return name
57
+
58
+ def convert_numpy(data):
59
+ if isinstance(data, dict):
60
+ return {key: convert_numpy(value) for key, value in data.items()}
61
+ elif isinstance(data, list):
62
+ return [convert_numpy(item) for item in data]
63
+ elif isinstance(data, np.integer):
64
+ return int(data)
65
+ elif isinstance(data, np.floating):
66
+ return float(data)
67
+ elif isinstance(data, np.ndarray):
68
+ return data.tolist()
69
+ elif isinstance(data, pd.DataFrame):
70
+ return data.to_dict(orient='records')
71
+ else:
72
+ return data
73
+
74
+ def filter_layout_blocks(input_data):
75
+ filtered_layout_blocks = []
76
+ for blocks in input_data.values():
77
+ filtered_layout_blocks.extend([block for block in blocks])
78
+ return filtered_layout_blocks
79
+
80
+ def convert_pdf_to_images(file_path, batch_size=20, dpi=100):
81
+ images = convert_from_path(file_path, dpi=dpi)
82
+ total_pages = len(images)
83
+
84
+ def page_generator():
85
+ for start_page in range(1, total_pages + 1, batch_size):
86
+ end_page = min(start_page + batch_size - 1, total_pages)
87
+ yield images[start_page-1:end_page]
88
+
89
+ return page_generator()
90
+
91
+ def read_json(json_file):
92
+ with open(json_file, 'r') as file:
93
+ return json.load(file)
94
+
95
+ def filter_and_sort_headers(data, modified_json_output_filepath):
96
+ def sort_blocks_by_min_x(blocks):
97
+ return sorted(blocks, key=lambda block: block['bbox'][0])
98
+
99
+ def sort_blocks_by_min_y(blocks):
100
+ return sorted(blocks, key=lambda block: block['bbox'][1])
101
+
102
+ def find_headers_and_group(sorted_blocks):
103
+ headers_list = []
104
+ current_group = []
105
+ previous_block = None
106
+
107
+ for i, block in enumerate(sorted_blocks):
108
+ if previous_block:
109
+ prev_xmax = previous_block['bbox'][2]
110
+ prev_xmax_threshold = int(previous_block['bbox'][2])
111
+ if block['bbox'][0] > prev_xmax and block['bbox'][0] > prev_xmax_threshold:
112
+ if current_group:
113
+ headers_list.extend(sort_blocks_by_min_y(current_group))
114
+ current_group = []
115
+ current_group.append(block)
116
+ previous_block = block
117
+
118
+ if current_group:
119
+ headers_list.extend(sort_blocks_by_min_y(current_group))
120
+
121
+ return headers_list
122
+
123
+ result = {}
124
+ for key, blocks in data.items():
125
+ sorted_blocks = sort_blocks_by_min_x(blocks)
126
+ sorted_headers = find_headers_and_group(sorted_blocks)
127
+ result[key] = sorted_headers
128
+
129
+ sorted_data = result
130
+ with open(modified_json_output_filepath, 'w') as f:
131
+ json.dump(sorted_data, f, indent=4)
132
+
133
+ return sorted_data, modified_json_output_filepath
134
+
135
+ def filter_and_sort_layouts(data, modified_json_output_filepath):
136
+ def sort_blocks_by_min_x(blocks):
137
+ return sorted(blocks, key=lambda block: block['bbox'][0])
138
+
139
+ def sort_blocks_by_min_y(blocks):
140
+ return sorted(blocks, key=lambda block: block['bbox'][1])
141
+
142
+ def find_classes_and_group(sorted_blocks):
143
+ classes_list = []
144
+ current_group = []
145
+ previous_block = None
146
+
147
+ for i, block in enumerate(sorted_blocks):
148
+ if previous_block:
149
+ prev_xmax = previous_block['bbox'][2]
150
+ prev_xmax_threshold = int(previous_block['bbox'][2])
151
+ if block['bbox'][0] > prev_xmax and block['bbox'][0] > prev_xmax_threshold:
152
+ if current_group:
153
+ classes_list.extend(sort_blocks_by_min_y(current_group))
154
+ current_group = []
155
+ current_group.append(block)
156
+ previous_block = block
157
+
158
+ if current_group:
159
+ classes_list.extend(sort_blocks_by_min_y(current_group))
160
+
161
+ return classes_list
162
+
163
+ result = {}
164
+ for key, blocks in data.items():
165
+ sorted_blocks = sort_blocks_by_min_x(blocks)
166
+ sorted_layouts = find_classes_and_group(sorted_blocks)
167
+ result[key] = sorted_layouts
168
+
169
+ sorted_layout_data = result
170
+ with open(modified_json_output_filepath, 'w') as f:
171
+ json.dump(sorted_layout_data, f, indent=4)
172
+
173
+ return sorted_layout_data, modified_json_output_filepath
174
+
175
+ @log_time_taken
176
+ def layout_detection(img_path, model, image_processor, threshold=0.6, device='cuda' if torch.cuda.is_available() else 'cpu'):
177
+ try:
178
+ image = Image.open(img_path).convert("RGB")
179
+
180
+ # Process image with the Docling Heron model
181
+ inputs = image_processor(images=[image], return_tensors="pt")
182
+
183
+ # Move inputs to the same device as the model
184
+ inputs = {k: v.to(device) for k, v in inputs.items()}
185
+
186
+ with torch.no_grad():
187
+ outputs = model(**inputs)
188
+
189
+ # Post-process the results
190
+ results = image_processor.post_process_object_detection(
191
+ outputs,
192
+ target_sizes=torch.tensor([image.size[::-1]], device=device),
193
+ threshold=threshold
194
+ )[0]
195
+
196
+ # Move results to CPU for further processing
197
+ results = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in results.items()}
198
+
199
+ # Convert to supervision Detections format for compatibility
200
+ xyxy = results["boxes"].numpy()
201
+ confidence = results["scores"].numpy()
202
+ class_id = results["labels"].numpy()
203
+ class_name = [model.config.id2label[label_id] for label_id in class_id]
204
+
205
+ detections = sv.Detections(
206
+ xyxy=xyxy,
207
+ confidence=confidence,
208
+ class_id=class_id,
209
+ data={"class_name": class_name}
210
+ )
211
+
212
+ # Custom bounding box color (Red)
213
+ bbox_color = sv.Color(r=255, g=0, b=0)
214
+ bounding_box_annotator = sv.BoxAnnotator(color=bbox_color)
215
+ label_annotator = sv.LabelAnnotator()
216
+
217
+ # Annotate the image
218
+ image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
219
+ annotated_image = bounding_box_annotator.annotate(scene=image_cv, detections=detections)
220
+ annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
221
+
222
+ # Clean up
223
+ del inputs, outputs
224
+ torch.cuda.empty_cache() if device == 'cuda' else None
225
+ gc.collect()
226
+
227
+ return annotated_image, detections, results
228
+
229
+ except Exception as e:
230
+ logger.error(f"Error in layout_detection for {img_path}: {str(e)}")
231
+ raise
232
+
233
+ def enhance_dpi(image, new_dpi=300, old_dpi=150):
234
+ old_dpi = int(old_dpi)
235
+ new_dpi = int(new_dpi)
236
+ scaling_factor = new_dpi / old_dpi
237
+ new_size = (int(image.width * scaling_factor), int(image.height * scaling_factor))
238
+ resized_image = image.resize(new_size, Image.LANCZOS)
239
+ return resized_image
240
+
241
+ def extract_text_from_bbox(image, bbox):
242
+ if isinstance(image, Image.Image):
243
+ image = np.array(image)
244
+ elif isinstance(image, np.ndarray):
245
+ pass
246
+ else:
247
+ raise TypeError("Unsupported image type. The image should be either a PIL Image or a NumPy array.")
248
+
249
+ image_height, image_width = image.shape[:2]
250
+ ymin = max(0, int(bbox['ymin'] - 5))
251
+ ymax = min(image_height, int(bbox['ymax'] + 5))
252
+ xmin = max(0, int(bbox['xmin'] - 20))
253
+ xmax = min(image_width, int(bbox['xmax'] + 20))
254
+
255
+ cropped_image = image[ymin:ymax, xmin:xmax]
256
+ cropped_image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
257
+ high_dpi_image = enhance_dpi(cropped_image_pil)
258
+ high_dpi_image_cv = cv2.cvtColor(np.array(high_dpi_image), cv2.COLOR_RGB2BGR)
259
+ gray_image = cv2.cvtColor(high_dpi_image_cv, cv2.COLOR_BGR2GRAY)
260
+
261
+ custom_config = r'--oem 3 --psm 6 -c tessedit_create_alto=1'
262
+ extracted_text = pytesseract.image_to_string(gray_image, config=custom_config)
263
+
264
+ return extracted_text
265
+
266
+ def check_extracted_text_headers(extracted_text, header_list, model_name='all-MiniLM-L6-v2', threshold=0.8):
267
+ if not isinstance(extracted_text, pd.DataFrame):
268
+ return False
269
+
270
+ model = SentenceTransformer(model_name)
271
+ extracted_headers = list(extracted_text.columns)
272
+ extracted_embeddings = model.encode(extracted_headers, convert_to_tensor=True)
273
+ header_embeddings = model.encode(header_list, convert_to_tensor=True)
274
+
275
+ similarity_matrix = util.pytorch_cos_sim(header_embeddings, extracted_embeddings)
276
+
277
+ for i, header in enumerate(header_list):
278
+ for j, extracted_header in enumerate(extracted_headers):
279
+ if similarity_matrix[i][j] > threshold:
280
+ logger.info(f"Matching header found: {extracted_header} (similar to {header})")
281
+ return True
282
+
283
+ logger.info("No matching headers found.")
284
+ return False
285
+
286
+ def process_page(args):
287
+ (page_img, current_page_num, file_name, pdf_images_path, bbox_images_path) = args
288
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
289
+ try:
290
+ model, image_processor, class_names = load_model_for_process(model_name=MODEL_NAME_DOCLING)
291
+ model.to(device) # Ensure model is on the correct device
292
+ image = np.array(page_img)
293
+
294
+ h, w, _ = image.shape
295
+ page_number = str(current_page_num)
296
+
297
+ img_output_filename = f"{file_name}_page_no_{page_number}.jpeg"
298
+ img_output_filepath = os.path.join(pdf_images_path, img_output_filename)
299
+ pil_image = Image.fromarray(image)
300
+ pil_image.save(img_output_filepath)
301
+
302
+ cropped_images_path = os.path.join(pdf_images_path, f"{file_name}_cropped_images")
303
+ os.makedirs(cropped_images_path, exist_ok=True)
304
+
305
+ bbox_image, page_detections_info, results_info = layout_detection(img_output_filepath, model, image_processor, device=device)
306
+ logger.info(f"Processed layout detection for page {page_number}")
307
+
308
+ pil_bbox_image = Image.fromarray(bbox_image)
309
+ bbox_output_filename = f"bbox_{file_name}_page_no_{page_number}.jpeg"
310
+ bbox_output_filepath = os.path.join(bbox_images_path, bbox_output_filename)
311
+ pil_bbox_image.save(bbox_output_filepath)
312
+ page_information = []
313
+
314
+ for idx, bbox in enumerate(page_detections_info.xyxy):
315
+ label_name = page_detections_info.data['class_name'][idx]
316
+ class_id = page_detections_info.class_id[idx]
317
+ score = page_detections_info.confidence[idx]
318
+
319
+ image_height = h
320
+ image_width = w
321
+
322
+ ymin = max(0, bbox[1] - 10)
323
+ ymax = min(image_height, bbox[3] + 10)
324
+ xmin = max(0, bbox[0] - 10)
325
+ xmax = min(image_width, bbox[2] + 10)
326
+
327
+ new_bbox = {
328
+ "xmin": int(bbox[0]),
329
+ "ymin": int(bbox[1]),
330
+ "xmax": int(bbox[2]),
331
+ "ymax": int(bbox[3])
332
+ }
333
+
334
+ cropped_labels_images_path = os.path.join(cropped_images_path, f"{file_name}_{label_name}_cropped_images")
335
+ os.makedirs(cropped_labels_images_path, exist_ok=True)
336
+
337
+ crop_label_image_filename = f"{file_name}_label_name{label_name}_page_no_{page_number}_id_{idx + 1}.png"
338
+ crop_label_image_filename_filepath = os.path.join(cropped_labels_images_path, crop_label_image_filename)
339
+
340
+ crop_label_image_bbox = (new_bbox["xmin"], new_bbox["ymin"], new_bbox["xmax"], new_bbox["ymax"])
341
+ cropped_label_pil_image = pil_image.crop(crop_label_image_bbox)
342
+ cropped_label_pil_image.save(crop_label_image_filename_filepath)
343
+
344
+ if label_name == 'Table':
345
+ crop_bbox = (new_bbox["xmin"], new_bbox["ymin"], new_bbox["xmax"], new_bbox["ymax"])
346
+ cropped_image = pil_image.crop(crop_bbox)
347
+ df_post_processed, df_original = tsr_inference_image(cropped_image)
348
+ extracted_df = df_post_processed
349
+ extracted_text = extracted_df
350
+
351
+ if isinstance(df_original, pd.DataFrame):
352
+ extracted_df_markdown = df_original.to_markdown()
353
+ else:
354
+ extracted_df_markdown = df_original
355
+ else:
356
+ extracted_text = extract_text_from_bbox(image, new_bbox)
357
+ extracted_df_markdown = ""
358
+
359
+ page_block_id = f"{str(idx + 1) + str(current_page_num)}"
360
+ page_block_id = int(page_block_id)
361
+
362
+ page_information.append({
363
+ 'page_block_id': page_block_id,
364
+ 'label_name': label_name,
365
+ 'pdf_page_id': current_page_num,
366
+ 'pdf_name': file_name,
367
+ 'label_id': class_id,
368
+ 'yolo_detection_confidence_score': score,
369
+ 'bbox': [xmin, ymin, xmax, ymax],
370
+ 'page_img_width': w,
371
+ 'page_img_height': h,
372
+ 'extracted_text': [extracted_text],
373
+ "extracted_table_markdown": [extracted_df_markdown]
374
+ })
375
+
376
+ # Clean up
377
+ del image, bbox_image, model, image_processor
378
+ torch.cuda.empty_cache() if device == 'cuda' else None
379
+ gc.collect()
380
+
381
+ return page_number, page_information, class_names
382
+
383
+ except Exception as e:
384
+ logger.error(f"Error processing page {current_page_num}: {str(e)}")
385
+ raise
386
+
387
+ @log_time_taken
388
+ def yolov10_layout_pipeline(file_name, file_path, directory_path):
389
+ if not file_path.lower().endswith('.pdf'):
390
+ raise ValueError("Input file must be a PDF.")
391
+
392
+ logger.info(f"Starting processing for {file_name}")
393
+ start_time = datetime.now()
394
+ file_name = get_file_name_without_extension(file_path)
395
+
396
+ pdf_images_path = os.path.join(directory_path, f"{file_name}_images")
397
+ os.makedirs(pdf_images_path, exist_ok=True)
398
+
399
+ bbox_images_path = os.path.join(pdf_images_path, f"{file_name}_bbox_images")
400
+ os.makedirs(bbox_images_path, exist_ok=True)
401
+
402
+ json_output_path = os.path.join(directory_path, f"{file_name}_json_output")
403
+ os.makedirs(json_output_path, exist_ok=True)
404
+
405
+ total_pages_processed = 0
406
+ data_pdf = {}
407
+
408
+ try:
409
+ page_generator = convert_pdf_to_images(file_path, batch_size=20, dpi=150)
410
+
411
+ page_args = []
412
+ for pages in page_generator:
413
+ if not pages:
414
+ break
415
+
416
+ for page_num, page_img in enumerate(pages):
417
+ current_page_num = total_pages_processed + page_num + 1
418
+ logger.info(f"Processing file {file_name}, page {current_page_num}")
419
+
420
+ page_args.append((
421
+ page_img,
422
+ current_page_num,
423
+ file_name,
424
+ pdf_images_path,
425
+ bbox_images_path
426
+ ))
427
+
428
+ total_pages_processed += len(pages)
429
+
430
+ logger.info(f"Total pages to process: {total_pages_processed}")
431
+ with ProcessPoolExecutor(max_workers=no_of_threads) as executor:
432
+ future_to_page = {executor.submit(process_page, arg): arg[1] for arg in page_args}
433
+ for future in as_completed(future_to_page):
434
+ page_number = future_to_page[future]
435
+ try:
436
+ result = future.result()
437
+ page_number, page_information, class_names = result
438
+ data_pdf[page_number] = page_information
439
+ except Exception as e:
440
+ logger.error(f"Error processing page {page_number}: {str(e)}")
441
+ raise
442
+
443
+ logger.info(f"Processed pages: {data_pdf.keys()}")
444
+ layout_json_file_path = os.path.join(json_output_path, f"yolo_model_detections_{file_name}.json")
445
+ user_modification_json_file_path = os.path.join(json_output_path, f"user_modified_{file_name}.json")
446
+ tree_structured_json_output_path = os.path.join(json_output_path, f"tree_structured_headers_{file_name}.json")
447
+ data_pdf = convert_numpy(data_pdf)
448
+ layout_list_data = filter_layout_blocks(data_pdf)
449
+
450
+ with open(layout_json_file_path, 'w') as json_file:
451
+ json.dump(data_pdf, json_file, indent=4)
452
+
453
+ with open(user_modification_json_file_path, 'w') as json_file:
454
+ json.dump(data_pdf, json_file, indent=4)
455
+
456
+ sorted_data, modified_json_output_filepath = filter_and_sort_headers(data_pdf, user_modification_json_file_path)
457
+ tree_structured_organized_json_data = tree_structured_headers_pipeline(user_modification_json_file_path, tree_structured_json_output_path)
458
+ sorted_layout_data, sorted_layout_json_filepath = filter_and_sort_layouts(data_pdf, layout_json_file_path)
459
+
460
+ filtered_table_header_data, filtered_table_header_data_json_path = put_table_header_pipeline(user_modification_json_file_path, json_output_path, file_name)
461
+ end_time = datetime.now()
462
+
463
+ logger.info(f"Processed {file_name} from {start_time} to {end_time}, duration: {end_time - start_time}")
464
+ logger.info(f"JSON file created at: {modified_json_output_filepath}")
465
+ return (
466
+ json_output_path,
467
+ layout_list_data,
468
+ class_names,
469
+ sorted_data,
470
+ modified_json_output_filepath,
471
+ pdf_images_path,
472
+ file_name,
473
+ sorted_layout_data,
474
+ sorted_layout_json_filepath,
475
+ tree_structured_organized_json_data,
476
+ tree_structured_json_output_path,
477
+ filtered_table_header_data,
478
+ filtered_table_header_data_json_path
479
+ )
480
+
481
+ except Exception as e:
482
+ logger.error(f"Error in yolov10_layout_pipeline: {str(e)}")
483
+ raise
484
+ finally:
485
+ # Ensure GPU memory is cleared
486
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
487
+ gc.collect()
488
+
489
+ # Example usage
490
+ if __name__ == "__main__":
491
+ pdf_path = "/shared_disk/kushal/db_str_chunking/new_ws_structured_code/Flexstone_Investor_Report_Test.pdf"
492
+ output_directory = "/shared_disk/kushal/db_str_chunking/new_ws_structured_code/clearstreet_docs/iqeq_docling_heron_bbox_images"
493
+ file_name = get_file_name_without_extension(pdf_path)
494
+ yolov10_layout_pipeline(file_name, pdf_path, output_directory)
495
+
496
+
497
+
498
+
load_model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # from ultralytics import YOLOv10
3
+ # import torch
4
+ # from config.set_config import set_configuration
5
+
6
+ # set_config_project = set_configuration()
7
+ # layout_model_weights_path = set_config_project.layout_model_weights_path
8
+ # no_of_threads = set_config_project.no_of_threads
9
+
10
+ # def load_model_for_process(detection_model_path=layout_model_weights_path):
11
+ # """
12
+ # Load model in each subprocess to avoid CUDA initialization issues
13
+
14
+ # Returns:
15
+ # Model loaded in appropriate device
16
+ # """
17
+ # # Your model loading logic
18
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ # # print(f"Using device: {device}")
20
+
21
+ # model = YOLOv10(detection_model_path).to(device)
22
+ # class_names = model.names
23
+ # class_names["11"] = "Table-header"
24
+ # class_names["12"] = "Portfolio-Company-Table"
25
+
26
+ # return model, class_names
27
+
28
+ import torch
29
+
30
+ from ultralytics import YOLO
31
+ layout_model_weights_path = "/shared_disk/kushal/db_str_chunking/new_ws_structured_code/db_structured_chunking/structure_chunking/model_weights/yolov12_epoch60.pt"
32
+
33
+
34
+ # def load_model_for_process(detection_model_path=layout_model_weights_path):
35
+ # """
36
+ # Load model in each subprocess to avoid CUDA initialization issues
37
+
38
+ # Returns:
39
+ # Model loaded in appropriate device
40
+ # """
41
+ # # Your model loading logic
42
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ # # print(f"Using device: {device}")
44
+
45
+ # model = YOLO(detection_model_path).to(device)
46
+ # class_names = model.names
47
+ # class_names["11"] = "Table-header"
48
+ # class_names["12"] = "Portfolio-Company-Table"
49
+ # print("YOLOV12"*10)
50
+
51
+ # return model, class_names
52
+
53
+
54
+ '''Below code for docling heron model'''
55
+
56
+ from transformers import RTDetrV2ForObjectDetection, RTDetrImageProcessor
57
+
58
+ def load_model_for_process(model_name="ds4sd/docling-layout-heron"):
59
+ """
60
+ Load the Docling Heron model and image processor in each subprocess to avoid CUDA initialization issues.
61
+
62
+ Returns:
63
+ Tuple of (model, image_processor, class_names)
64
+ """
65
+ device = "cuda" if torch.cuda.is_available() else "cpu"
66
+ print(f"Using device: {device}")
67
+
68
+ # Load the image processor and model
69
+ image_processor = RTDetrImageProcessor.from_pretrained(model_name)
70
+ model = RTDetrV2ForObjectDetection.from_pretrained(model_name).to(device)
71
+
72
+ # Define class names mapping
73
+ class_names = {
74
+ 0: "Caption",
75
+ 1: "Footnote",
76
+ 2: "Formula",
77
+ 3: "List-item",
78
+ 4: "Page-footer",
79
+ 5: "Page-header",
80
+ 6: "Picture",
81
+ 7: "Section-header",
82
+ 8: "Table",
83
+ 9: "Text",
84
+ 10: "Title",
85
+ 11: "Document Index",
86
+ 12: "Code",
87
+ 13: "Checkbox-Selected",
88
+ 14: "Checkbox-Unselected",
89
+ 15: "Form",
90
+ 16: "Key-Value Region",
91
+ # Additional classes for compatibility with existing pipeline
92
+ 17 : "Table-header",
93
+ 18 : "Portfolio-Company-Table"
94
+ }
95
+
96
+ return model, image_processor, class_names
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
post_process_portfolio_company_json.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from fuzzywuzzy import fuzz
4
+ from typing import List, Dict, Any
5
+ import yaml
6
+ import warnings
7
+ import pandas as pd
8
+
9
+ # Constants
10
+ # PORTFOLIO_COMPANY_LIST_IDENTIFIER = ["portfolio company or platforms", "portfolio company"]
11
+ PORTFOLIO_COMPANY_LIST_IDENTIFIER = ["portfolio company or platforms","\u20acm","$m","Unrealised fair market valuation","Realised proceeds in the period","Portfolio Company or Platforms","portfolio company", "active investment", "realized/unrealized company","Realized Company","Unrealized Company", "quoted/unquoted company", "portfolio investment", "portfolio company"]
12
+ FUZZY_MATCH_THRESHOLD = 70
13
+ EXCLUDE_COMPANY_NAMES = ["total", "subtotal","Total","Investments","Fund"]
14
+
15
+
16
+ def get_file_name_without_extension(file_path: str) -> str:
17
+ """Extract file name without extension from path."""
18
+ return os.path.splitext(os.path.basename(file_path))[0]
19
+
20
+ def fuzzy_match(text: str, patterns: List[str], threshold: int = FUZZY_MATCH_THRESHOLD) -> bool:
21
+ """Check if text fuzzy matches any of the patterns."""
22
+ text = str(text).lower()
23
+ for pattern in patterns:
24
+ if fuzz.partial_ratio(text, pattern.lower()) >= threshold:
25
+ return True
26
+ return False
27
+
28
+ def extract_portfolio_companies_from_table(table_data: Dict) -> List[str]:
29
+ """Extract company names from a portfolio company table."""
30
+ companies = []
31
+ if not table_data.get("table_info"):
32
+ return companies
33
+
34
+ # Find the company column
35
+ company_column = None
36
+ for i, header in enumerate(table_data.get("table_column_header", [])):
37
+ if fuzzy_match(header, PORTFOLIO_COMPANY_LIST_IDENTIFIER):
38
+ company_column = i
39
+ break
40
+
41
+ if company_column is None:
42
+ return companies
43
+
44
+ # Get the column name that contains companies
45
+ company_column_name = table_data["table_column_header"][company_column]
46
+ print("company_column::",company_column)
47
+ print("cpmpany_column_name::",company_column_name)
48
+
49
+ # Extract companies
50
+ for row in table_data["table_info"]:
51
+ if not isinstance(row, dict):
52
+ continue
53
+ company_name = str(row.get(company_column_name, "")).strip()
54
+ if company_name and not fuzzy_match(company_name, EXCLUDE_COMPANY_NAMES):
55
+ companies.append(company_name)
56
+
57
+ return companies
58
+
59
+ def get_portfolio_company_list(intermediate_data: List[Dict]) -> List[str]:
60
+ """Extract portfolio companies from all tables in the document."""
61
+ portfolio_companies = set()
62
+
63
+ for entry in intermediate_data:
64
+ if "table_content" not in entry:
65
+ continue
66
+ for table in entry["table_content"]:
67
+ companies = extract_portfolio_companies_from_table(table)
68
+ portfolio_companies.update(companies)
69
+
70
+ return list(portfolio_companies)
71
+
72
+ def merge_content_under_same_header(
73
+ intermediate_data: List[Dict],
74
+ portfolio_company_list: List[str],
75
+ start_index: int
76
+ ) -> Dict:
77
+ """
78
+ Merge content under the same header until next company match is found.
79
+ Returns merged content and the next index to process.
80
+ """
81
+ merged_entry = {
82
+ "header": intermediate_data[start_index]["header"],
83
+ "content": intermediate_data[start_index].get("content", ""),
84
+ "table_content": intermediate_data[start_index].get("table_content", []),
85
+ "label_name": intermediate_data[start_index]["label_name"],
86
+ "page_number": intermediate_data[start_index]["page_number"],
87
+ "pdf_name": intermediate_data[start_index]["pdf_name"]
88
+ }
89
+
90
+ current_index = start_index + 1
91
+ while current_index < len(intermediate_data):
92
+ current_entry = intermediate_data[current_index]
93
+
94
+ # Check if we're still under the same header
95
+ if current_entry["header"] != merged_entry["header"]:
96
+ break
97
+
98
+ # Check if current entry matches any portfolio company
99
+ content_match = any(company in current_entry.get("content", "")
100
+ for company in portfolio_company_list)
101
+ table_match = False
102
+ for table in current_entry.get("table_content", []):
103
+ if extract_portfolio_companies_from_table(table):
104
+ table_match = True
105
+ break
106
+
107
+ if content_match or table_match:
108
+ break
109
+
110
+ # Merge content
111
+ if "content" in current_entry:
112
+ if merged_entry["content"]:
113
+ merged_entry["content"] += "\n" + current_entry["content"]
114
+ else:
115
+ merged_entry["content"] = current_entry["content"]
116
+
117
+ # Merge tables
118
+ if "table_content" in current_entry:
119
+ merged_entry["table_content"].extend(current_entry["table_content"])
120
+
121
+ current_index += 1
122
+
123
+ return merged_entry, current_index
124
+
125
+
126
+
127
+ def process_table_page_ids(merged_output):
128
+ """
129
+ Process the data to update the page_number key by combining its existing values with unique page numbers
130
+ from table_content metadata, for pages that contain table_content.
131
+
132
+ Args:
133
+ data (dict): Input data dictionary with page numbers as keys and page content as values.
134
+
135
+ Returns:
136
+ dict: Modified data with updated page_number key including existing and metadata page numbers.
137
+ """
138
+ # Iterate through each page in the data
139
+ for current_merged_entry in merged_output:
140
+ # Only process pages that have table_content
141
+ if 'table_content' in current_merged_entry:
142
+ # Initialize a set with existing page numbers from the page_number key
143
+ existing_page_numbers = set(current_merged_entry.get('page_number', '').split(',')) if current_merged_entry.get('page_number') else set()
144
+
145
+ # Add unique page numbers from table_content metadata
146
+ for table in current_merged_entry['table_content']:
147
+ if 'metadata' in table and 'table_page_id' in table['metadata']:
148
+ existing_page_numbers.add(str(table['metadata']['table_page_id']))
149
+
150
+ # Update the page_number key with sorted, unique page numbers
151
+ if existing_page_numbers:
152
+ current_merged_entry['page_number'] = ','.join(sorted(existing_page_numbers, key=int))
153
+
154
+ return merged_output
155
+
156
+
157
+ ################################################################################################################
158
+ ## Below function for more than one occurence of underlying_assets
159
+
160
+ def merge_portfolio_company_sections(intermediate_data: List[Dict]) -> tuple[List[Dict], List[str], List[str]]:
161
+ """Merge all content and tables under the same portfolio company header until next company is found.
162
+ Returns:
163
+ - merged_output: List of merged document sections
164
+ - fuzzy_matched_companies: List of companies that were fuzzy matched in headers
165
+ - portfolio_companies: List of all portfolio companies found in tables
166
+ """
167
+ portfolio_companies = get_portfolio_company_list(intermediate_data)
168
+ print(f"Extracted portfolio companies: {portfolio_companies}")
169
+
170
+ merged_output = []
171
+ fuzzy_matched_companies = set()
172
+ current_chunk = None
173
+ active_company = None
174
+
175
+ for entry in intermediate_data:
176
+ # Find all companies in this entry's header
177
+ # header_companies = []
178
+ # for company in portfolio_companies:
179
+ # if fuzzy_match(entry["header"], [company], threshold=90):
180
+ # header_companies.append(company)
181
+ # fuzzy_matched_companies.add(company)
182
+ entry_copy = entry.copy()
183
+
184
+ header_companies = match_company_names(entry["header"], portfolio_companies)
185
+
186
+ if header_companies:
187
+ print("&"*100)
188
+ print("*"*100)
189
+ print("entry_header::", entry["header"])
190
+ print("page number of header::", entry["page_number"])
191
+
192
+ print("*"*100)
193
+ print("header_companies::", header_companies)
194
+ print("*"*100)
195
+
196
+ # If we have an active chunk, finalize it before starting new one
197
+ if current_chunk:
198
+ merged_output.append(current_chunk)
199
+ current_chunk = None
200
+ active_company = None
201
+
202
+ # Start new chunk with the first matched company
203
+ # (in case multiple companies matched, we take the first one)
204
+ active_company = header_companies[0]
205
+ current_chunk = {
206
+ "page_number": entry["page_number"],
207
+ "pdf_name": entry["pdf_name"],
208
+ "header": entry["header"],
209
+ "label_name": entry["label_name"],
210
+ "content": entry.get("content", ""),
211
+ "table_content": entry.get("table_content", []),
212
+ "matched_company": active_company
213
+ }
214
+
215
+ # If multiple companies matched, create separate chunks for others
216
+ for additional_company in header_companies[1:]:
217
+ merged_output.append({
218
+ "page_number": entry["page_number"],
219
+ "pdf_name": entry["pdf_name"],
220
+ "header": entry["header"],
221
+ "label_name": entry["label_name"],
222
+ "content": entry.get("content", ""),
223
+ "table_content": entry.get("table_content", []),
224
+ "matched_company": additional_company
225
+ })
226
+
227
+ elif current_chunk:
228
+ # Continue adding to current chunk if no new company detected
229
+ if "content" in entry:
230
+ if current_chunk["content"]:
231
+ current_chunk["content"] += "\n\n" + entry["content"]
232
+ current_chunk["page_number"] += "," + str(entry["page_number"])
233
+ page_numbers_list = list(dict.fromkeys(str(current_chunk["page_number"]).split(",")))
234
+ page_numbers_list = [num.strip() for num in page_numbers_list if num.strip()]
235
+ current_chunk["page_number"] = ",".join(page_numbers_list)
236
+
237
+ else:
238
+ current_chunk["content"] = entry["content"]
239
+ current_chunk["page_number"] = str(entry["page_number"])
240
+
241
+ if "table_content" in entry:
242
+ current_chunk["table_content"].extend(entry["table_content"])
243
+ if current_chunk["page_number"]:
244
+ if "metadata" in entry["table_content"]:
245
+ if "table_page_id" in entry["table_content"]["metadata"]:
246
+ current_chunk["page_number"] += "," + str(entry["table_content"]["metadata"]["table_page_id"])
247
+
248
+ current_chunk["page_number"] += "," + str(entry["page_number"])
249
+ page_numbers_list = list(dict.fromkeys(str(current_chunk["page_number"]).split(",")))
250
+ page_numbers_list = [num.strip() for num in page_numbers_list if num.strip()]
251
+ current_chunk["page_number"] = ",".join(page_numbers_list)
252
+
253
+ # if "page_number" in entry:
254
+ # if current_chunk["page_number"]:
255
+ # current_chunk["page_number"] += "," + str(entry["page_number"])
256
+ # else:
257
+ # current_chunk["page_number"] = str(entry["page_number"])
258
+
259
+ else:
260
+ # Ensure Unique page numbers for this entry
261
+ entry_copy = entry.copy()
262
+ if "page_number" in entry_copy :
263
+ page_numbers_list = list(dict.fromkeys(str(entry_copy["page_number"]).split(",")))
264
+ page_numbers_list = [num.strip() for num in page_numbers_list if num.strip()]
265
+ entry_copy["page_number"] = ",".join(page_numbers_list)
266
+
267
+ # Content before any company section
268
+ merged_output.append(entry_copy)
269
+
270
+ # Add the last active chunk if it exists
271
+ if current_chunk:
272
+ # Ensure Unique page numbers for last entry
273
+ page_numbers_list = list(dict.fromkeys(str(current_chunk["page_number"]).split(",")))
274
+ page_numbers_list = [num.strip() for num in page_numbers_list if num.strip()]
275
+ entry_copy["page_number"] = ",".join(page_numbers_list)
276
+ merged_output.append(current_chunk)
277
+
278
+ merged_output_new = process_table_page_ids(merged_output=merged_output)
279
+
280
+ return merged_output_new, list(fuzzy_matched_companies), portfolio_companies
281
+
282
+ ################################################################################################
283
+
284
+ ## Below code for using abbreviation funcnality
285
+
286
+ import re
287
+
288
+ def match_company_names(header_text: str, companies: List[str], threshold: int = FUZZY_MATCH_THRESHOLD) -> List[str]:
289
+ """Match company names in text, first checking header text abbreviations, then company abbreviations."""
290
+ header_text = str(header_text).lower().strip()
291
+ matched_companies = []
292
+
293
+ # Generate possible abbreviations for header_text
294
+ header_abbreviations = [
295
+ ''.join(word[0] for word in header_text.split() if word), # First letters of each word
296
+ re.sub(r'[aeiou\s]', '', header_text), # Remove vowels and spaces
297
+ header_text.replace(' ', '') # Remove spaces
298
+ ]
299
+
300
+ for company in companies:
301
+ company_lower = company.lower()
302
+
303
+ # First check: header text (full or abbreviated) against company full name
304
+ for header_pattern in [header_text] + header_abbreviations:
305
+ if fuzz.partial_ratio(header_pattern, company_lower) >= threshold:
306
+ matched_companies.append(company)
307
+ break
308
+ else:
309
+ # Second check: header text against company abbreviations
310
+ company_abbreviations = [
311
+ ''.join(word[0] for word in company_lower.split() if word), # First letters of each word
312
+ re.sub(r'[aeiou\s]', '', company_lower), # Remove vowels and spaces
313
+ company_lower.replace(' ', '') # Remove spaces
314
+ ]
315
+ for company_pattern in company_abbreviations:
316
+ if fuzz.partial_ratio(header_text, company_pattern) >= threshold:
317
+ matched_companies.append(company)
318
+ break
319
+
320
+ return list(dict.fromkeys(matched_companies)) # Remove duplicates while preserving order
321
+
322
+
323
+ ################################################################################################################
324
+
325
+ def process_document_company_wise(
326
+ intermediate_str_chunk_json: List[Dict],
327
+ output_directory: str,
328
+ file_name: str
329
+ ) -> List[Dict]:
330
+ """Process the document and return merged content in original format."""
331
+ # Convert string input to dict if needed
332
+ if isinstance(intermediate_str_chunk_json, str):
333
+ intermediate_str_chunk_json = json.loads(intermediate_str_chunk_json)
334
+
335
+ # Merge content by company sections
336
+ # merged_content,matched_company_list = merge_portfolio_company_sections(intermediate_str_chunk_json)
337
+ merged_content,matched_company_list,portfolio_company_list = merge_portfolio_company_sections(intermediate_str_chunk_json)
338
+ # merged_content[0]["companies_list"] = matched_company_list
339
+ merged_content[0]["portfolio_companies_list_fuzzy_matched"] = matched_company_list
340
+ merged_content[0]["portfolio_companies_list_before"] = portfolio_company_list
341
+
342
+ # Ensure output directory exists
343
+ os.makedirs(output_directory, exist_ok=True)
344
+
345
+ # Save output
346
+ output_path = os.path.join(output_directory, f"{file_name}_h2h_merged_output.json")
347
+ with open(output_path, "w", encoding="utf-8") as f:
348
+ json.dump(merged_content, f, indent=4, ensure_ascii=False)
349
+ print(f"Saved merged output to {output_path}")
350
+
351
+ return merged_content
352
+
353
+
354
+ def read_json(file_path):
355
+ """Reads a JSON file and returns the parsed data."""
356
+ with open(file_path, 'r', encoding='utf-8') as file:
357
+ data = json.load(file)
358
+ return data
359
+
360
+
361
+ # # Example usage
362
+ if __name__ == "__main__":
363
+ input_str_chunk_json_path="/shared_disk/kushal/db_str_chunking/new_ws_structured_code/Triton2023Q4_patria_sample_output/Triton2023Q4_patria_sample_json_output/Triton2023Q4_patria_sample_final_h2h_extraction.json"
364
+ input_json = read_json(input_str_chunk_json_path)
365
+
366
+ # Process the data
367
+ result = process_document_company_wise(
368
+ intermediate_str_chunk_json=input_json,
369
+ output_directory="db_structured_chunking/structure_chunking/src/iqeq_modification/testing_sample/output",
370
+ file_name="sample_report"
371
+ )
372
+
373
+ print("Processing complete.")
374
+ # print(json.dumps(result, indent=2))
375
+