Spaces:
Sleeping
Sleeping
File size: 6,578 Bytes
428bcb4 f214078 49c58d2 428bcb4 49c58d2 428bcb4 49c58d2 428bcb4 49c58d2 428bcb4 a15a51f 428bcb4 a15a51f 428bcb4 49c58d2 428bcb4 49c58d2 428bcb4 49c58d2 428bcb4 49c58d2 428bcb4 00ffeaa 428bcb4 49c58d2 428bcb4 49c58d2 428bcb4 49c58d2 428bcb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gradio as gr
from transformers import LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification
import torch
from PIL import Image
import os
import tempfile
from tqdm import tqdm
import re
from ai_mapping import extract_key_values_with_layoutlm, run_ai_mapping_with_layoutlm, extract_clauses
from ocr_utils import extract_text_from_pdf_with_tesseract_or_layoutlm
from salesforce_utils import get_token, create_or_update_record
# Initialize global state
contract_data = {} # In-memory contract repository
processed_files = 0
total_files = 0
# Load pre-trained LayoutLMv3 model and tokenizer (placeholder for future fine-tuning)
tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
def save_temp_file(pdf_bytes):
"""Save PDF bytes to a temporary file and return the path."""
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
tmp.write(pdf_bytes)
return tmp.name
def detect_risks(data):
"""Detect risks (e.g., missing dates, large amounts)."""
risks = []
if not data.get("Agreement Start Date") and not data.get("Agreement End Date"):
risks.append("No agreement dates detected - potential obligation risk.")
if data.get("Amount") and float(data.get("Amount", "0").replace('$', '').replace(',', '')) > 1000000:
risks.append("Large amount detected - review for financial risk.")
return risks
def process_contract(pdf_bytes, object_type):
"""Process contract and simulate CCI workflow."""
global processed_files, total_files
total_files = 1
processed_files = 0
print("Received file - Starting processing")
temp_path = save_temp_file(pdf_bytes)
print(f"Temporary file created at: {temp_path}")
page_data = extract_text_from_pdf_with_tesseract_or_layoutlm(temp_path)
print(f"OCR result pages: {len(page_data)}")
if not page_data or all("No text detected" in page["text"] for page in page_data):
os.unlink(temp_path)
print("No text extracted from PDF.")
return "❌ No text extracted from PDF.", {}, [], "0/1"
print("Extracting key data")
key_data = extract_key_values_with_layoutlm(page_data, temp_path)
print(f"Key data extracted: {key_data}")
if "status" in key_data and key_data["status"] == "failed":
os.unlink(temp_path)
print(f"Extraction failed: {key_data.get('error', 'Unknown error')}")
return f"❌ Extraction failed: {key_data.get('error', 'Unknown error')}", {}, [], "0/1"
print("Extracting clauses")
clauses = extract_clauses(page_data)
print(f"Extracted clauses: {clauses}")
print("Detecting risks")
risks = detect_risks(key_data)
print(f"Detected risks: {risks}")
status = "✅ Processed" if not risks else "⚠️ Processed with risks"
# Mock CLM fields with Salesforce-ready structure
clm_fields = {"Name": f"Contract_{len(contract_data) + 1}", "Type__c": object_type, "Status__c": status}
clm_fields.update({k: v for k, v in key_data.items() if k not in ["status", "error", "key_values"]})
for clause_name, clause_text in clauses.items():
clm_fields[f"{clause_name}_Text__c"] = clause_text
# Optional Salesforce sync
try:
token, instance_url = get_token()
sf_response = create_or_update_record(f"{object_type}__c", clm_fields, token, instance_url)
if "error" in sf_response:
print(f"Salesforce sync failed: {sf_response['error']}")
else:
print(f"Salesforce sync successful: {sf_response}")
except Exception as e:
print(f"Salesforce sync error: {str(e)}")
contract_id = f"Contract_{len(contract_data) + 1}"
contract_data[contract_id] = {
"data": key_data,
"clauses": clauses,
"risks": risks,
"clm_fields": clm_fields,
"status": status
}
processed_files = 1
progress = "1/1"
print(f"Processing completed - ID: {contract_id}, Progress: {progress}")
os.unlink(temp_path)
return status, key_data, risks, progress
def search_contracts(query):
"""Search contract repository."""
results = {cid: data for cid, data in contract_data.items() if query.lower() in str(data).lower()}
return results if results else {"No matches": "No contracts found matching the query."}
# Gradio UI
with gr.Blocks(title="Contract Intelligence App") as demo:
with gr.Row():
file_input = gr.File(type="binary", file_types=["pdf"], file_count="multiple", label="Upload Contracts")
upload_progress = gr.Textbox(label="Progress", value="0/0", interactive=False)
object_type = gr.Dropdown(choices=["Contract", "Agreement", "Invoice"], label="Select Object Type")
process_button = gr.Button("Process Contracts")
status_output = gr.Textbox(label="Status", interactive=False)
extracted_data_output = gr.JSON(label="Extracted Data")
clauses_output = gr.JSON(label="Extracted Clauses")
risks_output = gr.Textbox(label="Detected Risks", interactive=False)
def process_and_display(files, obj_type):
if not files:
return "❌ No files uploaded.", {}, {}, "No risks detected", gr.update(value="0/0")
results = []
all_data = {}
all_clauses = {}
all_risks = []
for i, file in enumerate(files):
status, data, risks, _ = process_contract(file, obj_type)
clauses = contract_data.get(f"Contract_{len(contract_data)}", {}).get("clauses", {}) # Handle empty contract_data
results.append(f"{status} - File: File_{i}")
all_data.update({f"File_{i}": data})
all_clauses.update({f"File_{i}": clauses})
all_risks.extend(risks)
progress = f"{len(files)}/{len(files)}"
return "\n".join(results), all_data, all_clauses, "\n".join(all_risks) if all_risks else "No risks detected", gr.update(value=progress)
process_button.click(
fn=process_and_display,
inputs=[file_input, object_type],
outputs=[status_output, extracted_data_output, clauses_output, risks_output, upload_progress]
)
with gr.Tab("Contract Repository"):
search_query = gr.Textbox(label="Search Contracts", placeholder="Enter keyword...")
search_results = gr.JSON(label="Search Results")
search_button = gr.Button("Search")
search_button.click(
fn=search_contracts,
inputs=search_query,
outputs=search_results
)
demo.launch() |