SmartContractMigrator / ai_mapping.py
pavansuresh's picture
Update ai_mapping.py
0cb4c94 verified
from transformers import LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification, LayoutLMv3ImageProcessor
import torch
from PIL import Image
import fitz # PyMuPDF
from typing import Dict, List
import os
import re
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load pre-trained LayoutLMv3 models
try:
tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
feature_extractor = LayoutLMv3ImageProcessor(apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
logger.info("LayoutLMv3 models loaded successfully.")
except Exception as e:
logger.error(f"Failed to load LayoutLMv3 models: {str(e)}")
def extract_key_values_with_layoutlm(page_data: list, pdf_path: str) -> Dict[str, str]:
"""
Extract key-value pairs from PDF text using LayoutLMv3-base with focus on Agreement Name,
Agreement Start Date, Agreement End Date, and Total Agreement Value, with regex fallback.
Args:
page_data (list): List of dictionaries with 'text' (str), 'words' (list of str),
'bbox' (list of [x0, y0, x1, y1] normalized to 0-1000), and 'image_dims' ([width, height]) per page.
pdf_path (str): Path to the PDF file.
Returns:
dict: Key-value pairs extracted from the document focusing on specified fields.
"""
key_values = {
"Agreement Name": "Unknown",
"Agreement Start Date": "",
"Agreement End Date": "",
"Total Agreement Value": ""
}
try:
# Fallback to regex using concatenated text from all pages
text_data = " ".join([page.get("text", "") for page in page_data])
logger.info("Starting regex-based extraction.")
# Refined regex patterns for required fields, avoiding misidentification
name_context = re.findall(r'(?:Agreement\s+Name|Contract\s+Title|Agreement\s+Title)\s*[:\s]*([A-Za-z0-9\s]+?)(?=\s*(?:Exhibit|\n\n|\Z))', text_data, re.IGNORECASE)
if name_context:
key_values["Agreement Name"] = next((name.strip() for name in name_context if len(name.split()) > 1 and "MASTER SUBSCRIPTION AGREEMENT" not in name.upper() and "Customer" not in name), "Unknown")
else:
# Fallback to infer name from context, avoiding single party names
party_match = re.search(r'(?:between\s+([A-Za-z\s]+)\s+and\s+([A-Za-z\s]+))', text_data, re.IGNORECASE)
if party_match:
key_values["Agreement Name"] = f"{party_match.group(1).strip()} and {party_match.group(2).strip()}" if party_match.group(2) else "Unknown"
# Enhanced date patterns to capture "executed as of" and other date contexts
date_patterns = [
r'(?:Agreement\s+Start\s+Date|Effective\s+Date|executed\s+as\s+of)\s*[:\s]*(\d{1,2}/\d{1,2}/\d{2,4})',
r'(?:Agreement\s+End\s+Date|Termination\s+Date)\s*[:\s]*(\d{1,2}/\d{1,2}/\d{2,4})'
]
for pattern in date_patterns:
matches = re.findall(pattern, text_data, re.IGNORECASE)
if matches:
key, value = ("Agreement Start Date", matches[0]) if "start" in pattern.lower() or "effective" in pattern.lower() or "executed" in pattern.lower() else ("Agreement End Date", matches[0])
if value and not key_values.get(key):
key_values[key] = value
# Improved amount pattern to capture total value context
amount_pattern = r'(?:Total\s+Agreement\s+Value|Total\s+Amount|Contract\s+Value|List\s+Price)\s*[:\s]*\$?\d{1,3}(?:,\d{3})*(?:\.\d{2})?'
amounts = re.findall(amount_pattern, text_data, re.IGNORECASE)
if amounts:
key_values["Total Agreement Value"] = next((amt.split(":")[-1].strip() if ":" in amt else amt.strip() for amt in amounts if any(k.lower() in amt.lower() for k in ["total", "value", "price"])), "")
# Attempt LayoutLMv3 processing for enhanced extraction
if all([tokenizer, feature_extractor, model]):
doc = fitz.open(pdf_path)
for page_num, page_info in enumerate(page_data):
if not page_info.get("text", "").strip() or "No text detected" in page_info.get("text", ""):
continue
page = doc[page_num]
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI
img_path = f"{pdf_path}_page_{page_num}.png"
pix.save(img_path)
image = Image.open(img_path).convert("RGB")
words = page_info.get("words", [])
bboxes = page_info.get("bbox", [])
if words and bboxes:
encoding = tokenizer(
words,
boxes=bboxes,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
)
input_ids = encoding["input_ids"]
attention_mask = encoding["attention_mask"]
bbox = encoding["bbox"]
image_encoding = feature_extractor(image, return_tensors="pt")
pixel_values = image_encoding["pixel_values"]
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
bbox=bbox,
pixel_values=pixel_values
)
predictions = torch.argmax(outputs.logits, dim=2)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
labels = predictions[0].tolist()
current_key = None
current_value = []
for token, label in zip(tokens, labels):
if label == 1: # Key start (hypothetical label, adjust based on training)
if current_key and current_value:
key = " ".join(current_value).strip()
if "agreement name" in current_key.lower() and "MASTER SUBSCRIPTION AGREEMENT" not in key.upper() and "Customer" not in key:
key_values["Agreement Name"] = key
elif "start date" in current_key.lower() or "effective date" in current_key.lower() or "executed as of" in current_key.lower():
key_values["Agreement Start Date"] = key
elif "end date" in current_key.lower() or "termination date" in current_key.lower():
key_values["Agreement End Date"] = key
elif "total agreement value" in current_key.lower() or "amount" in current_key.lower() or "price" in current_key.lower():
key_values["Total Agreement Value"] = key
current_key = token
current_value = []
elif label == 2 and current_key: # Value (hypothetical label, adjust based on training)
current_value.append(token)
if current_key and current_value:
key = " ".join(current_value).strip()
if "agreement name" in current_key.lower() and "MASTER SUBSCRIPTION AGREEMENT" not in key.upper() and "Customer" not in key:
key_values["Agreement Name"] = key
elif "start date" in current_key.lower() or "effective date" in current_key.lower() or "executed as of" in current_key.lower():
key_values["Agreement Start Date"] = key
elif "end date" in current_key.lower() or "termination date" in current_key.lower():
key_values["Agreement End Date"] = key
elif "total agreement value" in current_key.lower() or "amount" in current_key.lower() or "price" in current_key.lower():
key_values["Total Agreement Value"] = key
if os.path.exists(img_path):
os.unlink(img_path)
doc.close()
else:
logger.warning("LayoutLMv3 model components not available, skipping advanced extraction.")
return key_values if any(key_values.values()) else {"status": "failed", "error": "No key-value pairs extracted", "key_values": {}}
except Exception as e:
logger.error(f"Error in extract_key_values_with_layoutlm: {str(e)}")
return {"status": "failed", "error": str(e), "key_values": key_values}
def extract_clauses(page_data: list) -> Dict[str, str]:
"""
Extract clauses from PDF text based on keywords, focusing on key clauses like NO WAIVER and Termination.
Args:
page_data (list): List of dictionaries with 'text' (str) per page.
Returns:
dict: Mapping of clause names to their text content.
"""
clauses = {}
try:
text_data = "\n".join([page.get("text", "") for page in page_data])
logger.info("Starting clause extraction.")
# Search for NO WAIVER clause
no_waiver_match = re.search(r'(?:General\s+Provisions\s*[\s\S]*?NO\s+WAIVER\s*[:\s]*)([\s\S]*?)(?=\n\n|\Z)', text_data, re.IGNORECASE)
if no_waiver_match:
clause_text = no_waiver_match.group(1).strip()
clauses["NO WAIVER"] = clause_text if clause_text else "NO WAIVER clause found but no content extracted"
elif "NO WAIVER" in text_data.upper():
clauses["NO WAIVER"] = re.search(r'(NO\s+WAIVER\s*[:\s]*[\s\S]*?)(?=\n\n|\Z)', text_data, re.IGNORECASE).group(1).strip() if re.search(r'(NO\s+WAIVER\s*[:\s]*[\s\S]*?)(?=\n\n|\Z)', text_data, re.IGNORECASE) else "NO WAIVER clause identified but no detailed content extracted"
# Search for Termination clause
termination_match = re.search(r'(?:Termination\s*[:\s]*)([\s\S]*?)(?=\n\n|\Z)', text_data, re.IGNORECASE)
if termination_match:
clauses["Termination"] = termination_match.group(1).strip()
return clauses if clauses else {"No clauses extracted": "No relevant clauses found in the document"}
except Exception as e:
logger.error(f"Error in extract_clauses: {str(e)}")
return clauses
def run_ai_mapping_with_layoutlm(key_values: Dict[str, str], object_field_names: List[str], pdf_path: str) -> Dict:
"""
Map extracted key-values to object fields, prioritizing Agreement Name, Agreement Start Date,
Agreement End Date, and Total Agreement Value.
Args:
key_values (dict): Extracted key-value pairs.
object_field_names (list): List of object field names.
pdf_path (str): Path to the PDF file (for context if needed).
Returns:
dict: Mapping results with status, mappings, unmapped fields, and error (if any).
"""
try:
mappings = {}
unmapped_fields = object_field_names.copy()
logger.info("Starting mapping process.")
for field in object_field_names:
for key, value in key_values.items():
if field.lower() in key.lower() and value:
mappings[field] = value
if field in unmapped_fields:
unmapped_fields.remove(field)
break
return {
"status": "success",
"mappings": mappings,
"unmapped_fields": unmapped_fields,
"error": None,
"clauses": extract_clauses(page_data) # Include clauses in the output
}
except Exception as e:
logger.error(f"Error in run_ai_mapping_with_layoutlm: {str(e)}")
return {
"status": "failed",
"error": str(e),
"mappings": {},
"unmapped_fields": object_field_names,
"clauses": {}
}