heerjtdev commited on
Commit
660939a
·
1 Parent(s): 0208f98

Rename train_layoutlmv3.py to inference.py

Browse files
Files changed (2) hide show
  1. inference.py +394 -0
  2. train_layoutlmv3.py +0 -1
inference.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ import pdfplumber
6
+ import argparse
7
+ import time
8
+ import re
9
+ from typing import List, Dict, Any, Optional
10
+ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model
11
+ from TorchCRF import CRF
12
+
13
+ # --- Configuration (Must match training) ---
14
+ BASE_MODEL_ID = "microsoft/layoutlmv3-base"
15
+ MAX_BBOX_DIMENSION = 1000
16
+ LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-SECTION_HEADING", "I-SECTION_HEADING", "B-PASSAGE", "I-PASSAGE"]
17
+ LABEL2ID = {l: i for i, l in enumerate(LABELS)}
18
+ ID2LABEL = {i: l for l, i in LABEL2ID.items()}
19
+
20
+ # -------------------------
21
+ # Part 1: Model Architecture
22
+ # (Must be identical to training script)
23
+ # -------------------------
24
+ class LayoutLMv3CRF(nn.Module):
25
+ def __init__(self, num_labels):
26
+ super().__init__()
27
+ self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID)
28
+ hidden_size = self.layoutlm.config.hidden_size
29
+
30
+ self.classifier = nn.Sequential(
31
+ nn.Linear(hidden_size, hidden_size),
32
+ nn.GELU(),
33
+ nn.LayerNorm(hidden_size),
34
+ nn.Dropout(0.1),
35
+ nn.Linear(hidden_size, num_labels)
36
+ )
37
+ self.crf = CRF(num_labels)
38
+
39
+ def forward(self, input_ids, bbox, attention_mask, labels=None):
40
+ # Note: Your training script did not use pixel_values, so we omit them here too
41
+ outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
42
+ sequence_output = outputs.last_hidden_state
43
+ emissions = self.classifier(sequence_output)
44
+
45
+ if labels is not None:
46
+ log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
47
+ return -log_likelihood.mean()
48
+ else:
49
+ return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
50
+
51
+ # -------------------------
52
+ # Part 2: PDF Extraction
53
+ # -------------------------
54
+ def extract_pdf_data(pdf_path):
55
+ """
56
+ Extracts words and normalized bounding boxes (0-1000) from a PDF.
57
+ """
58
+ extracted_pages = []
59
+ print(f"📄 Extracting content from: {pdf_path}")
60
+
61
+ with pdfplumber.open(pdf_path) as pdf:
62
+ for page_idx, page in enumerate(pdf.pages):
63
+ width, height = page.width, page.height
64
+ words_data = page.extract_words()
65
+
66
+ page_tokens = []
67
+ page_bboxes = []
68
+
69
+ for w in words_data:
70
+ text = w['text']
71
+ # Normalize bbox to 0-1000 scale
72
+ x0 = int((w['x0'] / width) * 1000)
73
+ top = int((w['top'] / height) * 1000)
74
+ x1 = int((w['x1'] / width) * 1000)
75
+ bottom = int((w['bottom'] / height) * 1000)
76
+
77
+ # Clamp
78
+ box = [
79
+ max(0, min(x0, 1000)),
80
+ max(0, min(top, 1000)),
81
+ max(0, min(x1, 1000)),
82
+ max(0, min(bottom, 1000))
83
+ ]
84
+
85
+ page_tokens.append(text)
86
+ page_bboxes.append(box)
87
+
88
+ extracted_pages.append({
89
+ "page_id": page_idx,
90
+ "tokens": page_tokens,
91
+ "bboxes": page_bboxes
92
+ })
93
+
94
+ print(f"✅ Extracted {len(extracted_pages)} pages.")
95
+ return extracted_pages
96
+
97
+ # -------------------------
98
+ # Part 3: Inference Logic
99
+ # -------------------------
100
+ def run_inference(model, tokenizer, pages_data, device):
101
+ results = []
102
+ model.eval()
103
+
104
+ print("🧠 Running Inference...")
105
+
106
+ for page in pages_data:
107
+ tokens = page['tokens']
108
+ bboxes = page['bboxes']
109
+
110
+ if not tokens:
111
+ continue
112
+
113
+ # Tokenize
114
+ encoding = tokenizer(
115
+ tokens,
116
+ boxes=bboxes,
117
+ return_tensors="pt",
118
+ padding="max_length",
119
+ truncation=True,
120
+ max_length=512,
121
+ return_offsets_mapping=True
122
+ )
123
+
124
+ input_ids = encoding.input_ids.to(device)
125
+ bbox = encoding.bbox.to(device)
126
+ attention_mask = encoding.attention_mask.to(device)
127
+
128
+ # Predict
129
+ with torch.no_grad():
130
+ # returns list of lists (batch_size=1)
131
+ preds = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
132
+ pred_tags = preds[0] # Take first item in batch
133
+
134
+ # Align sub-word predictions back to original words
135
+ word_ids = encoding.word_ids()
136
+ aligned_data = []
137
+
138
+ previous_word_idx = None
139
+
140
+ for i, word_idx in enumerate(word_ids):
141
+ # Special tokens (None) or padding (masked) are skipped
142
+ if word_idx is None:
143
+ continue
144
+
145
+ # If we are at the start of a new word (or the only token for that word)
146
+ if word_idx != previous_word_idx:
147
+ # Get the label ID
148
+ label_id = pred_tags[i]
149
+ label_str = ID2LABEL[label_id]
150
+
151
+ # Retrieve original word text
152
+ original_word = tokens[word_idx]
153
+
154
+ aligned_data.append({
155
+ "word": original_word,
156
+ "predicted_label": label_str
157
+ })
158
+
159
+ previous_word_idx = word_idx
160
+
161
+ results.append({
162
+ "page": page['page_id'],
163
+ "data": aligned_data
164
+ })
165
+
166
+ return results
167
+
168
+ # -------------------------
169
+ # Part 4: User's Conversion Function
170
+ # -------------------------
171
+ def convert_bio_to_structured_json_relaxed(input_path: str, output_path: str) -> Optional[List[Dict[str, Any]]]:
172
+ print("\n" + "=" * 80)
173
+ print("--- 3. STARTING BIO TO STRUCTURED JSON DECODING ---")
174
+ print(f"Source: {input_path}")
175
+ print("=" * 80)
176
+
177
+ start_time = time.time()
178
+
179
+ try:
180
+ with open(input_path, 'r', encoding='utf-8') as f:
181
+ predictions_by_page = json.load(f)
182
+ print(f"✅ Successfully loaded raw predictions ({len(predictions_by_page)} pages found)")
183
+ except Exception as e:
184
+ print(f"❌ Error loading raw prediction file: {e}")
185
+ return None
186
+
187
+ predictions = []
188
+ for page_item in predictions_by_page:
189
+ if isinstance(page_item, dict) and 'data' in page_item:
190
+ predictions.extend(page_item['data'])
191
+
192
+ total_words = len(predictions)
193
+ print(f"📋 Total words to process: {total_words}")
194
+
195
+ structured_data = []
196
+ current_item = None
197
+ current_option_key = None
198
+ current_passage_buffer = []
199
+ current_text_buffer = []
200
+ first_question_started = False
201
+ last_entity_type = None
202
+ just_finished_i_option = False
203
+ is_in_new_passage = False
204
+
205
+ def finalize_passage_to_item(item, passage_buffer):
206
+ if passage_buffer:
207
+ passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip()
208
+ # print(f" ↳ [Buffer] Finalizing passage ({len(passage_buffer)} words) into current item")
209
+ if item.get('passage'):
210
+ item['passage'] += ' ' + passage_text
211
+ else:
212
+ item['passage'] = passage_text
213
+ passage_buffer.clear()
214
+
215
+ # Iterate through every predicted word
216
+ for idx, item in enumerate(predictions):
217
+ word = item['word']
218
+ label = item['predicted_label']
219
+ entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
220
+ current_text_buffer.append(word)
221
+
222
+ previous_entity_type = last_entity_type
223
+ is_passage_label = (entity_type == 'PASSAGE')
224
+
225
+ if not first_question_started:
226
+ if label != 'B-QUESTION' and not is_passage_label:
227
+ just_finished_i_option = False
228
+ is_in_new_passage = False
229
+ continue
230
+ if is_passage_label:
231
+ current_passage_buffer.append(word)
232
+ last_entity_type = 'PASSAGE'
233
+ just_finished_i_option = False
234
+ is_in_new_passage = False
235
+ continue
236
+
237
+ if label == 'B-QUESTION':
238
+ # print(f"🔍 Detection: New Question Started at word {idx}")
239
+ if not first_question_started:
240
+ header_text = ' '.join(current_text_buffer[:-1]).strip()
241
+ if header_text or current_passage_buffer:
242
+ print(f" -> Creating METADATA item for text found before first question")
243
+ metadata_item = {'type': 'METADATA', 'passage': ''}
244
+ finalize_passage_to_item(metadata_item, current_passage_buffer)
245
+ if header_text: metadata_item['text'] = header_text
246
+ structured_data.append(metadata_item)
247
+ first_question_started = True
248
+ current_text_buffer = [word]
249
+
250
+ if current_item is not None:
251
+ finalize_passage_to_item(current_item, current_passage_buffer)
252
+ current_item['text'] = ' '.join(current_text_buffer[:-1]).strip()
253
+ structured_data.append(current_item)
254
+ # print(f" -> Saved Question. Total structured items so far: {len(structured_data)}")
255
+ current_text_buffer = [word]
256
+
257
+ current_item = {
258
+ 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': ''
259
+ }
260
+ current_option_key = None
261
+ last_entity_type = 'QUESTION'
262
+ just_finished_i_option = False
263
+ is_in_new_passage = False
264
+ continue
265
+
266
+ if current_item is not None:
267
+ if is_in_new_passage:
268
+ if 'new_passage' not in current_item:
269
+ current_item['new_passage'] = word
270
+ else:
271
+ current_item['new_passage'] += f' {word}'
272
+
273
+ if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'):
274
+ # print(f" ↳ [State] Exiting new_passage mode at label {label}")
275
+ is_in_new_passage = False
276
+
277
+ if label.startswith(('B-', 'I-')):
278
+ last_entity_type = entity_type
279
+ continue
280
+
281
+ is_in_new_passage = False
282
+
283
+ if label.startswith('B-'):
284
+ if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']:
285
+ finalize_passage_to_item(current_item, current_passage_buffer)
286
+ current_passage_buffer = []
287
+
288
+ last_entity_type = entity_type
289
+
290
+ if entity_type == 'PASSAGE':
291
+ if previous_entity_type == 'OPTION' and just_finished_i_option:
292
+ # print(f" ↳ [State] Transitioning to new_passage (Option -> Passage boundary)")
293
+ current_item['new_passage'] = word
294
+ is_in_new_passage = True
295
+ else:
296
+ current_passage_buffer.append(word)
297
+
298
+ elif entity_type == 'OPTION':
299
+ current_option_key = word
300
+ current_item['options'][current_option_key] = word
301
+ just_finished_i_option = False
302
+
303
+ elif entity_type == 'ANSWER':
304
+ current_item['answer'] = word
305
+ current_option_key = None
306
+ just_finished_i_option = False
307
+
308
+ elif entity_type == 'QUESTION':
309
+ current_item['question'] += f' {word}'
310
+ just_finished_i_option = False
311
+
312
+ elif label.startswith('I-'):
313
+ if entity_type == 'QUESTION':
314
+ current_item['question'] += f' {word}'
315
+ elif entity_type == 'PASSAGE':
316
+ if previous_entity_type == 'OPTION' and just_finished_i_option:
317
+ current_item['new_passage'] = word
318
+ is_in_new_passage = True
319
+ else:
320
+ if not current_passage_buffer: last_entity_type = 'PASSAGE'
321
+ current_passage_buffer.append(word)
322
+ elif entity_type == 'OPTION' and current_option_key is not None:
323
+ current_item['options'][current_option_key] += f' {word}'
324
+ just_finished_i_option = True
325
+ elif entity_type == 'ANSWER':
326
+ current_item['answer'] += f' {word}'
327
+
328
+ just_finished_i_option = (entity_type == 'OPTION')
329
+
330
+ elif label == 'O':
331
+ pass
332
+
333
+ # Final wrap up
334
+ if current_item is not None:
335
+ print(f"🏁 Finalizing the very last item...")
336
+ finalize_passage_to_item(current_item, current_passage_buffer)
337
+ current_item['text'] = ' '.join(current_text_buffer).strip()
338
+ structured_data.append(current_item)
339
+
340
+ # Clean up and regex replacement
341
+ for item in structured_data:
342
+ if 'text' in item:
343
+ item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
344
+ if 'new_passage' in item:
345
+ item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip()
346
+
347
+ print(f"💾 Saving {len(structured_data)} items to {output_path}")
348
+ try:
349
+ with open(output_path, 'w', encoding='utf-8') as f:
350
+ json.dump(structured_data, f, indent=2, ensure_ascii=False)
351
+ print(f"✅ Decoding Complete. Total time: {time.time() - start_time:.2f}s")
352
+ except Exception as e:
353
+ print(f"⚠️ Error saving final JSON: {e}")
354
+
355
+ return structured_data
356
+
357
+ # -------------------------
358
+ # Part 5: Main Execution
359
+ # -------------------------
360
+ if __name__ == "__main__":
361
+ parser = argparse.ArgumentParser()
362
+ parser.add_argument("--pdf_path", type=str, required=True, help="Path to the PDF file")
363
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the .pth checkpoint")
364
+ parser.add_argument("--output_json", type=str, default="final_output.json", help="Path for final structured JSON")
365
+ args = parser.parse_args()
366
+
367
+ # 1. Setup Device
368
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
369
+ print(f"⚙️ Using device: {device}")
370
+
371
+ # 2. Load Model
372
+ print(f"🔄 Loading model from {args.model_path}...")
373
+ model = LayoutLMv3CRF(num_labels=len(LABELS))
374
+ # Load state dict
375
+ state_dict = torch.load(args.model_path, map_location=device)
376
+ model.load_state_dict(state_dict)
377
+ model.to(device)
378
+
379
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID)
380
+
381
+ # 3. Extract PDF Data
382
+ pages_data = extract_pdf_data(args.pdf_path)
383
+
384
+ # 4. Run Inference
385
+ raw_predictions = run_inference(model, tokenizer, pages_data, device)
386
+
387
+ # 5. Save Intermediate (BIO tagged format)
388
+ intermediate_path = "temp_inference_bio.json"
389
+ with open(intermediate_path, "w", encoding="utf-8") as f:
390
+ json.dump(raw_predictions, f, indent=2, ensure_ascii=False)
391
+ print(f"💾 Intermediate BIO predictions saved to {intermediate_path}")
392
+
393
+ # 6. Convert to Structured JSON
394
+ convert_bio_to_structured_json_relaxed(intermediate_path, args.output_json)
train_layoutlmv3.py DELETED
@@ -1 +0,0 @@
1
-