pavansuresh commited on
Commit
2b51034
·
verified ·
1 Parent(s): 1c6f1bf

Update ai_mapping.py

Browse files
Files changed (1) hide show
  1. ai_mapping.py +79 -25
ai_mapping.py CHANGED
@@ -1,39 +1,93 @@
1
- from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
2
- import os
 
 
 
3
 
4
- def run_ai_mapping(text_data, pdf_path, object_fields):
 
 
 
 
 
5
  """
6
- Map extracted PDF text to Salesforce fields using LayoutLMv3.
7
- Returns mappings with confidence scores and flags unmapped fields.
 
 
 
 
8
  """
9
  try:
10
- # Placeholder for LayoutLMv3-based key-value pair extraction
11
- # In a real implementation, load the model and processor
12
- # processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
13
- # model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base-finetuned-funsd")
14
- # Process pdf_path, extract key-value pairs, and map to object_fields
 
 
 
 
 
 
15
 
16
- # Mock implementation for demonstration
17
- mappings = {
18
- "Customer_Name__c": {"value": "Acme Corp", "confidence": 0.95},
19
- "Start_Date__c": {"value": "2023-01-01", "confidence": 0.90},
20
- "End_Date__c": {"value": "2024-01-01", "confidence": 0.90},
21
- "Amount__c": {"value": "50000", "confidence": 0.85}
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # Flag unmapped fields
25
- unmapped_fields = [field for field in object_fields if field not in mappings]
26
- result = {
27
  "mappings": mappings,
28
  "unmapped_fields": unmapped_fields,
29
- "status": "success" if not unmapped_fields else "partial",
30
  "error": None
31
  }
32
- return result
33
  except Exception as e:
34
  return {
35
- "mappings": {},
36
- "unmapped_fields": object_fields,
37
  "status": "failed",
38
- "error": f"AI mapping failed: {str(e)}"
 
 
39
  }
 
1
+ from transformers import LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification, LayoutLMv3FeatureExtractor
2
+ import torch
3
+ from PIL import Image
4
+ import pdf2image
5
+ from typing import Dict, List
6
 
7
+ # Load pre-trained LayoutLMv3 models (adjust model names based on your fine-tuned models)
8
+ tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
9
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False) # Set to True if OCR is needed
10
+ model = LayoutLMv3ForTokenClassification.from_pretrained("path_to_finetuned_funsd_model") # Replace with your fine-tuned model
11
+
12
+ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str, str]:
13
  """
14
+ Extract key-value pairs from PDF text using LayoutLMv3-finetuned-funsd.
15
+ Args:
16
+ text_data (str): Extracted text from PDF.
17
+ pdf_path (str): Path to the PDF file.
18
+ Returns:
19
+ dict: Key-value pairs extracted from the document.
20
  """
21
  try:
22
+ # Convert PDF to images (one per page)
23
+ images = pdf2image.convert_from_path(pdf_path)
24
+
25
+ # Process each page
26
+ key_values = {}
27
+ for i, image in enumerate(images):
28
+ # Preprocess image and text
29
+ encoding = feature_extractor(image, text_data.splitlines(), return_tensors="pt")
30
+ input_ids = encoding["input_ids"]
31
+ attention_mask = encoding["attention_mask"]
32
+ token_type_ids = encoding["token_type_ids"] if "token_type_ids" in encoding else None
33
 
34
+ # Get model predictions
35
+ with torch.no_grad():
36
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
37
+ predictions = torch.argmax(outputs.logits, dim=2)
38
+
39
+ # Post-process predictions to extract key-value pairs (simplified logic)
40
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
41
+ labels = predictions[0].tolist()
42
+ current_key = None
43
+ current_value = []
44
+ for token, label in zip(tokens, labels):
45
+ if label == 1: # Assuming label 1 indicates a key start
46
+ if current_key and current_value:
47
+ key_values[current_key] = " ".join(current_value).strip()
48
+ current_key = token
49
+ current_value = []
50
+ elif label == 2 and current_key: # Assuming label 2 indicates a value
51
+ current_value.append(token)
52
+ if current_key and current_value:
53
+ key_values[current_key] = " ".join(current_value).strip()
54
+
55
+ return key_values
56
+ except Exception as e:
57
+ return {"status": "failed", "error": str(e), "key_values": {}}
58
+
59
+ def run_ai_mapping_with_layoutlm(key_values: Dict[str, str], object_field_names: List[str], pdf_path: str) -> Dict:
60
+ """
61
+ Map extracted key-values to Salesforce fields using a custom-trained Transformer.
62
+ Args:
63
+ key_values (dict): Extracted key-value pairs.
64
+ object_field_names (list): List of Salesforce field names.
65
+ pdf_path (str): Path to the PDF file (for context if needed).
66
+ Returns:
67
+ dict: Mapping results with status, mappings, unmapped fields, and error (if any).
68
+ """
69
+ try:
70
+ # Placeholder for custom-trained Transformer logic (replace with your model)
71
+ mappings = {}
72
+ unmapped_fields = object_field_names.copy()
73
+
74
+ for field in object_field_names:
75
+ for key, value in key_values.items():
76
+ if field.lower() in key.lower(): # Simple string matching (replace with model prediction)
77
+ mappings[field] = value
78
+ unmapped_fields.remove(field)
79
+ break
80
 
81
+ return {
82
+ "status": "success",
 
83
  "mappings": mappings,
84
  "unmapped_fields": unmapped_fields,
 
85
  "error": None
86
  }
 
87
  except Exception as e:
88
  return {
 
 
89
  "status": "failed",
90
+ "error": str(e),
91
+ "mappings": {},
92
+ "unmapped_fields": object_field_names
93
  }