pavansuresh's picture
Update app.py
00ffeaa verified
import gradio as gr
from transformers import LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification
import torch
from PIL import Image
import os
import tempfile
from tqdm import tqdm
import re
from ai_mapping import extract_key_values_with_layoutlm, run_ai_mapping_with_layoutlm, extract_clauses
from ocr_utils import extract_text_from_pdf_with_tesseract_or_layoutlm
from salesforce_utils import get_token, create_or_update_record
# Initialize global state
contract_data = {} # In-memory contract repository
processed_files = 0
total_files = 0
# Load pre-trained LayoutLMv3 model and tokenizer (placeholder for future fine-tuning)
tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
def save_temp_file(pdf_bytes):
"""Save PDF bytes to a temporary file and return the path."""
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
tmp.write(pdf_bytes)
return tmp.name
def detect_risks(data):
"""Detect risks (e.g., missing dates, large amounts)."""
risks = []
if not data.get("Agreement Start Date") and not data.get("Agreement End Date"):
risks.append("No agreement dates detected - potential obligation risk.")
if data.get("Amount") and float(data.get("Amount", "0").replace('$', '').replace(',', '')) > 1000000:
risks.append("Large amount detected - review for financial risk.")
return risks
def process_contract(pdf_bytes, object_type):
"""Process contract and simulate CCI workflow."""
global processed_files, total_files
total_files = 1
processed_files = 0
print("Received file - Starting processing")
temp_path = save_temp_file(pdf_bytes)
print(f"Temporary file created at: {temp_path}")
page_data = extract_text_from_pdf_with_tesseract_or_layoutlm(temp_path)
print(f"OCR result pages: {len(page_data)}")
if not page_data or all("No text detected" in page["text"] for page in page_data):
os.unlink(temp_path)
print("No text extracted from PDF.")
return "❌ No text extracted from PDF.", {}, [], "0/1"
print("Extracting key data")
key_data = extract_key_values_with_layoutlm(page_data, temp_path)
print(f"Key data extracted: {key_data}")
if "status" in key_data and key_data["status"] == "failed":
os.unlink(temp_path)
print(f"Extraction failed: {key_data.get('error', 'Unknown error')}")
return f"❌ Extraction failed: {key_data.get('error', 'Unknown error')}", {}, [], "0/1"
print("Extracting clauses")
clauses = extract_clauses(page_data)
print(f"Extracted clauses: {clauses}")
print("Detecting risks")
risks = detect_risks(key_data)
print(f"Detected risks: {risks}")
status = "✅ Processed" if not risks else "⚠️ Processed with risks"
# Mock CLM fields with Salesforce-ready structure
clm_fields = {"Name": f"Contract_{len(contract_data) + 1}", "Type__c": object_type, "Status__c": status}
clm_fields.update({k: v for k, v in key_data.items() if k not in ["status", "error", "key_values"]})
for clause_name, clause_text in clauses.items():
clm_fields[f"{clause_name}_Text__c"] = clause_text
# Optional Salesforce sync
try:
token, instance_url = get_token()
sf_response = create_or_update_record(f"{object_type}__c", clm_fields, token, instance_url)
if "error" in sf_response:
print(f"Salesforce sync failed: {sf_response['error']}")
else:
print(f"Salesforce sync successful: {sf_response}")
except Exception as e:
print(f"Salesforce sync error: {str(e)}")
contract_id = f"Contract_{len(contract_data) + 1}"
contract_data[contract_id] = {
"data": key_data,
"clauses": clauses,
"risks": risks,
"clm_fields": clm_fields,
"status": status
}
processed_files = 1
progress = "1/1"
print(f"Processing completed - ID: {contract_id}, Progress: {progress}")
os.unlink(temp_path)
return status, key_data, risks, progress
def search_contracts(query):
"""Search contract repository."""
results = {cid: data for cid, data in contract_data.items() if query.lower() in str(data).lower()}
return results if results else {"No matches": "No contracts found matching the query."}
# Gradio UI
with gr.Blocks(title="Contract Intelligence App") as demo:
with gr.Row():
file_input = gr.File(type="binary", file_types=["pdf"], file_count="multiple", label="Upload Contracts")
upload_progress = gr.Textbox(label="Progress", value="0/0", interactive=False)
object_type = gr.Dropdown(choices=["Contract", "Agreement", "Invoice"], label="Select Object Type")
process_button = gr.Button("Process Contracts")
status_output = gr.Textbox(label="Status", interactive=False)
extracted_data_output = gr.JSON(label="Extracted Data")
clauses_output = gr.JSON(label="Extracted Clauses")
risks_output = gr.Textbox(label="Detected Risks", interactive=False)
def process_and_display(files, obj_type):
if not files:
return "❌ No files uploaded.", {}, {}, "No risks detected", gr.update(value="0/0")
results = []
all_data = {}
all_clauses = {}
all_risks = []
for i, file in enumerate(files):
status, data, risks, _ = process_contract(file, obj_type)
clauses = contract_data.get(f"Contract_{len(contract_data)}", {}).get("clauses", {}) # Handle empty contract_data
results.append(f"{status} - File: File_{i}")
all_data.update({f"File_{i}": data})
all_clauses.update({f"File_{i}": clauses})
all_risks.extend(risks)
progress = f"{len(files)}/{len(files)}"
return "\n".join(results), all_data, all_clauses, "\n".join(all_risks) if all_risks else "No risks detected", gr.update(value=progress)
process_button.click(
fn=process_and_display,
inputs=[file_input, object_type],
outputs=[status_output, extracted_data_output, clauses_output, risks_output, upload_progress]
)
with gr.Tab("Contract Repository"):
search_query = gr.Textbox(label="Search Contracts", placeholder="Enter keyword...")
search_results = gr.JSON(label="Search Results")
search_button = gr.Button("Search")
search_button.click(
fn=search_contracts,
inputs=search_query,
outputs=search_results
)
demo.launch()