heerjtdev commited on
Commit
144aace
Β·
verified Β·
1 Parent(s): b13058c

Update working_yolo_pipeline.py

Browse files
Files changed (1) hide show
  1. working_yolo_pipeline.py +446 -4
working_yolo_pipeline.py CHANGED
@@ -96,8 +96,190 @@ except Exception as e:
96
 
97
 
98
 
 
 
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  from typing import Optional
102
 
103
  def sanitize_text(text: Optional[str]) -> str:
@@ -1634,6 +1816,130 @@ def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
1634
 
1635
 
1636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1637
  # ============================================================================
1638
  # --- PHASE 3: BIO TO STRUCTURED JSON DECODER ---
1639
  # ============================================================================
@@ -2250,6 +2556,129 @@ import glob
2250
 
2251
 
2252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2253
  def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, structured_intermediate_output_path: Optional[str] = None) -> Optional[List[Dict[str, Any]]]:
2254
  if not os.path.exists(input_pdf_path):
2255
  print(f"❌ ERROR: File not found: {input_pdf_path}")
@@ -2284,12 +2713,25 @@ def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, struc
2284
  return None
2285
  print(f"βœ… Step 1 Complete ({time.time() - p1_start:.2f}s)")
2286
 
2287
- # --- Phase 2: Inference ---
2288
- print(f"\n[Step 2/5] Inference (LayoutLMv3)...")
2289
  p2_start = time.time()
2290
- page_raw_predictions_list = run_inference_and_get_raw_words(
2291
- input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out
 
 
 
 
 
 
 
 
 
 
 
2292
  )
 
 
2293
  if not page_raw_predictions_list:
2294
  print("❌ FAILED at Step 2: Inference returned no data.")
2295
  return None
 
96
 
97
 
98
 
99
+ #=====================================================================================================================
100
+ #=====================================================================================================================
101
 
102
 
