heerjtdev commited on
Commit
bbc046a
Β·
verified Β·
1 Parent(s): 66a0ed8

Update working_yolo_pipeline.py

Browse files
Files changed (1) hide show
  1. working_yolo_pipeline.py +19 -26
working_yolo_pipeline.py CHANGED
@@ -2075,10 +2075,6 @@ def load_image_as_fitz_page(image_path: str) -> Tuple[fitz.Document, fitz.Page]:
2075
  doc = fitz.open("pdf", pdf_stream.read())
2076
  return doc, doc[0]
2077
 
2078
-
2079
-
2080
-
2081
-
2082
  def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
2083
  """
2084
  Modified pipeline that handles both PDFs and Images, running YOLO,
@@ -2088,7 +2084,6 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
2088
  yolo_model = YOLO(WEIGHTS_PATH)
2089
 
2090
  # 2. DETECT FILE TYPE
2091
- # FIX: [1] added to get the extension string from the (root, ext) tuple
2092
  ext = os.path.splitext(input_path)[1].lower()
2093
  is_image = ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
2094
 
@@ -2098,10 +2093,8 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
2098
  try:
2099
  if is_image:
2100
  print(f"πŸ“Έ Image detected: {input_path}. Processing with YOLO + Tesseract.")
2101
- # Use the corrected helper function defined above
2102
  doc, page = load_image_as_fitz_page(input_path)
2103
 
2104
- # Render for YOLO
2105
  pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
2106
  img_np = pixmap_to_numpy(pix)
2107
 
@@ -2112,7 +2105,6 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
2112
  all_pages_data.append(page_data)
2113
  doc.close()
2114
  else:
2115
- # --- ORIGINAL PDF LOGIC ---
2116
  doc = fitz.open(input_path)
2117
  print(f"πŸ“„ Processing PDF: {pdf_name} ({len(doc)} pages)")
2118
  for page_index in range(len(doc)):
@@ -2131,26 +2123,14 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
2131
  print("❌ No data extracted.")
2132
  return None
2133
 
2134
- # # 3. CONSOLIDATE BLOCKS FOR INFERENCE
2135
- # sequential_blocks = []
2136
- # for p_data in all_pages_data:
2137
- # sequential_blocks.extend(p_data.get('blocks', []))
2138
-
2139
- # 3. CONSOLIDATE BLOCKS FOR INFERENCE
2140
  sequential_blocks = []
2141
  for p_data in all_pages_data:
2142
  if isinstance(p_data, dict):
2143
- # If it's a dictionary, extract the 'blocks' key
2144
  blocks = p_data.get('blocks', [])
2145
  sequential_blocks.extend(blocks)
2146
  elif isinstance(p_data, list):
2147
- # If it's already a list, add it directly
2148
  sequential_blocks.extend(p_data)
2149
- else:
2150
- print(f"⚠️ Warning: Unexpected data type in all_pages_data: {type(p_data)}")
2151
-
2152
-
2153
-
2154
 
2155
  # --- 4. STARTING LAYOUTLMV3 INFERENCE ---
2156
  print("\n" + "=" * 80)
@@ -2160,10 +2140,26 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
2160
  tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
2161
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2162
 
2163
- # Note: Ensure LayoutLMv3ForTokenClassification is defined in your script
2164
  model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS)
 
 
 
2165
  checkpoint = torch.load(layoutlmv3_model_path, map_location=device)
2166
- model.load_state_dict(checkpoint.get('model_state_dict', checkpoint))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2167
  model.to(device)
2168
  model.eval()
2169
 
@@ -2178,7 +2174,6 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
2178
  return final_result
2179
 
2180
  except Exception as e:
2181
- # Improved error logging to catch exactly where it fails
2182
  import traceback
2183
  traceback.print_exc()
2184
  print(f"❌ FATAL ERROR in pipeline: {e}")
@@ -2186,8 +2181,6 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
2186
 
2187
 
2188
 
2189
-
2190
-
2191
 
2192
  # #================================================================================
2193
  # # --- NEW FINAL STEP: HIERARCHICAL CLASSIFICATION TAGGING ---
 
2075
  doc = fitz.open("pdf", pdf_stream.read())
2076
  return doc, doc[0]
2077
 
 
 
 
 
2078
  def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
2079
  """
2080
  Modified pipeline that handles both PDFs and Images, running YOLO,
 
2084
  yolo_model = YOLO(WEIGHTS_PATH)
2085
 
2086
  # 2. DETECT FILE TYPE
 
2087
  ext = os.path.splitext(input_path)[1].lower()
2088
  is_image = ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
2089
 
 
2093
  try:
2094
  if is_image:
2095
  print(f"πŸ“Έ Image detected: {input_path}. Processing with YOLO + Tesseract.")
 
2096
  doc, page = load_image_as_fitz_page(input_path)
2097
 
 
2098
  pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
2099
  img_np = pixmap_to_numpy(pix)
2100
 
 
2105
  all_pages_data.append(page_data)
2106
  doc.close()
2107
  else:
 
2108
  doc = fitz.open(input_path)
2109
  print(f"πŸ“„ Processing PDF: {pdf_name} ({len(doc)} pages)")
2110
  for page_index in range(len(doc)):
 
2123
  print("❌ No data extracted.")
2124
  return None
2125
 
2126
+ # 3. CONSOLIDATE BLOCKS FOR INFERENCE (Safe against List vs Dict)
 
 
 
 
 
2127
  sequential_blocks = []
2128
  for p_data in all_pages_data:
2129
  if isinstance(p_data, dict):
 
2130
  blocks = p_data.get('blocks', [])
2131
  sequential_blocks.extend(blocks)
2132
  elif isinstance(p_data, list):
 
2133
  sequential_blocks.extend(p_data)
 
 
 
 
 
2134
 
2135
  # --- 4. STARTING LAYOUTLMV3 INFERENCE ---
2136
  print("\n" + "=" * 80)
 
2140
  tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
2141
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2142
 
 
2143
  model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS)
2144
+
2145
+ # --- FIX: ROBUST KEY REMAPPING FOR LAYOUTLMV3 ---
2146
+
2147
  checkpoint = torch.load(layoutlmv3_model_path, map_location=device)
2148
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
2149
+
2150
+ # Rename keys from 'layoutlm.xxx' to 'layoutlmv3.xxx' if necessary
2151
+ new_state_dict = {}
2152
+ for key, value in state_dict.items():
2153
+ if key.startswith("layoutlm."):
2154
+ new_key = key.replace("layoutlm.", "layoutlmv3.", 1)
2155
+ new_state_dict[new_key] = value
2156
+ else:
2157
+ new_state_dict[key] = value
2158
+
2159
+ # Load with strict=False to handle minor metadata differences
2160
+ model.load_state_dict(new_state_dict, strict=False)
2161
+ # -----------------------------------------------
2162
+
2163
  model.to(device)
2164
  model.eval()
2165
 
 
2174
  return final_result
2175
 
2176
  except Exception as e:
 
2177
  import traceback
2178
  traceback.print_exc()
2179
  print(f"❌ FATAL ERROR in pipeline: {e}")
 
2181
 
2182
 
2183
 
 
 
2184
 
2185
  # #================================================================================
2186
  # # --- NEW FINAL STEP: HIERARCHICAL CLASSIFICATION TAGGING ---