aagamjtdev commited on
Commit
6deed2e
Β·
1 Parent(s): 44ea3cf

correction

Browse files
Files changed (1) hide show
  1. HF_LayoutLM_with_Passage.py +1 -120
HF_LayoutLM_with_Passage.py CHANGED
@@ -8,14 +8,10 @@ import torch.nn as nn
8
  from torch.utils.data import Dataset, DataLoader, random_split
9
  from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model
10
  from TorchCRF import CRF
11
-
12
  from torch.optim import AdamW
13
  from tqdm import tqdm
14
  from sklearn.metrics import precision_recall_fscore_support
15
- import fitz # PyMuPDF
16
- import pytesseract
17
- from PIL import Image
18
- from pdf2image import convert_from_path
19
 
20
  # --- Configuration for Augmentation ---
21
  MAX_BBOX_DIMENSION = 999
@@ -347,117 +343,6 @@ def main(args):
347
  print(f"πŸ’Ύ Model saved at {ckpt_path}")
348
 
349
 
350
- def run_inference(pdf_path, model_path, output_path):
351
- # LABELS UPDATED: Added SECTION_HEADING and PASSAGE (Must match main)
352
- labels = [
353
- "O",
354
- "B-QUESTION", "I-QUESTION",
355
- "B-OPTION", "I-OPTION",
356
- "B-ANSWER", "I-ANSWER",
357
- "B-SECTION_HEADING", "I-SECTION_HEADING",
358
- "B-PASSAGE", "I-PASSAGE"
359
- ]
360
- label2id = {l: i for i, l in enumerate(labels)}
361
- id2label = {i: l for l, i in label2id.items()}
362
- tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
363
-
364
- # Load the trained model
365
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
366
- model = LayoutLMv3CRF("microsoft/layoutlmv3-base", num_labels=len(labels)).to(device)
367
- try:
368
- model.load_state_dict(torch.load(model_path, map_location=device))
369
- except Exception as e:
370
- print(
371
- f"❌ Error loading model state: {e}. Ensure the model at {model_path} has been successfully trained with the new labels.")
372
- return
373
-
374
- model.eval()
375
-
376
- # Process PDF with OCR
377
- try:
378
- doc = fitz.open(pdf_path)
379
- except Exception as e:
380
- print(f"❌ Error opening PDF: {e}")
381
- return
382
-
383
- all_predictions = []
384
- tesseract_config = '--psm 6'
385
-
386
- for page_num in range(len(doc)):
387
- page = doc.load_page(page_num)
388
-
389
- # Get a high-resolution image of the page for Tesseract
390
- pix = page.get_pixmap(dpi=300)
391
- img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
392
-
393
- # Get page dimensions from PyMuPDF
394
- page_width, page_height = page.bound().width, page.bound().height
395
-
396
- # Get OCR data (words and bboxes)
397
- ocr_data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT, config=tesseract_config)
398
- words = [word for word in ocr_data['text'] if word.strip()]
399
-
400
- # Skip empty pages
401
- if not words:
402
- continue
403
-
404
- # Get the scaling factors from the image resolution to the PDF's native resolution
405
- x_scale = page_width / pix.width
406
- y_scale = page_height / pix.height
407
-
408
- # Create original pixel bboxes
409
- bboxes_raw = [[
410
- ocr_data['left'][i],
411
- ocr_data['top'][i],
412
- ocr_data['left'][i] + ocr_data['width'][i],
413
- ocr_data['top'][i] + ocr_data['height'][i]
414
- ] for i in range(len(ocr_data['text'])) if ocr_data['text'][i].strip()]
415
-
416
- # Normalize bboxes to 0-1000 scale using the correct scaling factors
417
- normalized_bboxes = [[
418
- int(1000 * (b[0] * x_scale) / page_width),
419
- int(1000 * (b[1] * y_scale) / page_height),
420
- int(1000 * (b[2] * x_scale) / page_width),
421
- int(1000 * (b[3] * y_scale) / page_height)
422
- ] for b in bboxes_raw]
423
-
424
- # Tokenize and run inference
425
- inputs = tokenizer(words, boxes=normalized_bboxes, return_tensors="pt", truncation=True).to(device)
426
-
427
- with torch.no_grad():
428
- # The model is run on the normalized bboxes
429
- preds = model(**inputs)
430
-
431
- # Align predictions back to words
432
- word_ids = inputs.word_ids(batch_index=0)
433
- final_preds = []
434
- previous_word_idx = None
435
- for idx, word_id in enumerate(word_ids):
436
- if word_id is not None and word_id != previous_word_idx:
437
- # The model returns a list of predicted classes for each token
438
- final_preds.append(id2label[preds[0][idx]])
439
- previous_word_idx = word_id
440
-
441
- # Prepare structured output
442
- page_results = []
443
- # Tesseract returns word list that is shorter than ocr_data if it contains empty strings.
444
- # We need to use the cleaned 'words' list and its corresponding filtered bboxes.
445
- # Note: We must ensure that the word and bbox lists passed to tokenizer and the filtered
446
- # final_preds list are all correctly aligned with the original ocr_data indices.
447
- # Since 'words' and 'bboxes_raw' are filtered exactly the same way (by word.strip()),
448
- # and 'final_preds' is aligned back to 'words', we can zip them.
449
- for word, bbox, label in zip(words, bboxes_raw, final_preds):
450
- page_results.append({
451
- "word": word,
452
- "bbox": bbox,
453
- "predicted_label": label
454
- })
455
- all_predictions.extend(page_results)
456
-
457
- doc.close()
458
- with open(output_path, "w") as f:
459
- json.dump(all_predictions, f, indent=2, ensure_ascii=False)
460
- print(f"βœ… Inference complete. Predictions saved to {output_path}")
461
 
462
 
463
  # -------------------------
@@ -478,7 +363,3 @@ if __name__ == "__main__":
478
  if not args.input:
479
  parser.error("--input is required for 'train' mode.")
480
  main(args)
481
- elif args.mode == "infer":
482
- if not args.input:
483
- parser.error("--input is required for 'infer' mode.")
484
- run_inference(args.input, "checkpoints/layoutlmv3_crf_new_passage.pth", "inference_predictions.json")
 
8
  from torch.utils.data import Dataset, DataLoader, random_split
9
  from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model
10
  from TorchCRF import CRF
 
11
  from torch.optim import AdamW
12
  from tqdm import tqdm
13
  from sklearn.metrics import precision_recall_fscore_support
14
+
 
 
 
15
 
16
  # --- Configuration for Augmentation ---
17
  MAX_BBOX_DIMENSION = 999
 
343
  print(f"πŸ’Ύ Model saved at {ckpt_path}")
344
 
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
 
348
  # -------------------------
 
363
  if not args.input:
364
  parser.error("--input is required for 'train' mode.")
365
  main(args)