Spaces:
Sleeping
Sleeping
| from transformers import LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification, LayoutLMv3ImageProcessor | |
| import torch | |
| from PIL import Image | |
| import fitz # PyMuPDF | |
| from typing import Dict, List | |
| import os | |
| import re | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Load pre-trained LayoutLMv3 models | |
| try: | |
| tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") | |
| feature_extractor = LayoutLMv3ImageProcessor(apply_ocr=False) | |
| model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base") | |
| logger.info("LayoutLMv3 models loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to load LayoutLMv3 models: {str(e)}") | |
| def extract_key_values_with_layoutlm(page_data: list, pdf_path: str) -> Dict[str, str]: | |
| """ | |
| Extract key-value pairs from PDF text using LayoutLMv3-base with focus on Agreement Name, | |
| Agreement Start Date, Agreement End Date, and Total Agreement Value, with regex fallback. | |
| Args: | |
| page_data (list): List of dictionaries with 'text' (str), 'words' (list of str), | |
| 'bbox' (list of [x0, y0, x1, y1] normalized to 0-1000), and 'image_dims' ([width, height]) per page. | |
| pdf_path (str): Path to the PDF file. | |
| Returns: | |
| dict: Key-value pairs extracted from the document focusing on specified fields. | |
| """ | |
| key_values = { | |
| "Agreement Name": "Unknown", | |
| "Agreement Start Date": "", | |
| "Agreement End Date": "", | |
| "Total Agreement Value": "" | |
| } | |
| try: | |
| # Fallback to regex using concatenated text from all pages | |
| text_data = " ".join([page.get("text", "") for page in page_data]) | |
| logger.info("Starting regex-based extraction.") | |
| # Refined regex patterns for required fields, avoiding misidentification | |
| name_context = re.findall(r'(?:Agreement\s+Name|Contract\s+Title|Agreement\s+Title)\s*[:\s]*([A-Za-z0-9\s]+?)(?=\s*(?:Exhibit|\n\n|\Z))', text_data, re.IGNORECASE) | |
| if name_context: | |
| key_values["Agreement Name"] = next((name.strip() for name in name_context if len(name.split()) > 1 and "MASTER SUBSCRIPTION AGREEMENT" not in name.upper() and "Customer" not in name), "Unknown") | |
| else: | |
| # Fallback to infer name from context, avoiding single party names | |
| party_match = re.search(r'(?:between\s+([A-Za-z\s]+)\s+and\s+([A-Za-z\s]+))', text_data, re.IGNORECASE) | |
| if party_match: | |
| key_values["Agreement Name"] = f"{party_match.group(1).strip()} and {party_match.group(2).strip()}" if party_match.group(2) else "Unknown" | |
| # Enhanced date patterns to capture "executed as of" and other date contexts | |
| date_patterns = [ | |
| r'(?:Agreement\s+Start\s+Date|Effective\s+Date|executed\s+as\s+of)\s*[:\s]*(\d{1,2}/\d{1,2}/\d{2,4})', | |
| r'(?:Agreement\s+End\s+Date|Termination\s+Date)\s*[:\s]*(\d{1,2}/\d{1,2}/\d{2,4})' | |
| ] | |
| for pattern in date_patterns: | |
| matches = re.findall(pattern, text_data, re.IGNORECASE) | |
| if matches: | |
| key, value = ("Agreement Start Date", matches[0]) if "start" in pattern.lower() or "effective" in pattern.lower() or "executed" in pattern.lower() else ("Agreement End Date", matches[0]) | |
| if value and not key_values.get(key): | |
| key_values[key] = value | |
| # Improved amount pattern to capture total value context | |
| amount_pattern = r'(?:Total\s+Agreement\s+Value|Total\s+Amount|Contract\s+Value|List\s+Price)\s*[:\s]*\$?\d{1,3}(?:,\d{3})*(?:\.\d{2})?' | |
| amounts = re.findall(amount_pattern, text_data, re.IGNORECASE) | |
| if amounts: | |
| key_values["Total Agreement Value"] = next((amt.split(":")[-1].strip() if ":" in amt else amt.strip() for amt in amounts if any(k.lower() in amt.lower() for k in ["total", "value", "price"])), "") | |
| # Attempt LayoutLMv3 processing for enhanced extraction | |
| if all([tokenizer, feature_extractor, model]): | |
| doc = fitz.open(pdf_path) | |
| for page_num, page_info in enumerate(page_data): | |
| if not page_info.get("text", "").strip() or "No text detected" in page_info.get("text", ""): | |
| continue | |
| page = doc[page_num] | |
| pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI | |
| img_path = f"{pdf_path}_page_{page_num}.png" | |
| pix.save(img_path) | |
| image = Image.open(img_path).convert("RGB") | |
| words = page_info.get("words", []) | |
| bboxes = page_info.get("bbox", []) | |
| if words and bboxes: | |
| encoding = tokenizer( | |
| words, | |
| boxes=bboxes, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=512 | |
| ) | |
| input_ids = encoding["input_ids"] | |
| attention_mask = encoding["attention_mask"] | |
| bbox = encoding["bbox"] | |
| image_encoding = feature_extractor(image, return_tensors="pt") | |
| pixel_values = image_encoding["pixel_values"] | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| bbox=bbox, | |
| pixel_values=pixel_values | |
| ) | |
| predictions = torch.argmax(outputs.logits, dim=2) | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| labels = predictions[0].tolist() | |
| current_key = None | |
| current_value = [] | |
| for token, label in zip(tokens, labels): | |
| if label == 1: # Key start (hypothetical label, adjust based on training) | |
| if current_key and current_value: | |
| key = " ".join(current_value).strip() | |
| if "agreement name" in current_key.lower() and "MASTER SUBSCRIPTION AGREEMENT" not in key.upper() and "Customer" not in key: | |
| key_values["Agreement Name"] = key | |
| elif "start date" in current_key.lower() or "effective date" in current_key.lower() or "executed as of" in current_key.lower(): | |
| key_values["Agreement Start Date"] = key | |
| elif "end date" in current_key.lower() or "termination date" in current_key.lower(): | |
| key_values["Agreement End Date"] = key | |
| elif "total agreement value" in current_key.lower() or "amount" in current_key.lower() or "price" in current_key.lower(): | |
| key_values["Total Agreement Value"] = key | |
| current_key = token | |
| current_value = [] | |
| elif label == 2 and current_key: # Value (hypothetical label, adjust based on training) | |
| current_value.append(token) | |
| if current_key and current_value: | |
| key = " ".join(current_value).strip() | |
| if "agreement name" in current_key.lower() and "MASTER SUBSCRIPTION AGREEMENT" not in key.upper() and "Customer" not in key: | |
| key_values["Agreement Name"] = key | |
| elif "start date" in current_key.lower() or "effective date" in current_key.lower() or "executed as of" in current_key.lower(): | |
| key_values["Agreement Start Date"] = key | |
| elif "end date" in current_key.lower() or "termination date" in current_key.lower(): | |
| key_values["Agreement End Date"] = key | |
| elif "total agreement value" in current_key.lower() or "amount" in current_key.lower() or "price" in current_key.lower(): | |
| key_values["Total Agreement Value"] = key | |
| if os.path.exists(img_path): | |
| os.unlink(img_path) | |
| doc.close() | |
| else: | |
| logger.warning("LayoutLMv3 model components not available, skipping advanced extraction.") | |
| return key_values if any(key_values.values()) else {"status": "failed", "error": "No key-value pairs extracted", "key_values": {}} | |
| except Exception as e: | |
| logger.error(f"Error in extract_key_values_with_layoutlm: {str(e)}") | |
| return {"status": "failed", "error": str(e), "key_values": key_values} | |
| def extract_clauses(page_data: list) -> Dict[str, str]: | |
| """ | |
| Extract clauses from PDF text based on keywords, focusing on key clauses like NO WAIVER and Termination. | |
| Args: | |
| page_data (list): List of dictionaries with 'text' (str) per page. | |
| Returns: | |
| dict: Mapping of clause names to their text content. | |
| """ | |
| clauses = {} | |
| try: | |
| text_data = "\n".join([page.get("text", "") for page in page_data]) | |
| logger.info("Starting clause extraction.") | |
| # Search for NO WAIVER clause | |
| no_waiver_match = re.search(r'(?:General\s+Provisions\s*[\s\S]*?NO\s+WAIVER\s*[:\s]*)([\s\S]*?)(?=\n\n|\Z)', text_data, re.IGNORECASE) | |
| if no_waiver_match: | |
| clause_text = no_waiver_match.group(1).strip() | |
| clauses["NO WAIVER"] = clause_text if clause_text else "NO WAIVER clause found but no content extracted" | |
| elif "NO WAIVER" in text_data.upper(): | |
| clauses["NO WAIVER"] = re.search(r'(NO\s+WAIVER\s*[:\s]*[\s\S]*?)(?=\n\n|\Z)', text_data, re.IGNORECASE).group(1).strip() if re.search(r'(NO\s+WAIVER\s*[:\s]*[\s\S]*?)(?=\n\n|\Z)', text_data, re.IGNORECASE) else "NO WAIVER clause identified but no detailed content extracted" | |
| # Search for Termination clause | |
| termination_match = re.search(r'(?:Termination\s*[:\s]*)([\s\S]*?)(?=\n\n|\Z)', text_data, re.IGNORECASE) | |
| if termination_match: | |
| clauses["Termination"] = termination_match.group(1).strip() | |
| return clauses if clauses else {"No clauses extracted": "No relevant clauses found in the document"} | |
| except Exception as e: | |
| logger.error(f"Error in extract_clauses: {str(e)}") | |
| return clauses | |
| def run_ai_mapping_with_layoutlm(key_values: Dict[str, str], object_field_names: List[str], pdf_path: str) -> Dict: | |
| """ | |
| Map extracted key-values to object fields, prioritizing Agreement Name, Agreement Start Date, | |
| Agreement End Date, and Total Agreement Value. | |
| Args: | |
| key_values (dict): Extracted key-value pairs. | |
| object_field_names (list): List of object field names. | |
| pdf_path (str): Path to the PDF file (for context if needed). | |
| Returns: | |
| dict: Mapping results with status, mappings, unmapped fields, and error (if any). | |
| """ | |
| try: | |
| mappings = {} | |
| unmapped_fields = object_field_names.copy() | |
| logger.info("Starting mapping process.") | |
| for field in object_field_names: | |
| for key, value in key_values.items(): | |
| if field.lower() in key.lower() and value: | |
| mappings[field] = value | |
| if field in unmapped_fields: | |
| unmapped_fields.remove(field) | |
| break | |
| return { | |
| "status": "success", | |
| "mappings": mappings, | |
| "unmapped_fields": unmapped_fields, | |
| "error": None, | |
| "clauses": extract_clauses(page_data) # Include clauses in the output | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in run_ai_mapping_with_layoutlm: {str(e)}") | |
| return { | |
| "status": "failed", | |
| "error": str(e), | |
| "mappings": {}, | |
| "unmapped_fields": object_field_names, | |
| "clauses": {} | |
| } |