import gradio as gr from transformers import LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification import torch from PIL import Image import os import tempfile from tqdm import tqdm import re from ai_mapping import extract_key_values_with_layoutlm, run_ai_mapping_with_layoutlm, extract_clauses from ocr_utils import extract_text_from_pdf_with_tesseract_or_layoutlm from salesforce_utils import get_token, create_or_update_record # Initialize global state contract_data = {} # In-memory contract repository processed_files = 0 total_files = 0 # Load pre-trained LayoutLMv3 model and tokenizer (placeholder for future fine-tuning) tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base") def save_temp_file(pdf_bytes): """Save PDF bytes to a temporary file and return the path.""" with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: tmp.write(pdf_bytes) return tmp.name def detect_risks(data): """Detect risks (e.g., missing dates, large amounts).""" risks = [] if not data.get("Agreement Start Date") and not data.get("Agreement End Date"): risks.append("No agreement dates detected - potential obligation risk.") if data.get("Amount") and float(data.get("Amount", "0").replace('$', '').replace(',', '')) > 1000000: risks.append("Large amount detected - review for financial risk.") return risks def process_contract(pdf_bytes, object_type): """Process contract and simulate CCI workflow.""" global processed_files, total_files total_files = 1 processed_files = 0 print("Received file - Starting processing") temp_path = save_temp_file(pdf_bytes) print(f"Temporary file created at: {temp_path}") page_data = extract_text_from_pdf_with_tesseract_or_layoutlm(temp_path) print(f"OCR result pages: {len(page_data)}") if not page_data or all("No text detected" in page["text"] for page in page_data): os.unlink(temp_path) print("No text extracted from PDF.") return "❌ No text extracted from PDF.", {}, [], "0/1" print("Extracting key data") key_data = extract_key_values_with_layoutlm(page_data, temp_path) print(f"Key data extracted: {key_data}") if "status" in key_data and key_data["status"] == "failed": os.unlink(temp_path) print(f"Extraction failed: {key_data.get('error', 'Unknown error')}") return f"❌ Extraction failed: {key_data.get('error', 'Unknown error')}", {}, [], "0/1" print("Extracting clauses") clauses = extract_clauses(page_data) print(f"Extracted clauses: {clauses}") print("Detecting risks") risks = detect_risks(key_data) print(f"Detected risks: {risks}") status = "✅ Processed" if not risks else "⚠️ Processed with risks" # Mock CLM fields with Salesforce-ready structure clm_fields = {"Name": f"Contract_{len(contract_data) + 1}", "Type__c": object_type, "Status__c": status} clm_fields.update({k: v for k, v in key_data.items() if k not in ["status", "error", "key_values"]}) for clause_name, clause_text in clauses.items(): clm_fields[f"{clause_name}_Text__c"] = clause_text # Optional Salesforce sync try: token, instance_url = get_token() sf_response = create_or_update_record(f"{object_type}__c", clm_fields, token, instance_url) if "error" in sf_response: print(f"Salesforce sync failed: {sf_response['error']}") else: print(f"Salesforce sync successful: {sf_response}") except Exception as e: print(f"Salesforce sync error: {str(e)}") contract_id = f"Contract_{len(contract_data) + 1}" contract_data[contract_id] = { "data": key_data, "clauses": clauses, "risks": risks, "clm_fields": clm_fields, "status": status } processed_files = 1 progress = "1/1" print(f"Processing completed - ID: {contract_id}, Progress: {progress}") os.unlink(temp_path) return status, key_data, risks, progress def search_contracts(query): """Search contract repository.""" results = {cid: data for cid, data in contract_data.items() if query.lower() in str(data).lower()} return results if results else {"No matches": "No contracts found matching the query."} # Gradio UI with gr.Blocks(title="Contract Intelligence App") as demo: with gr.Row(): file_input = gr.File(type="binary", file_types=["pdf"], file_count="multiple", label="Upload Contracts") upload_progress = gr.Textbox(label="Progress", value="0/0", interactive=False) object_type = gr.Dropdown(choices=["Contract", "Agreement", "Invoice"], label="Select Object Type") process_button = gr.Button("Process Contracts") status_output = gr.Textbox(label="Status", interactive=False) extracted_data_output = gr.JSON(label="Extracted Data") clauses_output = gr.JSON(label="Extracted Clauses") risks_output = gr.Textbox(label="Detected Risks", interactive=False) def process_and_display(files, obj_type): if not files: return "❌ No files uploaded.", {}, {}, "No risks detected", gr.update(value="0/0") results = [] all_data = {} all_clauses = {} all_risks = [] for i, file in enumerate(files): status, data, risks, _ = process_contract(file, obj_type) clauses = contract_data.get(f"Contract_{len(contract_data)}", {}).get("clauses", {}) # Handle empty contract_data results.append(f"{status} - File: File_{i}") all_data.update({f"File_{i}": data}) all_clauses.update({f"File_{i}": clauses}) all_risks.extend(risks) progress = f"{len(files)}/{len(files)}" return "\n".join(results), all_data, all_clauses, "\n".join(all_risks) if all_risks else "No risks detected", gr.update(value=progress) process_button.click( fn=process_and_display, inputs=[file_input, object_type], outputs=[status_output, extracted_data_output, clauses_output, risks_output, upload_progress] ) with gr.Tab("Contract Repository"): search_query = gr.Textbox(label="Search Contracts", placeholder="Enter keyword...") search_results = gr.JSON(label="Search Results") search_button = gr.Button("Search") search_button.click( fn=search_contracts, inputs=search_query, outputs=search_results ) demo.launch()