heerjtdev commited on
Commit
71693a6
·
verified ·
1 Parent(s): a77096f

Update working_yolo_pipeline.py

Browse files
Files changed (1) hide show
  1. working_yolo_pipeline.py +40 -7
working_yolo_pipeline.py CHANGED
@@ -184,7 +184,7 @@ FIGURE_EXTRACTION_DIR = './figure_extraction'
184
  TEMP_IMAGE_DIR = './temp_pdf_images'
185
 
186
  # Detection parameters
187
- CONF_THRESHOLD = 0.2
188
  TARGET_CLASSES = ['figure', 'equation']
189
  IOU_MERGE_THRESHOLD = 0.4
190
  IOA_SUPPRESSION_THRESHOLD = 0.7
@@ -1094,22 +1094,55 @@ def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str,
1094
  # --- STEP 1: YOLO DETECTION ---
1095
  # ====================================================================
1096
  start_time_yolo = time.time()
1097
- results = model.predict(source=original_img, conf=CONF_THRESHOLD, imgsz=640, verbose=False)
 
 
 
1098
 
1099
  relevant_detections = []
 
 
 
 
 
 
 
 
 
1100
  if results and results[0].boxes:
1101
  for box in results[0].boxes:
1102
  class_id = int(box.cls[0])
1103
  class_name = model.names[class_id]
1104
- if class_name in TARGET_CLASSES:
1105
- x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
1106
- relevant_detections.append(
1107
- {'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])}
1108
- )
 
 
 
 
 
 
 
1109
 
1110
  merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD)
1111
  print(f" [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.")
1112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1113
  # ====================================================================
1114
  # --- STEP 2: PREPARE DATA FOR COLUMN DETECTION (MASKING) ---
1115
  # ====================================================================
 
184
  TEMP_IMAGE_DIR = './temp_pdf_images'
185
 
186
  # Detection parameters
187
+ # CONF_THRESHOLD = 0.2
188
  TARGET_CLASSES = ['figure', 'equation']
189
  IOU_MERGE_THRESHOLD = 0.4
190
  IOA_SUPPRESSION_THRESHOLD = 0.7
 
1094
  # --- STEP 1: YOLO DETECTION ---
1095
  # ====================================================================
1096
  start_time_yolo = time.time()
1097
+ results = model.predict(source=original_img, conf=0.2, imgsz=640, verbose=False)
1098
+
1099
+
1100
+
1101
 
1102
  relevant_detections = []
1103
+
1104
+ THRESHOLDS = {
1105
+ 'figure': 0.75,
1106
+ 'equation': 0.20
1107
+ }
1108
+
1109
+
1110
+
1111
+
1112
  if results and results[0].boxes:
1113
  for box in results[0].boxes:
1114
  class_id = int(box.cls[0])
1115
  class_name = model.names[class_id]
1116
+ conf = float(box.conf[0])
1117
+
1118
+ # Logic: Check if class is in our list AND meets its specific threshold
1119
+ if class_name in THRESHOLDS:
1120
+ if conf >= THRESHOLDS[class_name]:
1121
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
1122
+ relevant_detections.append({
1123
+ 'coords': (x1, y1, x2, y2),
1124
+ 'y1': y1,
1125
+ 'class': class_name,
1126
+ 'conf': conf
1127
+ })
1128
 
1129
  merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD)
1130
  print(f" [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.")
1131
 
1132
+
1133
+ # if results and results[0].boxes:
1134
+ # for box in results[0].boxes:
1135
+ # class_id = int(box.cls[0])
1136
+ # class_name = model.names[class_id]
1137
+ # if class_name in TARGET_CLASSES:
1138
+ # x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
1139
+ # relevant_detections.append(
1140
+ # {'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])}
1141
+ # )
1142
+
1143
+ # merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD)
1144
+ # print(f" [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.")
1145
+
1146
  # ====================================================================
1147
  # --- STEP 2: PREPARE DATA FOR COLUMN DETECTION (MASKING) ---
1148
  # ====================================================================