pavansuresh commited on
Commit
5c08ef5
·
verified ·
1 Parent(s): 2b51034

Update ai_mapping.py

Browse files
Files changed (1) hide show
  1. ai_mapping.py +19 -14
ai_mapping.py CHANGED
@@ -1,17 +1,23 @@
1
- from transformers import LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification, LayoutLMv3FeatureExtractor
2
  import torch
3
  from PIL import Image
4
  import pdf2image
5
  from typing import Dict, List
 
 
6
 
7
- # Load pre-trained LayoutLMv3 models (adjust model names based on your fine-tuned models)
 
 
 
 
8
  tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
9
- feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False) # Set to True if OCR is needed
10
- model = LayoutLMv3ForTokenClassification.from_pretrained("path_to_finetuned_funsd_model") # Replace with your fine-tuned model
11
 
12
  def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str, str]:
13
  """
14
- Extract key-value pairs from PDF text using LayoutLMv3-finetuned-funsd.
15
  Args:
16
  text_data (str): Extracted text from PDF.
17
  pdf_path (str): Path to the PDF file.
@@ -26,14 +32,14 @@ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str,
26
  key_values = {}
27
  for i, image in enumerate(images):
28
  # Preprocess image and text
29
- encoding = feature_extractor(image, text_data.splitlines(), return_tensors="pt")
30
  input_ids = encoding["input_ids"]
31
  attention_mask = encoding["attention_mask"]
32
- token_type_ids = encoding["token_type_ids"] if "token_type_ids" in encoding else None
33
 
34
  # Get model predictions
35
  with torch.no_grad():
36
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
37
  predictions = torch.argmax(outputs.logits, dim=2)
38
 
39
  # Post-process predictions to extract key-value pairs (simplified logic)
@@ -42,23 +48,23 @@ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str,
42
  current_key = None
43
  current_value = []
44
  for token, label in zip(tokens, labels):
45
- if label == 1: # Assuming label 1 indicates a key start
46
  if current_key and current_value:
47
  key_values[current_key] = " ".join(current_value).strip()
48
  current_key = token
49
  current_value = []
50
- elif label == 2 and current_key: # Assuming label 2 indicates a value
51
  current_value.append(token)
52
  if current_key and current_value:
53
  key_values[current_key] = " ".join(current_value).strip()
54
 
55
- return key_values
56
  except Exception as e:
57
  return {"status": "failed", "error": str(e), "key_values": {}}
58
 
59
  def run_ai_mapping_with_layoutlm(key_values: Dict[str, str], object_field_names: List[str], pdf_path: str) -> Dict:
60
  """
61
- Map extracted key-values to Salesforce fields using a custom-trained Transformer.
62
  Args:
63
  key_values (dict): Extracted key-value pairs.
64
  object_field_names (list): List of Salesforce field names.
@@ -67,13 +73,12 @@ def run_ai_mapping_with_layoutlm(key_values: Dict[str, str], object_field_names:
67
  dict: Mapping results with status, mappings, unmapped fields, and error (if any).
68
  """
69
  try:
70
- # Placeholder for custom-trained Transformer logic (replace with your model)
71
  mappings = {}
72
  unmapped_fields = object_field_names.copy()
73
 
74
  for field in object_field_names:
75
  for key, value in key_values.items():
76
- if field.lower() in key.lower(): # Simple string matching (replace with model prediction)
77
  mappings[field] = value
78
  unmapped_fields.remove(field)
79
  break
 
1
+ from transformers import LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification, LayoutLMv3ImageProcessor
2
  import torch
3
  from PIL import Image
4
  import pdf2image
5
  from typing import Dict, List
6
+ import os
7
+ from huggingface_hub import login
8
 
9
+ # Optional: Log in to Hugging Face if using a private model
10
+ # Uncomment and replace with your token if needed
11
+ # login(token="your_hf_token")
12
+
13
+ # Load pre-trained LayoutLMv3 models
14
  tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
15
+ feature_extractor = LayoutLMv3ImageProcessor(apply_ocr=False) # Updated to ImageProcessor
16
+ model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base") # Public base model
17
 
18
  def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str, str]:
19
  """
20
+ Extract key-value pairs from PDF text using LayoutLMv3-base.
21
  Args:
22
  text_data (str): Extracted text from PDF.
23
  pdf_path (str): Path to the PDF file.
 
32
  key_values = {}
33
  for i, image in enumerate(images):
34
  # Preprocess image and text
35
+ encoding = feature_extractor(images=[image], text=text_data.splitlines(), return_tensors="pt")
36
  input_ids = encoding["input_ids"]
37
  attention_mask = encoding["attention_mask"]
38
+ # token_type_ids not needed for LayoutLMv3-base
39
 
40
  # Get model predictions
41
  with torch.no_grad():
42
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
43
  predictions = torch.argmax(outputs.logits, dim=2)
44
 
45
  # Post-process predictions to extract key-value pairs (simplified logic)
 
48
  current_key = None
49
  current_value = []
50
  for token, label in zip(tokens, labels):
51
+ if label == 1: # Assuming label 1 indicates a key start (adjust based on training)
52
  if current_key and current_value:
53
  key_values[current_key] = " ".join(current_value).strip()
54
  current_key = token
55
  current_value = []
56
+ elif label == 2 and current_key: # Assuming label 2 indicates a value (adjust based on training)
57
  current_value.append(token)
58
  if current_key and current_value:
59
  key_values[current_key] = " ".join(current_value).strip()
60
 
61
+ return key_values if key_values else {"status": "failed", "error": "No key-value pairs extracted", "key_values": {}}
62
  except Exception as e:
63
  return {"status": "failed", "error": str(e), "key_values": {}}
64
 
65
  def run_ai_mapping_with_layoutlm(key_values: Dict[str, str], object_field_names: List[str], pdf_path: str) -> Dict:
66
  """
67
+ Map extracted key-values to Salesforce fields using LayoutLMv3-base (simplified).
68
  Args:
69
  key_values (dict): Extracted key-value pairs.
70
  object_field_names (list): List of Salesforce field names.
 
73
  dict: Mapping results with status, mappings, unmapped fields, and error (if any).
74
  """
75
  try:
 
76
  mappings = {}
77
  unmapped_fields = object_field_names.copy()
78
 
79
  for field in object_field_names:
80
  for key, value in key_values.items():
81
+ if field.lower() in key.lower(): # Simple string matching
82
  mappings[field] = value
83
  unmapped_fields.remove(field)
84
  break