103
+
104
+ # ============================================================================
105
+ # --- CUSTOM MODEL DEFINITIONS (ADD THIS BLOCK) ---
106
+ # ============================================================================
107
+ import torch
108
+ import torch.nn as nn
109
+ import torch.nn.functional as F
110
+ from collections import Counter
111
+ import pickle
112
+
113
+ # --- CONSTANTS FOR CUSTOM MODEL ---
114
+ MODEL_FILE = "model_enhanced.pt" # Ensure this file is in your directory
115
+ VOCAB_FILE = "vocabs_enhanced.pkl" # Ensure this file is in your directory
116
+ DEVICE = torch.device("cpu") # Use "cuda" if available
117
+ MAX_CHAR_LEN = 16
118
+ EMBED_DIM = 128
119
+ CHAR_EMBED_DIM = 50
120
+ CHAR_CNN_OUT = 50
121
+ BBOX_DIM = 128
122
+ HIDDEN_SIZE = 768
123
+ SPATIAL_FEATURE_DIM = 64
124
+ POSITIONAL_DIM = 128
125
+ INFERENCE_CHUNK_SIZE = 450
126
+
127
+ LABELS = [
128
+ "O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION",
129
+ "B-ANSWER", "I-ANSWER", "B-IMAGE", "I-IMAGE",
130
+ "B-SECTION HEADING", "I-SECTION HEADING", "B-PASSAGE", "I-PASSAGE"
131
+ ]
132
+ IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
133
+
134
+ # --- CRF DEPENDENCY ---
135
+ try:
136
+ from torch_crf import CRF
137
+ except ImportError:
138
+ try:
139
+ from TorchCRF import CRF
140
+ except ImportError:
141
+ # Minimal fallback if CRF library is missing (though you should install it)
142
+ class CRF(nn.Module):
143
+ def __init__(self, *args, **kwargs): super().__init__()
144
+
145
+ # --- MODEL CLASSES ---
146
+ class Vocab:
147
+ def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
148
+ self.min_freq = min_freq
149
+ self.unk_token = unk_token
150
+ self.pad_token = pad_token
151
+ self.freq = Counter()
152
+ self.itos = []
153
+ self.stoi = {}
154
+ def __len__(self): return len(self.itos)
155
+ def __getitem__(self, token): return self.stoi.get(token, self.stoi.get(self.unk_token, 0))
156
+
157
+ class CharCNNEncoder(nn.Module):
158
+ def __init__(self, char_vocab_size, char_emb_dim, out_dim, kernel_sizes=(2, 3, 4, 5)):
159
+ super().__init__()
160
+ self.char_emb = nn.Embedding(char_vocab_size, char_emb_dim, padding_idx=0)
161
+ self.convs = nn.ModuleList([nn.Conv1d(char_emb_dim, out_dim, kernel_size=k) for k in kernel_sizes])
162
+ self.out_dim = out_dim * len(kernel_sizes)
163
+ def forward(self, char_ids):
164
+ B, L, C = char_ids.size()
165
+ emb = self.char_emb(char_ids.view(B * L, C)).transpose(1, 2)
166
+ outs = [torch.max(torch.relu(conv(emb)), dim=2)[0] for conv in self.convs]
167
+ return torch.cat(outs, dim=1).view(B, L, -1)
168
+
169
+ class SpatialAttention(nn.Module):
170
+ def __init__(self, hidden_dim):
171
+ super().__init__()
172
+ self.query = nn.Linear(hidden_dim, hidden_dim)
173
+ self.key = nn.Linear(hidden_dim, hidden_dim)
174
+ self.value = nn.Linear(hidden_dim, hidden_dim)
175
+ self.scale = hidden_dim ** 0.5
176
+ def forward(self, x, mask):
177
+ Q, K, V = self.query(x), self.key(x), self.value(x)
178
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
179
+ mask_expanded = mask.unsqueeze(1).expand_as(scores)
180
+ scores = scores.masked_fill(~mask_expanded, float('-inf'))
181
+ attn_weights = F.softmax(scores, dim=-1).masked_fill(torch.isnan(scores), 0.0)
182
+ return torch.matmul(attn_weights, V)
183
+
184
+ class MCQTagger(nn.Module):
185
+ def __init__(self, vocab_size, char_vocab_size, n_labels):
186
+ super().__init__()
187
+ self.word_emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=0)
188
+ self.char_enc = CharCNNEncoder(char_vocab_size, CHAR_EMBED_DIM, CHAR_CNN_OUT)
189
+ self.bbox_proj = nn.Sequential(nn.Linear(4, BBOX_DIM), nn.ReLU(), nn.Dropout(0.1), nn.Linear(BBOX_DIM, BBOX_DIM))
190
+ self.spatial_proj = nn.Sequential(nn.Linear(11, SPATIAL_FEATURE_DIM), nn.ReLU(), nn.Dropout(0.1))
191
+ self.context_proj = nn.Sequential(nn.Linear(8, 32), nn.ReLU(), nn.Dropout(0.1))
192
+ self.positional_encoding = nn.Embedding(512, POSITIONAL_DIM)
193
+ in_dim = (EMBED_DIM + self.char_enc.out_dim + BBOX_DIM + SPATIAL_FEATURE_DIM + 32 + POSITIONAL_DIM)
194
+ self.bilstm = nn.LSTM(in_dim, HIDDEN_SIZE // 2, num_layers=3, batch_first=True, bidirectional=True, dropout=0.3)
195
+ self.spatial_attention = SpatialAttention(HIDDEN_SIZE)
196
+ self.ff = nn.Sequential(nn.Linear(HIDDEN_SIZE * 2, HIDDEN_SIZE), nn.ReLU(), nn.Dropout(0.3), nn.Linear(HIDDEN_SIZE, n_labels))
197
+ self.crf = CRF(n_labels)
198
+ self.dropout = nn.Dropout(p=0.5)
199
+ def forward(self, words, chars, bboxes, spatial_feats, context_feats, mask):
200
+ B, L = words.size()
201
+ wemb = self.word_emb(words)
202
+ cenc = self.char_enc(chars)
203
+ benc = self.bbox_proj(bboxes)
204
+ senc = self.spatial_proj(spatial_feats)
205
+ cxt_enc = self.context_proj(context_feats)
206
+ pos = torch.arange(L, device=words.device).unsqueeze(0).expand(B, -1)
207
+ pos_enc = self.positional_encoding(pos.clamp(max=511))
208
+ enc_in = self.dropout(torch.cat([wemb, cenc, benc, senc, cxt_enc, pos_enc], dim=-1))
209
+ lengths = mask.sum(dim=1).cpu()
210
+ packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False)
211
+ packed_out, _ = self.bilstm(packed_in)
212
+ lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
213
+ attn_out = self.spatial_attention(lstm_out, mask)
214
+ emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
215
+ return self.crf.viterbi_decode(emissions, mask=mask)
216
+
217
+ # --- INJECT DEPENDENCIES FOR PICKLE LOADING ---
218
+ import sys
219
+ from types import ModuleType
220
+ train_mod = ModuleType("train_model")
221
+ sys.modules["train_model"] = train_mod
222
+ train_mod.Vocab = Vocab
223
+ train_mod.MCQTagger = MCQTagger
224
+ train_mod.CharCNNEncoder = CharCNNEncoder
225
+ train_mod.SpatialAttention = SpatialAttention
226
+
227
+
228
+
229
+
230
+
231
+
232
+ # ============================================================================
233
+ # --- CUSTOM FEATURE EXTRACTORS ---
234
+ # ============================================================================
235
+ def extract_spatial_features(tokens, idx):
236
+ curr = tokens[idx]
237
+ f = []
238
+ # Vertical distance to next
239
+ if idx < len(tokens)-1: f.append(min((tokens[idx+1]['y0'] - curr['y1'])/100.0, 1.0))
240
+ else: f.append(0.0)
241
+ # Vertical distance from prev
242
+ if idx > 0: f.append(min((curr['y0'] - tokens[idx-1]['y1'])/100.0, 1.0))
243
+ else: f.append(0.0)
244
+ # Geometry
245
+ f.extend([curr['x0']/1000.0, (curr['x1']-curr['x0'])/1000.0, (curr['y1']-curr['y0'])/1000.0])
246
+ f.extend([(curr['x0']+curr['x1'])/2000.0, (curr['y0']+curr['y1'])/2000.0, curr['x0']/1000.0])
247
+ # Aspect ratio
248
+ f.append(min(((curr['x1']-curr['x0'])/max((curr['y1']-curr['y0']),1.0))/10.0, 1.0))
249
+ # Alignment check
250
+ if idx > 0: f.append(float(abs(curr['x0'] - tokens[idx-1]['x0']) < 5))
251
+ else: f.append(0.0)
252
+ # Area
253
+ f.append(min(((curr['x1']-curr['x0'])*(curr['y1']-curr['y0']))/(1000.0**2), 1.0))
254
+ return f
255
+
256
+ def extract_context_features(tokens, idx, window=3):
257
+ f = []
258
+ def check_p(i):
259
+ t = str(tokens[i]['word']).lower().strip() # Changed 'text' to 'word' to match pipeline
260
+ return [float(bool(re.match(r'^q?\.?\d+[.:]', t))), float(bool(re.match(r'^[a-dA-D][.)]', t))), float(t.isupper() and len(t)>2)]
261
+
262
+ prev_res = [0.0, 0.0, 0.0]
263
+ for i in range(max(0, idx-window), idx):
264
+ res = check_p(i)
265
+ prev_res = [max(prev_res[j], res[j]) for j in range(3)]
266
+ f.extend(prev_res)
267
+ next_res = [0.0, 0.0, 0.0]
268
+ for i in range(idx+1, min(len(tokens), idx+window+1)):
269
+ res = check_p(i)
270
+ next_res = [max(next_res[j], res[j]) for j in range(3)]
271
+ f.extend(next_res)
272
+ dq, dopt = 1.0, 1.0
273
+ for i in range(idx+1, min(len(tokens), idx+window+1)):
274
+ t = str(tokens[i]['word']).lower().strip()
275
+ if re.match(r'^q?\.?\d+[.:]', t): dq = min(dq, (i-idx)/window)
276
+ if re.match(r'^[a-dA-D][.)]', t): dopt = min(dopt, (i-idx)/window)
277
+ f.extend([dq, dopt])
278
+ return f
279
+
280
+ #======================================================================================================================================================
281
+ #======================================================================================================================================================
282
+
283
  from typing import Optional
