from transformers import LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification, LayoutLMv3ImageProcessor import torch from PIL import Image import fitz # PyMuPDF from typing import Dict, List import os import re import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load pre-trained LayoutLMv3 models try: tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") feature_extractor = LayoutLMv3ImageProcessor(apply_ocr=False) model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base") logger.info("LayoutLMv3 models loaded successfully.") except Exception as e: logger.error(f"Failed to load LayoutLMv3 models: {str(e)}") def extract_key_values_with_layoutlm(page_data: list, pdf_path: str) -> Dict[str, str]: """ Extract key-value pairs from PDF text using LayoutLMv3-base with focus on Agreement Name, Agreement Start Date, Agreement End Date, and Total Agreement Value, with regex fallback. Args: page_data (list): List of dictionaries with 'text' (str), 'words' (list of str), 'bbox' (list of [x0, y0, x1, y1] normalized to 0-1000), and 'image_dims' ([width, height]) per page. pdf_path (str): Path to the PDF file. Returns: dict: Key-value pairs extracted from the document focusing on specified fields. """ key_values = { "Agreement Name": "Unknown", "Agreement Start Date": "", "Agreement End Date": "", "Total Agreement Value": "" } try: # Fallback to regex using concatenated text from all pages text_data = " ".join([page.get("text", "") for page in page_data]) logger.info("Starting regex-based extraction.") # Refined regex patterns for required fields, avoiding misidentification name_context = re.findall(r'(?:Agreement\s+Name|Contract\s+Title|Agreement\s+Title)\s*[:\s]*([A-Za-z0-9\s]+?)(?=\s*(?:Exhibit|\n\n|\Z))', text_data, re.IGNORECASE) if name_context: key_values["Agreement Name"] = next((name.strip() for name in name_context if len(name.split()) > 1 and "MASTER SUBSCRIPTION AGREEMENT" not in name.upper() and "Customer" not in name), "Unknown") else: # Fallback to infer name from context, avoiding single party names party_match = re.search(r'(?:between\s+([A-Za-z\s]+)\s+and\s+([A-Za-z\s]+))', text_data, re.IGNORECASE) if party_match: key_values["Agreement Name"] = f"{party_match.group(1).strip()} and {party_match.group(2).strip()}" if party_match.group(2) else "Unknown" # Enhanced date patterns to capture "executed as of" and other date contexts date_patterns = [ r'(?:Agreement\s+Start\s+Date|Effective\s+Date|executed\s+as\s+of)\s*[:\s]*(\d{1,2}/\d{1,2}/\d{2,4})', r'(?:Agreement\s+End\s+Date|Termination\s+Date)\s*[:\s]*(\d{1,2}/\d{1,2}/\d{2,4})' ] for pattern in date_patterns: matches = re.findall(pattern, text_data, re.IGNORECASE) if matches: key, value = ("Agreement Start Date", matches[0]) if "start" in pattern.lower() or "effective" in pattern.lower() or "executed" in pattern.lower() else ("Agreement End Date", matches[0]) if value and not key_values.get(key): key_values[key] = value # Improved amount pattern to capture total value context amount_pattern = r'(?:Total\s+Agreement\s+Value|Total\s+Amount|Contract\s+Value|List\s+Price)\s*[:\s]*\$?\d{1,3}(?:,\d{3})*(?:\.\d{2})?' amounts = re.findall(amount_pattern, text_data, re.IGNORECASE) if amounts: key_values["Total Agreement Value"] = next((amt.split(":")[-1].strip() if ":" in amt else amt.strip() for amt in amounts if any(k.lower() in amt.lower() for k in ["total", "value", "price"])), "") # Attempt LayoutLMv3 processing for enhanced extraction if all([tokenizer, feature_extractor, model]): doc = fitz.open(pdf_path) for page_num, page_info in enumerate(page_data): if not page_info.get("text", "").strip() or "No text detected" in page_info.get("text", ""): continue page = doc[page_num] pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI img_path = f"{pdf_path}_page_{page_num}.png" pix.save(img_path) image = Image.open(img_path).convert("RGB") words = page_info.get("words", []) bboxes = page_info.get("bbox", []) if words and bboxes: encoding = tokenizer( words, boxes=bboxes, return_tensors="pt", truncation=True, padding=True, max_length=512 ) input_ids = encoding["input_ids"] attention_mask = encoding["attention_mask"] bbox = encoding["bbox"] image_encoding = feature_extractor(image, return_tensors="pt") pixel_values = image_encoding["pixel_values"] with torch.no_grad(): outputs = model( input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, pixel_values=pixel_values ) predictions = torch.argmax(outputs.logits, dim=2) tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) labels = predictions[0].tolist() current_key = None current_value = [] for token, label in zip(tokens, labels): if label == 1: # Key start (hypothetical label, adjust based on training) if current_key and current_value: key = " ".join(current_value).strip() if "agreement name" in current_key.lower() and "MASTER SUBSCRIPTION AGREEMENT" not in key.upper() and "Customer" not in key: key_values["Agreement Name"] = key elif "start date" in current_key.lower() or "effective date" in current_key.lower() or "executed as of" in current_key.lower(): key_values["Agreement Start Date"] = key elif "end date" in current_key.lower() or "termination date" in current_key.lower(): key_values["Agreement End Date"] = key elif "total agreement value" in current_key.lower() or "amount" in current_key.lower() or "price" in current_key.lower(): key_values["Total Agreement Value"] = key current_key = token current_value = [] elif label == 2 and current_key: # Value (hypothetical label, adjust based on training) current_value.append(token) if current_key and current_value: key = " ".join(current_value).strip() if "agreement name" in current_key.lower() and "MASTER SUBSCRIPTION AGREEMENT" not in key.upper() and "Customer" not in key: key_values["Agreement Name"] = key elif "start date" in current_key.lower() or "effective date" in current_key.lower() or "executed as of" in current_key.lower(): key_values["Agreement Start Date"] = key elif "end date" in current_key.lower() or "termination date" in current_key.lower(): key_values["Agreement End Date"] = key elif "total agreement value" in current_key.lower() or "amount" in current_key.lower() or "price" in current_key.lower(): key_values["Total Agreement Value"] = key if os.path.exists(img_path): os.unlink(img_path) doc.close() else: logger.warning("LayoutLMv3 model components not available, skipping advanced extraction.") return key_values if any(key_values.values()) else {"status": "failed", "error": "No key-value pairs extracted", "key_values": {}} except Exception as e: logger.error(f"Error in extract_key_values_with_layoutlm: {str(e)}") return {"status": "failed", "error": str(e), "key_values": key_values} def extract_clauses(page_data: list) -> Dict[str, str]: """ Extract clauses from PDF text based on keywords, focusing on key clauses like NO WAIVER and Termination. Args: page_data (list): List of dictionaries with 'text' (str) per page. Returns: dict: Mapping of clause names to their text content. """ clauses = {} try: text_data = "\n".join([page.get("text", "") for page in page_data]) logger.info("Starting clause extraction.") # Search for NO WAIVER clause no_waiver_match = re.search(r'(?:General\s+Provisions\s*[\s\S]*?NO\s+WAIVER\s*[:\s]*)([\s\S]*?)(?=\n\n|\Z)', text_data, re.IGNORECASE) if no_waiver_match: clause_text = no_waiver_match.group(1).strip() clauses["NO WAIVER"] = clause_text if clause_text else "NO WAIVER clause found but no content extracted" elif "NO WAIVER" in text_data.upper(): clauses["NO WAIVER"] = re.search(r'(NO\s+WAIVER\s*[:\s]*[\s\S]*?)(?=\n\n|\Z)', text_data, re.IGNORECASE).group(1).strip() if re.search(r'(NO\s+WAIVER\s*[:\s]*[\s\S]*?)(?=\n\n|\Z)', text_data, re.IGNORECASE) else "NO WAIVER clause identified but no detailed content extracted" # Search for Termination clause termination_match = re.search(r'(?:Termination\s*[:\s]*)([\s\S]*?)(?=\n\n|\Z)', text_data, re.IGNORECASE) if termination_match: clauses["Termination"] = termination_match.group(1).strip() return clauses if clauses else {"No clauses extracted": "No relevant clauses found in the document"} except Exception as e: logger.error(f"Error in extract_clauses: {str(e)}") return clauses def run_ai_mapping_with_layoutlm(key_values: Dict[str, str], object_field_names: List[str], pdf_path: str) -> Dict: """ Map extracted key-values to object fields, prioritizing Agreement Name, Agreement Start Date, Agreement End Date, and Total Agreement Value. Args: key_values (dict): Extracted key-value pairs. object_field_names (list): List of object field names. pdf_path (str): Path to the PDF file (for context if needed). Returns: dict: Mapping results with status, mappings, unmapped fields, and error (if any). """ try: mappings = {} unmapped_fields = object_field_names.copy() logger.info("Starting mapping process.") for field in object_field_names: for key, value in key_values.items(): if field.lower() in key.lower() and value: mappings[field] = value if field in unmapped_fields: unmapped_fields.remove(field) break return { "status": "success", "mappings": mappings, "unmapped_fields": unmapped_fields, "error": None, "clauses": extract_clauses(page_data) # Include clauses in the output } except Exception as e: logger.error(f"Error in run_ai_mapping_with_layoutlm: {str(e)}") return { "status": "failed", "error": str(e), "mappings": {}, "unmapped_fields": object_field_names, "clauses": {} }