Spaces:
Sleeping
Sleeping
Update working_yolo_pipeline.py
Browse files- working_yolo_pipeline.py +19 -26
working_yolo_pipeline.py
CHANGED
|
@@ -2075,10 +2075,6 @@ def load_image_as_fitz_page(image_path: str) -> Tuple[fitz.Document, fitz.Page]:
|
|
| 2075 |
doc = fitz.open("pdf", pdf_stream.read())
|
| 2076 |
return doc, doc[0]
|
| 2077 |
|
| 2078 |
-
|
| 2079 |
-
|
| 2080 |
-
|
| 2081 |
-
|
| 2082 |
def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
|
| 2083 |
"""
|
| 2084 |
Modified pipeline that handles both PDFs and Images, running YOLO,
|
|
@@ -2088,7 +2084,6 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
|
|
| 2088 |
yolo_model = YOLO(WEIGHTS_PATH)
|
| 2089 |
|
| 2090 |
# 2. DETECT FILE TYPE
|
| 2091 |
-
# FIX: [1] added to get the extension string from the (root, ext) tuple
|
| 2092 |
ext = os.path.splitext(input_path)[1].lower()
|
| 2093 |
is_image = ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
|
| 2094 |
|
|
@@ -2098,10 +2093,8 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
|
|
| 2098 |
try:
|
| 2099 |
if is_image:
|
| 2100 |
print(f"πΈ Image detected: {input_path}. Processing with YOLO + Tesseract.")
|
| 2101 |
-
# Use the corrected helper function defined above
|
| 2102 |
doc, page = load_image_as_fitz_page(input_path)
|
| 2103 |
|
| 2104 |
-
# Render for YOLO
|
| 2105 |
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
|
| 2106 |
img_np = pixmap_to_numpy(pix)
|
| 2107 |
|
|
@@ -2112,7 +2105,6 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
|
|
| 2112 |
all_pages_data.append(page_data)
|
| 2113 |
doc.close()
|
| 2114 |
else:
|
| 2115 |
-
# --- ORIGINAL PDF LOGIC ---
|
| 2116 |
doc = fitz.open(input_path)
|
| 2117 |
print(f"π Processing PDF: {pdf_name} ({len(doc)} pages)")
|
| 2118 |
for page_index in range(len(doc)):
|
|
@@ -2131,26 +2123,14 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
|
|
| 2131 |
print("β No data extracted.")
|
| 2132 |
return None
|
| 2133 |
|
| 2134 |
-
#
|
| 2135 |
-
# sequential_blocks = []
|
| 2136 |
-
# for p_data in all_pages_data:
|
| 2137 |
-
# sequential_blocks.extend(p_data.get('blocks', []))
|
| 2138 |
-
|
| 2139 |
-
# 3. CONSOLIDATE BLOCKS FOR INFERENCE
|
| 2140 |
sequential_blocks = []
|
| 2141 |
for p_data in all_pages_data:
|
| 2142 |
if isinstance(p_data, dict):
|
| 2143 |
-
# If it's a dictionary, extract the 'blocks' key
|
| 2144 |
blocks = p_data.get('blocks', [])
|
| 2145 |
sequential_blocks.extend(blocks)
|
| 2146 |
elif isinstance(p_data, list):
|
| 2147 |
-
# If it's already a list, add it directly
|
| 2148 |
sequential_blocks.extend(p_data)
|
| 2149 |
-
else:
|
| 2150 |
-
print(f"β οΈ Warning: Unexpected data type in all_pages_data: {type(p_data)}")
|
| 2151 |
-
|
| 2152 |
-
|
| 2153 |
-
|
| 2154 |
|
| 2155 |
# --- 4. STARTING LAYOUTLMV3 INFERENCE ---
|
| 2156 |
print("\n" + "=" * 80)
|
|
@@ -2160,10 +2140,26 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
|
|
| 2160 |
tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
|
| 2161 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 2162 |
|
| 2163 |
-
# Note: Ensure LayoutLMv3ForTokenClassification is defined in your script
|
| 2164 |
model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS)
|
|
|
|
|
|
|
|
|
|
| 2165 |
checkpoint = torch.load(layoutlmv3_model_path, map_location=device)
|
| 2166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2167 |
model.to(device)
|
| 2168 |
model.eval()
|
| 2169 |
|
|
@@ -2178,7 +2174,6 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
|
|
| 2178 |
return final_result
|
| 2179 |
|
| 2180 |
except Exception as e:
|
| 2181 |
-
# Improved error logging to catch exactly where it fails
|
| 2182 |
import traceback
|
| 2183 |
traceback.print_exc()
|
| 2184 |
print(f"β FATAL ERROR in pipeline: {e}")
|
|
@@ -2186,8 +2181,6 @@ def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
|
|
| 2186 |
|
| 2187 |
|
| 2188 |
|
| 2189 |
-
|
| 2190 |
-
|
| 2191 |
|
| 2192 |
# #================================================================================
|
| 2193 |
# # --- NEW FINAL STEP: HIERARCHICAL CLASSIFICATION TAGGING ---
|
|
|
|
| 2075 |
doc = fitz.open("pdf", pdf_stream.read())
|
| 2076 |
return doc, doc[0]
|
| 2077 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2078 |
def run_document_pipeline(input_path: str, layoutlmv3_model_path: str):
|
| 2079 |
"""
|
| 2080 |
Modified pipeline that handles both PDFs and Images, running YOLO,
|
|
|
|
| 2084 |
yolo_model = YOLO(WEIGHTS_PATH)
|
| 2085 |
|
| 2086 |
# 2. DETECT FILE TYPE
|
|
|
|
| 2087 |
ext = os.path.splitext(input_path)[1].lower()
|
| 2088 |
is_image = ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
|
| 2089 |
|
|
|
|
| 2093 |
try:
|
| 2094 |
if is_image:
|
| 2095 |
print(f"πΈ Image detected: {input_path}. Processing with YOLO + Tesseract.")
|
|
|
|
| 2096 |
doc, page = load_image_as_fitz_page(input_path)
|
| 2097 |
|
|
|
|
| 2098 |
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
|
| 2099 |
img_np = pixmap_to_numpy(pix)
|
| 2100 |
|
|
|
|
| 2105 |
all_pages_data.append(page_data)
|
| 2106 |
doc.close()
|
| 2107 |
else:
|
|
|
|
| 2108 |
doc = fitz.open(input_path)
|
| 2109 |
print(f"π Processing PDF: {pdf_name} ({len(doc)} pages)")
|
| 2110 |
for page_index in range(len(doc)):
|
|
|
|
| 2123 |
print("β No data extracted.")
|
| 2124 |
return None
|
| 2125 |
|
| 2126 |
+
# 3. CONSOLIDATE BLOCKS FOR INFERENCE (Safe against List vs Dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2127 |
sequential_blocks = []
|
| 2128 |
for p_data in all_pages_data:
|
| 2129 |
if isinstance(p_data, dict):
|
|
|
|
| 2130 |
blocks = p_data.get('blocks', [])
|
| 2131 |
sequential_blocks.extend(blocks)
|
| 2132 |
elif isinstance(p_data, list):
|
|
|
|
| 2133 |
sequential_blocks.extend(p_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2134 |
|
| 2135 |
# --- 4. STARTING LAYOUTLMV3 INFERENCE ---
|
| 2136 |
print("\n" + "=" * 80)
|
|
|
|
| 2140 |
tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
|
| 2141 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 2142 |
|
|
|
|
| 2143 |
model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS)
|
| 2144 |
+
|
| 2145 |
+
# --- FIX: ROBUST KEY REMAPPING FOR LAYOUTLMV3 ---
|
| 2146 |
+
|
| 2147 |
checkpoint = torch.load(layoutlmv3_model_path, map_location=device)
|
| 2148 |
+
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
| 2149 |
+
|
| 2150 |
+
# Rename keys from 'layoutlm.xxx' to 'layoutlmv3.xxx' if necessary
|
| 2151 |
+
new_state_dict = {}
|
| 2152 |
+
for key, value in state_dict.items():
|
| 2153 |
+
if key.startswith("layoutlm."):
|
| 2154 |
+
new_key = key.replace("layoutlm.", "layoutlmv3.", 1)
|
| 2155 |
+
new_state_dict[new_key] = value
|
| 2156 |
+
else:
|
| 2157 |
+
new_state_dict[key] = value
|
| 2158 |
+
|
| 2159 |
+
# Load with strict=False to handle minor metadata differences
|
| 2160 |
+
model.load_state_dict(new_state_dict, strict=False)
|
| 2161 |
+
# -----------------------------------------------
|
| 2162 |
+
|
| 2163 |
model.to(device)
|
| 2164 |
model.eval()
|
| 2165 |
|
|
|
|
| 2174 |
return final_result
|
| 2175 |
|
| 2176 |
except Exception as e:
|
|
|
|
| 2177 |
import traceback
|
| 2178 |
traceback.print_exc()
|
| 2179 |
print(f"β FATAL ERROR in pipeline: {e}")
|
|
|
|
| 2181 |
|
| 2182 |
|
| 2183 |
|
|
|
|
|
|
|
| 2184 |
|
| 2185 |
# #================================================================================
|
| 2186 |
# # --- NEW FINAL STEP: HIERARCHICAL CLASSIFICATION TAGGING ---
|