284
 
285
  def sanitize_text(text: Optional[str]) -> str:
 
1816
 
1817
 
1818
 
1819
+
1820
+
1821
+
1822
+
1823
+
1824
+ # ============================================================================
1825
+ # --- PHASE 2 REPLACEMENT: CUSTOM INFERENCE PIPELINE ---
1826
+ # ============================================================================
1827
+ def run_custom_inference_and_get_raw_words(preprocessed_json_path: str) -> List[Dict[str, Any]]:
1828
+ print("\n" + "=" * 80)
1829
+ print("--- 2. STARTING CUSTOM MODEL INFERENCE PIPELINE ---")
1830
+ print("=" * 80)
1831
+
1832
+ # 1. Load Resources
1833
+ if not os.path.exists(MODEL_FILE) or not os.path.exists(VOCAB_FILE):
1834
+ print("❌ Error: Missing custom model or vocab files.")
1835
+ return []
1836
+
1837
+ try:
1838
+ print(" -> Loading Vocab and Model...")
1839
+ with open(VOCAB_FILE, "rb") as f:
1840
+ word_vocab, char_vocab = pickle.load(f)
1841
+
1842
+ model = MCQTagger(len(word_vocab), len(char_vocab), len(LABELS)).to(DEVICE)
1843
+
1844
+ # Load state dict safe
1845
+ state_dict = torch.load(MODEL_FILE, map_location=DEVICE)
1846
+ model.load_state_dict(state_dict if isinstance(state_dict, dict) else state_dict.state_dict())
1847
+ model.eval()
1848
+ print("βœ… Custom Model loaded successfully.")
1849
+ except Exception as e:
1850
+ print(f"❌ Error loading custom model: {e}")
1851
+ return []
1852
+
1853
+ # 2. Load Preprocessed Data
1854
+ try:
1855
+ with open(preprocessed_json_path, 'r', encoding='utf-8') as f:
1856
+ preprocessed_data = json.load(f)
1857
+ print(f"βœ… Loaded preprocessed data for {len(preprocessed_data)} pages.")
1858
+ except Exception:
1859
+ print("❌ Error loading preprocessed JSON.")
1860
+ return []
1861
+
1862
+ final_page_predictions = []
1863
+ scale_factor = 2.0 # The pipeline scales PDF points to 2.0 for YOLO. We need to reverse this.
1864
+
1865
+ for page_data in preprocessed_data:
1866
+ page_num = page_data['page_number']
1867
+ raw_items = page_data['data']
1868
+
1869
+ if not raw_items: continue
1870
+
1871
+ # --- A. ADAPTER: Convert Pipeline Data format to Custom Model format ---
1872
+ # Pipeline Data: {'word': 'Text', 'bbox': [x1, y1, x2, y2]} (scaled by 2.0)
1873
+ # Custom Model Needed: {'word': 'Text', 'x0': x, 'y0': y, 'x1': x, 'y1': y} (PDF points)
1874
+
1875
+ tokens_for_inference = []
1876
+ for item in raw_items:
1877
+ bbox = item['bbox']
1878
+ # Revert scale to get native PDF coordinates
1879
+ x0 = bbox[0] / scale_factor
1880
+ y0 = bbox[1] / scale_factor
1881
+ x1 = bbox[2] / scale_factor
1882
+ y1 = bbox[3] / scale_factor
1883
+
1884
+ tokens_for_inference.append({
1885
+ 'word': str(item['word']), # Ensure string
1886
+ 'x0': x0, 'y0': y0, 'x1': x1, 'y1': y1,
1887
+ 'original_bbox': bbox # Keep for output
1888
+ })
1889
+
1890
+ # --- B. FEATURE EXTRACTION ---
1891
+ for i in range(len(tokens_for_inference)):
1892
+ tokens_for_inference[i]['spatial_features'] = extract_spatial_features(tokens_for_inference, i)
1893
+ tokens_for_inference[i]['context_features'] = extract_context_features(tokens_for_inference, i)
1894
+
1895
+ # --- C. BATCH INFERENCE ---
1896
+ page_raw_predictions = []
1897
+
1898
+ # Process in chunks
1899
+ for i in range(0, len(tokens_for_inference), INFERENCE_CHUNK_SIZE):
1900
+ chunk = tokens_for_inference[i : i + INFERENCE_CHUNK_SIZE]
1901
+
1902
+ # Prepare Tensors
1903
+ w_ids = torch.LongTensor([[word_vocab[t['word']] for t in chunk]]).to(DEVICE)
1904
+
1905
+ c_ids_list = []
1906
+ for t in chunk:
1907
+ chars = [char_vocab[c] for c in t['word'][:MAX_CHAR_LEN]]
1908
+ chars += [0] * (MAX_CHAR_LEN - len(chars))
1909
+ c_ids_list.append(chars)
1910
+ c_ids = torch.LongTensor([c_ids_list]).to(DEVICE)
1911
+
1912
+ bboxes = torch.FloatTensor([[[t['x0']/1000.0, t['y0']/1000.0, t['x1']/1000.0, t['y1']/1000.0] for t in chunk]]).to(DEVICE)
1913
+ s_feats = torch.FloatTensor([[t['spatial_features'] for t in chunk]]).to(DEVICE)
1914
+ c_feats = torch.FloatTensor([[t['context_features'] for t in chunk]]).to(DEVICE)
1915
+ mask = torch.ones(w_ids.size(), dtype=torch.bool).to(DEVICE)
1916
+
1917
+ # Predict
1918
+ with torch.no_grad():
1919
+ preds = model(w_ids, c_ids, bboxes, s_feats, c_feats, mask)[0]
1920
+
1921
+ # --- D. FORMAT OUTPUT ---
1922
+ for t, p in zip(chunk, preds):
1923
+ label = IDX2LABEL[p]
1924
+ # Create the exact dictionary structure expected by the rest of the pipeline
1925
+ page_raw_predictions.append({
1926
+ "word": t['word'],
1927
+ "bbox": t['original_bbox'], # Pass back the scaled bbox the pipeline uses
1928
+ "predicted_label": label,
1929
+ "page_number": page_num
1930
+ })
1931
+
1932
+ if page_raw_predictions:
1933
+ final_page_predictions.append({
1934
+ "page_number": page_num,
1935
+ "data": page_raw_predictions
1936
+ })
1937
+ print(f" -> Page {page_num} Inference Complete: {len(page_raw_predictions)} labeled words.")
1938
+
1939
+ return final_page_predictions
1940
+
1941
+
1942
+
1943
  # ============================================================================
