pavansuresh commited on
Commit
8822f53
·
verified ·
1 Parent(s): 27937fa

Update ai_mapping.py

Browse files
Files changed (1) hide show
  1. ai_mapping.py +21 -13
ai_mapping.py CHANGED
@@ -15,17 +15,18 @@ tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
15
  feature_extractor = LayoutLMv3ImageProcessor(apply_ocr=False)
16
  model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
17
 
18
- def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str, str]:
19
  """
20
  Extract key-value pairs from PDF text using LayoutLMv3-base or fallback to regex.
21
  Args:
22
- text_data (str): Extracted text from PDF.
23
  pdf_path (str): Path to the PDF file.
24
  Returns:
25
  dict: Key-value pairs extracted from the document.
26
  """
27
  try:
28
- # Fallback to regex if model is untrained
 
29
  key_values = {}
30
  dates = re.findall(r'\d{1,2}/\d{1,2}/\d{4}', text_data)
31
  amounts = re.findall(r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?', text_data)
@@ -34,17 +35,23 @@ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str,
34
 
35
  # Attempt LayoutLMv3 processing
36
  doc = fitz.open(pdf_path)
37
- for page_num in range(len(doc)):
 
 
 
 
38
  page = doc[page_num]
39
  pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI
40
  img_path = f"{pdf_path}_page_{page_num}.png"
41
  pix.save(img_path)
42
  image = Image.open(img_path).convert("RGB")
43
 
44
- # Tokenize text
45
- words = text_data.splitlines()
 
46
  encoding = tokenizer(
47
  words,
 
48
  return_tensors="pt",
49
  truncation=True,
50
  padding=True,
@@ -52,16 +59,17 @@ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str,
52
  )
53
  input_ids = encoding["input_ids"]
54
  attention_mask = encoding["attention_mask"]
 
55
 
56
- # Process image to get bounding boxes
57
  image_encoding = feature_extractor(image, return_tensors="pt")
58
- bbox = image_encoding["bbox"][0] # Shape: (num_tokens, 4)
59
 
60
  # Ensure bbox length matches input_ids
61
- if len(bbox) < len(input_ids[0]):
62
- bbox = torch.cat([bbox, torch.zeros((len(input_ids[0]) - len(bbox), 4), dtype=torch.int64)])
63
- elif len(bbox) > len(input_ids[0]):
64
- bbox = bbox[:len(input_ids[0])]
65
 
66
  # Pass inputs to the model
67
  with torch.no_grad():
@@ -69,7 +77,7 @@ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str,
69
  input_ids=input_ids,
70
  attention_mask=attention_mask,
71
  bbox=bbox,
72
- pixel_values=image_encoding["pixel_values"]
73
  )
74
  predictions = torch.argmax(outputs.logits, dim=2)
75
 
 
15
  feature_extractor = LayoutLMv3ImageProcessor(apply_ocr=False)
16
  model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
17
 
18
+ def extract_key_values_with_layoutlm(page_data: list, pdf_path: str) -> Dict[str, str]:
19
  """
20
  Extract key-value pairs from PDF text using LayoutLMv3-base or fallback to regex.
21
  Args:
22
+ page_data (list): List of dictionaries with 'text' (str) and 'bbox' (list of [x0, y0, x1, y1]) per page.
23
  pdf_path (str): Path to the PDF file.
24
  Returns:
25
  dict: Key-value pairs extracted from the document.
26
  """
27
  try:
28
+ # Fallback to regex using concatenated text from all pages
29
+ text_data = " ".join([page["text"] for page in page_data])
30
  key_values = {}
31
  dates = re.findall(r'\d{1,2}/\d{1,2}/\d{4}', text_data)
32
  amounts = re.findall(r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?', text_data)
 
35
 
36
  # Attempt LayoutLMv3 processing
37
  doc = fitz.open(pdf_path)
38
+ for page_num, page_info in enumerate(page_data):
39
+ if not page_info["text"].strip() or "No text detected" in page_info["text"]:
40
+ continue
41
+
42
+ # Load image for the page
43
  page = doc[page_num]
44
  pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI
45
  img_path = f"{pdf_path}_page_{page_num}.png"
46
  pix.save(img_path)
47
  image = Image.open(img_path).convert("RGB")
48
 
49
+ # Tokenize text and prepare bounding boxes
50
+ words = page_info["text"].split()
51
+ bboxes = page_info["bbox"]
52
  encoding = tokenizer(
53
  words,
54
+ boxes=bboxes,
55
  return_tensors="pt",
56
  truncation=True,
57
  padding=True,
 
59
  )
60
  input_ids = encoding["input_ids"]
61
  attention_mask = encoding["attention_mask"]
62
+ bbox = encoding["bbox"]
63
 
64
+ # Process image for pixel values
65
  image_encoding = feature_extractor(image, return_tensors="pt")
66
+ pixel_values = image_encoding["pixel_values"]
67
 
68
  # Ensure bbox length matches input_ids
69
+ if len(bbox[0]) < len(input_ids[0]):
70
+ bbox = torch.cat([bbox, torch.zeros((1, len(input_ids[0]) - len(bbox[0]), 4), dtype=torch.int64)])
71
+ elif len(bbox[0]) > len(input_ids[0]):
72
+ bbox = bbox[:, :len(input_ids[0])]
73
 
74
  # Pass inputs to the model
75
  with torch.no_grad():
 
77
  input_ids=input_ids,
78
  attention_mask=attention_mask,
79
  bbox=bbox,
80
+ pixel_values=pixel_values
81
  )
82
  predictions = torch.argmax(outputs.logits, dim=2)
83