pavansuresh commited on
Commit
9f54a59
·
verified ·
1 Parent(s): ec09821

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -97
app.py CHANGED
@@ -1,109 +1,166 @@
1
- from transformers import LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification, LayoutLMv3ImageProcessor
 
2
  import torch
3
  from PIL import Image
4
- import fitz # PyMuPDF
5
- from typing import Dict, List
6
  import os
7
- from huggingface_hub import login
 
 
8
  import re
 
 
 
9
 
10
- # Optional: Log in to Hugging Face if using a private model
11
- # login(token="your_hf_token")
 
 
12
 
13
- # Load pre-trained LayoutLMv3 models
14
  tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
15
- feature_extractor = LayoutLMv3ImageProcessor(apply_ocr=False)
16
  model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
17
 
18
- def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str, str]:
19
- """
20
- Extract key-value pairs from PDF text using LayoutLMv3-base or fallback to regex.
21
- Args:
22
- text_data (str): Extracted text from PDF.
23
- pdf_path (str): Path to the PDF file.
24
- Returns:
25
- dict: Key-value pairs extracted from the document.
26
- """
27
  try:
28
- # Fallback to regex if model is untrained
29
- key_values = {}
30
- dates = re.findall(r'\d{1,2}/\d{1,2}/\d{4}', text_data)
31
- amounts = re.findall(r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?', text_data)
32
- if dates or amounts:
33
- key_values.update({"Date bangs": dates[0] if dates else "", "Amount": amounts[0] if amounts else ""})
34
-
35
- # Attempt LayoutLMv3 processing
36
- doc = fit ভাz.open(pdf_path)
37
- for page_num in range(len(doc)):
38
- page = doc[page_num]
39
- pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI
40
- img_path = f"{pdf_path}_page_{page_num}.png"
41
- pix.save(img_path)
42
- image = Image.open(img_path)
43
-
44
- encoding = feature_extractor(images=[image], text=text_data.splitlines(), return_tensors="pt")
45
- input_ids = encoding["input_ids"]
46
- attention_mask = encoding["attention_mask"]
47
-
48
- with torch.no_grad():
49
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
50
- predictions = torch.argmax(outputs.logits, dim=2)
51
-
52
- tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
53
- labels = predictions[0].tolist()
54
- current_key = None
55
- current_value = []
56
- for token, label in zip(tokens, labels):
57
- if label == 1: # Key start (adjust based on training)
58
- if current_key and current_value:
59
- key_values[current_key] = " ".join(current_value).strip()
60
- current_key = token
61
- current_value = []
62
- elif label == 2 and current_key: # Value (adjust based on training)
63
- current_value.append(token)
64
- if current_key and current_value:
65
- key_values[current_key] = " ".join(current_value).strip()
66
-
67
- # Clean up temporary image
68
- if os.path.exists(img_path):
69
- os.unlink(img_path)
70
-
71
- doc.close()
72
- return key_values if key_values else {"status": "failed", "error": "No key-value pairs extracted", "key_values": {}}
73
- except Exception as e:
74
- return {"status": "failed", "error": str(e), "key_values": {}}
75
-
76
- def run_ai_mapping_with_layoutlm(key_values: Dict[str, str], object_field_names: List[str], pdf_path: str) -> Dict:
77
- """
78
- Map extracted key-values to object fields using LayoutLMv3-base (simplified).
79
- Args:
80
- key_values (dict): Extracted key-value pairs.
81
- object_field_names (list): List of object field names.
82
- pdf_path (str): Path to the PDF file (for context if needed).
83
- Returns:
84
- dict: Mapping results with status, mappings, unmapped fields, and error (if any).
85
- """
 
 
 
 
 
 
 
 
 
86
  try:
87
- mappings = {}
88
- unmapped_fields = object_field_names.copy()
89
-
90
- for field in object_field_names:
91
- for key, value in key_values.items():
92
- if field.lower() in key.lower() or any(k.lower() in field.lower() for k in key_values.keys()):
93
- mappings[field] = value
94
- unmapped_fields.remove(field)
95
- break
96
-
97
- return {
98
- "status": "success",
99
- "mappings": mappings,
100
- "unmapped_fields": unmapped_fields,
101
- "error": None
102
- }
103
  except Exception as e:
104
- return {
105
- "status": "failed",
106
- "error": str(e),
107
- "mappings": {},
108
- "unmapped_fields": object_field_names
109
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification
3
  import torch
4
  from PIL import Image
 
 
5
  import os
6
+ import tempfile
7
+ from tqdm import tqdm
8
+ import subprocess
9
  import re
10
+ from ai_mapping import extract_key_values_with_layoutlm, run_ai_mapping_with_layoutlm
11
+ from ocr_utils import extract_text_from_pdf_with_tesseract_or_layoutlm
12
+ from salesforce_utils import get_token, create_or_update_record
13
 
14
+ # Initialize global state
15
+ contract_data = {} # In-memory contract repository
16
+ processed_files = 0
17
+ total_files = 0
18
 
19
+ # Load pre-trained LayoutLMv3 model and tokenizer (placeholder for future fine-tuning)
20
  tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
 
21
  model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
22
 
23
+ def check_poppler():
24
+ """Check if poppler-utils is installed."""
 
 
 
 
 
 
 
25
  try:
26
+ subprocess.run(['pdftoppm', '-v'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
27
+ return True
28
+ except FileNotFoundError:
29
+ return False
30
+
31
+ def check_tesseract():
32
+ """Check if tesseract-ocr is installed."""
33
+ try:
34
+ subprocess.run(['tesseract', '-v'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
35
+ return True
36
+ except FileNotFoundError:
37
+ return False
38
+
39
+ def save_temp_file(pdf_bytes):
40
+ """Save PDF bytes to a temporary file and return the path."""
41
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
42
+ tmp.write(pdf_bytes)
43
+ return tmp.name
44
+
45
+ def detect_risks(data):
46
+ """Detect risks (e.g., missing dates, large amounts)."""
47
+ risks = []
48
+ if not data.get("dates", []):
49
+ risks.append("No expiration date detected - potential obligation risk.")
50
+ if any(float(amount.replace('$', '').replace(',', '')) > 1000000 for amount in data.get("amounts", [])):
51
+ risks.append("Large amount detected - review for financial risk.")
52
+ return risks
53
+
54
+ def process_contract(pdf_bytes, object_type):
55
+ """Process contract and simulate CCI workflow."""
56
+ global processed_files, total_files
57
+ total_files = 1
58
+ processed_files = 0
59
+
60
+ print("Received file - Starting processing")
61
+ if not check_poppler() or not check_tesseract():
62
+ error_msg = "Error: Required dependencies missing. Install poppler-utils (e.g., 'sudo apt-get install poppler-utils') and tesseract-ocr (e.g., 'sudo apt-get install tesseract-ocr')."
63
+ print(error_msg)
64
+ return error_msg, {}, [], "0/1"
65
+
66
+ temp_path = save_temp_file(pdf_bytes)
67
+ print(f"Temporary file created at: {temp_path}")
68
+ text = extract_text_from_pdf_with_tesseract_or_layoutlm(temp_path)
69
+ print(f"OCR result length: {len(text)}")
70
+ if isinstance(text, str) and not text.strip():
71
+ os.unlink(temp_path)
72
+ print("No text extracted from PDF.")
73
+ return "❌ No text extracted from PDF.", {}, [], "0/1"
74
+
75
+ print("Extracting key data")
76
+ key_data = extract_key_values_with_layoutlm(text, temp_path)
77
+ print(f"Key data extracted: {key_data}")
78
+ if "status" in key_data and key_data["status"] == "failed":
79
+ os.unlink(temp_path)
80
+ print(f"Extraction failed: {key_data.get('error', 'Unknown error')}")
81
+ return f"❌ Extraction failed: {key_data.get('error', 'Unknown error')}", {}, [], "0/1"
82
+
83
+ print("Detecting risks")
84
+ risks = detect_risks(key_data)
85
+ print(f"Detected risks: {risks}")
86
+ status = "✅ Processed" if not risks else "⚠️ Processed with risks"
87
+
88
+ # Mock CLM fields with Salesforce-ready structure
89
+ clm_fields = {"Name": f"Contract_{len(contract_data) + 1}", "Type__c": object_type, "Status__c": status}
90
+ clm_fields.update({k: v for k, v in key_data.items() if k not in ["status", "error", "key_values"]})
91
+
92
+ # Optional Salesforce sync
93
  try:
94
+ token, instance_url = get_token()
95
+ sf_response = create_or_update_record(f"{object_type}__c", clm_fields, token, instance_url)
96
+ if "error" in sf_response:
97
+ print(f"Salesforce sync failed: {sf_response['error']}")
98
+ else:
99
+ print(f"Salesforce sync successful: {sf_response}")
 
 
 
 
 
 
 
 
 
 
100
  except Exception as e:
101
+ print(f"Salesforce sync error: {str(e)}")
102
+
103
+ contract_id = f"Contract_{len(contract_data) + 1}"
104
+ contract_data[contract_id] = {
105
+ "data": key_data,
106
+ "risks": risks,
107
+ "clm_fields": clm_fields,
108
+ "status": status
109
+ }
110
+ processed_files = 1
111
+ progress = "1/1"
112
+ print(f"Processing completed - ID: {contract_id}, Progress: {progress}")
113
+ os.unlink(temp_path)
114
+
115
+ return status, key_data, risks, progress
116
+
117
+ def search_contracts(query):
118
+ """Search contract repository."""
119
+ results = {cid: data for cid, data in contract_data.items() if query.lower() in str(data).lower()}
120
+ return results if results else {"No matches": "No contracts found matching the query."}
121
+
122
+ # Gradio UI
123
+ with gr.Blocks(title="Contract Intelligence App") as demo:
124
+ with gr.Row():
125
+ file_input = gr.File(type="binary", file_types=["pdf"], file_count="multiple", label="Upload Contracts")
126
+ upload_progress = gr.Textbox(label="Progress", value="0/0", interactive=False)
127
+
128
+ object_type = gr.Dropdown(choices=["Contract", "Agreement", "Invoice"], label="Select Object Type")
129
+
130
+ process_button = gr.Button("Process Contracts")
131
+ status_output = gr.Textbox(label="Status", interactive=False)
132
+ extracted_data_output = gr.JSON(label="Extracted Data")
133
+ risks_output = gr.Textbox(label="Detected Risks", interactive=False)
134
+
135
+ def process_and_display(files, obj_type):
136
+ if not files:
137
+ return "❌ No files uploaded.", {}, "No risks detected", gr.update(value="0/0")
138
+ results = []
139
+ all_data = {}
140
+ all_risks = []
141
+ for i, file in enumerate(files):
142
+ status, data, risks, _ = process_contract(file, obj_type)
143
+ results.append(f"{status} - File: File_{i}")
144
+ all_data.update({f"File_{i}": data})
145
+ all_risks.extend(risks)
146
+ progress = f"{len(files)}/{len(files)}"
147
+ return "\n".join(results), all_data, "\n".join(all_risks) if all_risks else "No risks detected", gr.update(value=progress)
148
+
149
+ process_button.click(
150
+ fn=process_and_display,
151
+ inputs=[file_input, object_type],
152
+ outputs=[status_output, extracted_data_output, risks_output, upload_progress]
153
+ )
154
+
155
+ with gr.Tab("Contract Repository"):
156
+ search_query = gr.Textbox(label="Search Contracts", placeholder="Enter keyword...")
157
+ search_results = gr.JSON(label="Search Results")
158
+ search_button = gr.Button("Search")
159
+
160
+ search_button.click(
161
+ fn=search_contracts,
162
+ inputs=search_query,
163
+ outputs=search_results
164
+ )
165
+
166
+ demo.launch()