1944
  # --- PHASE 3: BIO TO STRUCTURED JSON DECODER ---
1945
  # ============================================================================
 
2556
 
2557
 
2558
 
2559
+ # def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, structured_intermediate_output_path: Optional[str] = None) -> Optional[List[Dict[str, Any]]]:
2560
+ # if not os.path.exists(input_pdf_path):
2561
+ # print(f"❌ ERROR: File not found: {input_pdf_path}")
2562
+ # return None
2563
+
2564
+ # print("\n" + "#" * 80)
2565
+ # print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###")
2566
+ # print(f"Input: {input_pdf_path}")
2567
+ # print("#" * 80)
2568
+
2569
+ # overall_start = time.time()
2570
+ # pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0]
2571
+ # temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}")
2572
+ # os.makedirs(temp_pipeline_dir, exist_ok=True)
2573
+
2574
+ # preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json")
2575
+ # raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json")
2576
+
2577
+ # if structured_intermediate_output_path is None:
2578
+ # structured_intermediate_output_path = os.path.join(
2579
+ # temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json"
2580
+ # )
2581
+
2582
+ # final_result = None
2583
+ # try:
2584
+ # # --- Phase 1: Preprocessing ---
2585
+ # print(f"\n[Step 1/5] Preprocessing (YOLO + Masking)...")
2586
+ # p1_start = time.time()
2587
+ # preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path)
2588
+ # if not preprocessed_json_path_out:
2589
+ # print("❌ FAILED at Step 1: Preprocessing returned None.")
2590
+ # return None
2591
+ # print(f"βœ… Step 1 Complete ({time.time() - p1_start:.2f}s)")
2592
+
2593
+ # # --- Phase 2: Inference ---
2594
+ # print(f"\n[Step 2/5] Inference (LayoutLMv3)...")
2595
+ # p2_start = time.time()
2596
+ # page_raw_predictions_list = run_inference_and_get_raw_words(
2597
+ # input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out
2598
+ # )
2599
+ # if not page_raw_predictions_list:
2600
+ # print("❌ FAILED at Step 2: Inference returned no data.")
2601
+ # return None
2602
+
2603
+ # with open(raw_output_path, 'w', encoding='utf-8') as f:
2604
+ # json.dump(page_raw_predictions_list, f, indent=4)
2605
+ # print(f"βœ… Step 2 Complete ({time.time() - p2_start:.2f}s)")
2606
+
2607
+ # # --- Phase 3: Decoding ---
2608
+ # print(f"\n[Step 3/5] Decoding (BIO to Structured JSON)...")
2609
+ # p3_start = time.time()
2610
+ # structured_data_list = convert_bio_to_structured_json_relaxed(
2611
+ # raw_output_path, structured_intermediate_output_path
2612
+ # )
2613
+ # if not structured_data_list:
2614
+ # print("❌ FAILED at Step 3: BIO conversion failed.")
2615
+ # return None
2616
+
2617
+ # print("... Correcting misalignments and linking context ...")
2618
+ # structured_data_list = correct_misaligned_options(structured_data_list)
2619
+ # structured_data_list = process_context_linking(structured_data_list)
2620
+ # print(f"βœ… Step 3 Complete ({time.time() - p3_start:.2f}s)")
2621
+
2622
+ # # --- Phase 4: Base64 & LaTeX ---
2623
+ # print(f"\n[Step 4/5] Finalizing Layout (Base64 Images & LaTeX)...")
2624
+ # p4_start = time.time()
2625
+ # final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR)
2626
+ # if not final_result:
2627
+ # print("❌ FAILED at Step 4: Final formatting failed.")
2628
+ # return None
2629
+ # print(f"βœ… Step 4 Complete ({time.time() - p4_start:.2f}s)")
2630
+
2631
+ # # --- Phase 4.5: Question Type Classification ---
2632
+ # print(f"\n[Step 4.5/5] Adding Question Type Classification...")
2633
+ # p4_5_start = time.time()
2634
+ # final_result = add_question_type_validation(final_result)
2635
+ # print(f"βœ… Step 4.5 Complete ({time.time() - p4_5_start:.2f}s)")
2636
+
2637
+ # # --- Phase 5: Hierarchical Tagging ---
2638
+ # print(f"\n[Step 5/5] AI Classification (Subject/Concept Tagging)...")
2639
+ # p5_start = time.time()
2640
+ # classifier = HierarchicalClassifier()
2641
+ # if classifier.load_models():
2642
+ # final_result = post_process_json_with_inference(final_result, classifier)
2643
+ # print(f"βœ… Step 5 Complete: Tags added ({time.time() - p5_start:.2f}s)")
2644
+ # else:
2645
+ # print("⚠️ WARNING: Classifier models failed to load. Skipping tags.")
2646
+
2647
+ # # ============================================================
2648
+ # # πŸ”§ NEW STEP: FILTER OUT METADATA ENTRIES
2649
+ # # ============================================================
2650
+ # print(f"\n[Post-Processing] Removing METADATA entries...")
2651
+ # initial_count = len(final_result)
2652
+ # final_result = [item for item in final_result if item.get('type') != 'METADATA']
2653
+ # removed_count = initial_count - len(final_result)
2654
+ # print(f"βœ… Removed {removed_count} METADATA entries. {len(final_result)} questions remain.")
2655
+ # # ============================================================
2656
+
2657
+ # except Exception as e:
2658
+ # print(f"\n‼️ FATAL PIPELINE EXCEPTION:")
2659
+ # print(f"Error Message: {str(e)}")
2660
+ # traceback.print_exc()
2661
+ # return None
2662
+
2663
+ # # finally:
2664
+ # # print(f"\nCleaning up temporary files in {temp_pipeline_dir}...")
2665
+ # # try:
2666
+ # # for f in glob.glob(os.path.join(temp_pipeline_dir, '*')):
2667
+ # # os.remove(f)
2668
+ # # os.rmdir(temp_pipeline_dir)
2669
+ # # print("🧹 Cleanup successful.")
2670
+ # # except Exception as e:
2671
+ # # print(f"⚠️ Cleanup failed: {e}")
2672
+
2673
+ # total_time = time.time() - overall_start
2674
+ # print("\n" + "#" * 80)
2675
+ # print(f"### PIPELINE COMPLETE | Total Time: {total_time:.2f}s ###")
2676
+ # print("#" * 80)
2677
+
2678
+ # return final_result
2679
+
2680
+
2681
+
2682
  def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, structured_intermediate_output_path: Optional[str] = None) -> Optional[List[Dict[str, Any]]]:
