Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification | |
| import torch | |
| from PIL import Image | |
| import os | |
| import tempfile | |
| from tqdm import tqdm | |
| import re | |
| from ai_mapping import extract_key_values_with_layoutlm, run_ai_mapping_with_layoutlm, extract_clauses | |
| from ocr_utils import extract_text_from_pdf_with_tesseract_or_layoutlm | |
| from salesforce_utils import get_token, create_or_update_record | |
| # Initialize global state | |
| contract_data = {} # In-memory contract repository | |
| processed_files = 0 | |
| total_files = 0 | |
| # Load pre-trained LayoutLMv3 model and tokenizer (placeholder for future fine-tuning) | |
| tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") | |
| model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base") | |
| def save_temp_file(pdf_bytes): | |
| """Save PDF bytes to a temporary file and return the path.""" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: | |
| tmp.write(pdf_bytes) | |
| return tmp.name | |
| def detect_risks(data): | |
| """Detect risks (e.g., missing dates, large amounts).""" | |
| risks = [] | |
| if not data.get("Agreement Start Date") and not data.get("Agreement End Date"): | |
| risks.append("No agreement dates detected - potential obligation risk.") | |
| if data.get("Amount") and float(data.get("Amount", "0").replace('$', '').replace(',', '')) > 1000000: | |
| risks.append("Large amount detected - review for financial risk.") | |
| return risks | |
| def process_contract(pdf_bytes, object_type): | |
| """Process contract and simulate CCI workflow.""" | |
| global processed_files, total_files | |
| total_files = 1 | |
| processed_files = 0 | |
| print("Received file - Starting processing") | |
| temp_path = save_temp_file(pdf_bytes) | |
| print(f"Temporary file created at: {temp_path}") | |
| page_data = extract_text_from_pdf_with_tesseract_or_layoutlm(temp_path) | |
| print(f"OCR result pages: {len(page_data)}") | |
| if not page_data or all("No text detected" in page["text"] for page in page_data): | |
| os.unlink(temp_path) | |
| print("No text extracted from PDF.") | |
| return "❌ No text extracted from PDF.", {}, [], "0/1" | |
| print("Extracting key data") | |
| key_data = extract_key_values_with_layoutlm(page_data, temp_path) | |
| print(f"Key data extracted: {key_data}") | |
| if "status" in key_data and key_data["status"] == "failed": | |
| os.unlink(temp_path) | |
| print(f"Extraction failed: {key_data.get('error', 'Unknown error')}") | |
| return f"❌ Extraction failed: {key_data.get('error', 'Unknown error')}", {}, [], "0/1" | |
| print("Extracting clauses") | |
| clauses = extract_clauses(page_data) | |
| print(f"Extracted clauses: {clauses}") | |
| print("Detecting risks") | |
| risks = detect_risks(key_data) | |
| print(f"Detected risks: {risks}") | |
| status = "✅ Processed" if not risks else "⚠️ Processed with risks" | |
| # Mock CLM fields with Salesforce-ready structure | |
| clm_fields = {"Name": f"Contract_{len(contract_data) + 1}", "Type__c": object_type, "Status__c": status} | |
| clm_fields.update({k: v for k, v in key_data.items() if k not in ["status", "error", "key_values"]}) | |
| for clause_name, clause_text in clauses.items(): | |
| clm_fields[f"{clause_name}_Text__c"] = clause_text | |
| # Optional Salesforce sync | |
| try: | |
| token, instance_url = get_token() | |
| sf_response = create_or_update_record(f"{object_type}__c", clm_fields, token, instance_url) | |
| if "error" in sf_response: | |
| print(f"Salesforce sync failed: {sf_response['error']}") | |
| else: | |
| print(f"Salesforce sync successful: {sf_response}") | |
| except Exception as e: | |
| print(f"Salesforce sync error: {str(e)}") | |
| contract_id = f"Contract_{len(contract_data) + 1}" | |
| contract_data[contract_id] = { | |
| "data": key_data, | |
| "clauses": clauses, | |
| "risks": risks, | |
| "clm_fields": clm_fields, | |
| "status": status | |
| } | |
| processed_files = 1 | |
| progress = "1/1" | |
| print(f"Processing completed - ID: {contract_id}, Progress: {progress}") | |
| os.unlink(temp_path) | |
| return status, key_data, risks, progress | |
| def search_contracts(query): | |
| """Search contract repository.""" | |
| results = {cid: data for cid, data in contract_data.items() if query.lower() in str(data).lower()} | |
| return results if results else {"No matches": "No contracts found matching the query."} | |
| # Gradio UI | |
| with gr.Blocks(title="Contract Intelligence App") as demo: | |
| with gr.Row(): | |
| file_input = gr.File(type="binary", file_types=["pdf"], file_count="multiple", label="Upload Contracts") | |
| upload_progress = gr.Textbox(label="Progress", value="0/0", interactive=False) | |
| object_type = gr.Dropdown(choices=["Contract", "Agreement", "Invoice"], label="Select Object Type") | |
| process_button = gr.Button("Process Contracts") | |
| status_output = gr.Textbox(label="Status", interactive=False) | |
| extracted_data_output = gr.JSON(label="Extracted Data") | |
| clauses_output = gr.JSON(label="Extracted Clauses") | |
| risks_output = gr.Textbox(label="Detected Risks", interactive=False) | |
| def process_and_display(files, obj_type): | |
| if not files: | |
| return "❌ No files uploaded.", {}, {}, "No risks detected", gr.update(value="0/0") | |
| results = [] | |
| all_data = {} | |
| all_clauses = {} | |
| all_risks = [] | |
| for i, file in enumerate(files): | |
| status, data, risks, _ = process_contract(file, obj_type) | |
| clauses = contract_data.get(f"Contract_{len(contract_data)}", {}).get("clauses", {}) # Handle empty contract_data | |
| results.append(f"{status} - File: File_{i}") | |
| all_data.update({f"File_{i}": data}) | |
| all_clauses.update({f"File_{i}": clauses}) | |
| all_risks.extend(risks) | |
| progress = f"{len(files)}/{len(files)}" | |
| return "\n".join(results), all_data, all_clauses, "\n".join(all_risks) if all_risks else "No risks detected", gr.update(value=progress) | |
| process_button.click( | |
| fn=process_and_display, | |
| inputs=[file_input, object_type], | |
| outputs=[status_output, extracted_data_output, clauses_output, risks_output, upload_progress] | |
| ) | |
| with gr.Tab("Contract Repository"): | |
| search_query = gr.Textbox(label="Search Contracts", placeholder="Enter keyword...") | |
| search_results = gr.JSON(label="Search Results") | |
| search_button = gr.Button("Search") | |
| search_button.click( | |
| fn=search_contracts, | |
| inputs=search_query, | |
| outputs=search_results | |
| ) | |
| demo.launch() |