pavansuresh commited on
Commit
b9ae2ff
·
verified ·
1 Parent(s): b2e3ca0

Update ai_mapping.py

Browse files
Files changed (1) hide show
  1. ai_mapping.py +17 -17
ai_mapping.py CHANGED
@@ -5,19 +5,19 @@ import pdf2image
5
  from typing import Dict, List
6
  import os
7
  from huggingface_hub import login
 
8
 
9
  # Optional: Log in to Hugging Face if using a private model
10
- # Uncomment and replace with your token if needed
11
  # login(token="your_hf_token")
12
 
13
  # Load pre-trained LayoutLMv3 models
14
  tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
15
- feature_extractor = LayoutLMv3ImageProcessor(apply_ocr=False) # Updated to ImageProcessor
16
- model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base") # Public base model
17
 
18
  def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str, str]:
19
  """
20
- Extract key-value pairs from PDF text using LayoutLMv3-base.
21
  Args:
22
  text_data (str): Extracted text from PDF.
23
  pdf_path (str): Path to the PDF file.
@@ -25,35 +25,35 @@ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str,
25
  dict: Key-value pairs extracted from the document.
26
  """
27
  try:
28
- # Convert PDF to images (one per page)
29
- images = pdf2image.convert_from_path(pdf_path)
30
-
31
- # Process each page
32
  key_values = {}
 
 
 
 
 
 
 
33
  for i, image in enumerate(images):
34
- # Preprocess image and text
35
  encoding = feature_extractor(images=[image], text=text_data.splitlines(), return_tensors="pt")
36
  input_ids = encoding["input_ids"]
37
  attention_mask = encoding["attention_mask"]
38
- # token_type_ids not needed for LayoutLMv3-base
39
 
40
- # Get model predictions
41
  with torch.no_grad():
42
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
43
  predictions = torch.argmax(outputs.logits, dim=2)
44
 
45
- # Post-process predictions to extract key-value pairs (simplified logic)
46
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
47
  labels = predictions[0].tolist()
48
  current_key = None
49
  current_value = []
50
  for token, label in zip(tokens, labels):
51
- if label == 1: # Assuming label 1 indicates a key start (adjust based on training)
52
  if current_key and current_value:
53
  key_values[current_key] = " ".join(current_value).strip()
54
  current_key = token
55
  current_value = []
56
- elif label == 2 and current_key: # Assuming label 2 indicates a value (adjust based on training)
57
  current_value.append(token)
58
  if current_key and current_value:
59
  key_values[current_key] = " ".join(current_value).strip()
@@ -64,10 +64,10 @@ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str,
64
 
65
  def run_ai_mapping_with_layoutlm(key_values: Dict[str, str], object_field_names: List[str], pdf_path: str) -> Dict:
66
  """
67
- Map extracted key-values to Salesforce fields using LayoutLMv3-base (simplified).
68
  Args:
69
  key_values (dict): Extracted key-value pairs.
70
- object_field_names (list): List of Salesforce field names.
71
  pdf_path (str): Path to the PDF file (for context if needed).
72
  Returns:
73
  dict: Mapping results with status, mappings, unmapped fields, and error (if any).
@@ -78,7 +78,7 @@ def run_ai_mapping_with_layoutlm(key_values: Dict[str, str], object_field_names:
78
 
79
  for field in object_field_names:
80
  for key, value in key_values.items():
81
- if field.lower() in key.lower(): # Simple string matching
82
  mappings[field] = value
83
  unmapped_fields.remove(field)
84
  break
 
5
  from typing import Dict, List
6
  import os
7
  from huggingface_hub import login
8
+ import re
9
 
10
  # Optional: Log in to Hugging Face if using a private model
 
11
  # login(token="your_hf_token")
12
 
13
  # Load pre-trained LayoutLMv3 models
14
  tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
15
+ feature_extractor = LayoutLMv3ImageProcessor(apply_ocr=False)
16
+ model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
17
 
18
  def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str, str]:
19
  """
20
+ Extract key-value pairs from PDF text using LayoutLMv3-base or fallback to regex.
21
  Args:
22
  text_data (str): Extracted text from PDF.
23
  pdf_path (str): Path to the PDF file.
 
25
  dict: Key-value pairs extracted from the document.
26
  """
27
  try:
28
+ # Fallback to regex if model is untrained
 
 
 
29
  key_values = {}
30
+ dates = re.findall(r'\d{1,2}/\d{1,2}/\d{4}', text_data)
31
+ amounts = re.findall(r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?', text_data)
32
+ if dates or amounts:
33
+ key_values.update({"Date": dates[0] if dates else "", "Amount": amounts[0] if amounts else ""})
34
+
35
+ # Attempt LayoutLMv3 processing
36
+ images = pdf2image.convert_from_path(pdf_path)
37
  for i, image in enumerate(images):
 
38
  encoding = feature_extractor(images=[image], text=text_data.splitlines(), return_tensors="pt")
39
  input_ids = encoding["input_ids"]
40
  attention_mask = encoding["attention_mask"]
 
41
 
 
42
  with torch.no_grad():
43
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
44
  predictions = torch.argmax(outputs.logits, dim=2)
45
 
 
46
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
47
  labels = predictions[0].tolist()
48
  current_key = None
49
  current_value = []
50
  for token, label in zip(tokens, labels):
51
+ if label == 1: # Key start (adjust based on training)
52
  if current_key and current_value:
53
  key_values[current_key] = " ".join(current_value).strip()
54
  current_key = token
55
  current_value = []
56
+ elif label == 2 and current_key: # Value (adjust based on training)
57
  current_value.append(token)
58
  if current_key and current_value:
59
  key_values[current_key] = " ".join(current_value).strip()
 
64
 
65
  def run_ai_mapping_with_layoutlm(key_values: Dict[str, str], object_field_names: List[str], pdf_path: str) -> Dict:
66
  """
67
+ Map extracted key-values to object fields using LayoutLMv3-base (simplified).
68
  Args:
69
  key_values (dict): Extracted key-value pairs.
70
+ object_field_names (list): List of object field names.
71
  pdf_path (str): Path to the PDF file (for context if needed).
72
  Returns:
73
  dict: Mapping results with status, mappings, unmapped fields, and error (if any).
 
78
 
79
  for field in object_field_names:
80
  for key, value in key_values.items():
81
+ if field.lower() in key.lower() or any(k.lower() in field.lower() for k in key_values.keys()):
82
  mappings[field] = value
83
  unmapped_fields.remove(field)
84
  break