2683
  if not os.path.exists(input_pdf_path):
2684
  print(f"❌ ERROR: File not found: {input_pdf_path}")
 
2713
  return None
2714
  print(f"βœ… Step 1 Complete ({time.time() - p1_start:.2f}s)")
2715
 
2716
+ # --- Phase 2: Inference (MODIFIED) ---
2717
+ print(f"\n[Step 2/5] Inference (Custom Model)...")
2718
  p2_start = time.time()
2719
+
2720
+ # -------------------------------------------------------------------------
2721
+ # --- COMMENTED OUT OLD LAYOUTLMV3 CALL FOR REVERSION ---
2722
+ # page_raw_predictions_list = run_inference_and_get_raw_words(
2723
+ # input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out
2724
+ # )
2725
+ # -------------------------------------------------------------------------
2726
+
2727
+ # --- NEW CUSTOM MODEL CALL ---
2728
+ # Note: We only pass the JSON path because the custom function
2729
+ # doesn't need to re-read the PDF or use the layoutlmv3 model path.
2730
+ page_raw_predictions_list = run_custom_inference_and_get_raw_words(
2731
+ preprocessed_json_path_out
2732
  )
2733
+ # -----------------------------
2734
+
2735
  if not page_raw_predictions_list:
2736
  print("❌ FAILED at Step 2: Inference returned no data.")
2737
  return None