Komal133 commited on
Commit
c76c941
·
verified ·
1 Parent(s): dcbd7b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -153
app.py CHANGED
@@ -1,166 +1,133 @@
1
- import gradio as gr
2
- import PyPDF2
3
- import nltk
 
 
 
 
 
4
  import seaborn as sns
5
  import matplotlib.pyplot as plt
6
- from reportlab.lib.pagesizes import letter
7
- from reportlab.pdfgen import canvas
8
- import json
9
- import os
10
- from io import BytesIO
11
- import numpy as np
12
- import logging
13
-
14
- # Set up logging
15
- logging.basicConfig(level=logging.INFO)
16
- logger = logging.getLogger(__name__)
17
-
18
- # Download NLTK data
19
- nltk.download('punkt')
20
-
21
- # Clause types and risk scoring logic
22
- CLAUSE_TYPES = ["penalty", "obligation", "delay"]
23
- RISK_WEIGHTS = {"penalty": 0.8, "obligation": 0.5, "delay": 0.6}
24
-
25
- # Keyword-based heuristic for clause classification
26
- KEYWORD_MAP = {
27
- "penalty": ["penalty", "fee", "fine", "charge", "incur"],
28
- "obligation": ["shall", "must", "obligated", "required", "responsible"],
29
- "delay": ["delay", "late", "beyond", "postpone", "deferred"]
30
- }
31
 
 
 
 
 
 
 
 
 
 
 
 
32
  def extract_text_from_pdf(pdf_file):
33
- """Extract text from uploaded PDF file."""
34
- try:
35
- reader = PyPDF2.PdfReader(pdf_file)
36
- text = ""
37
- for page in reader.pages:
38
- page_text = page.extract_text() or ""
39
- text += page_text + "\n"
40
- logger.info(f"Extracted text length: {len(text)} characters")
41
- logger.debug(f"Extracted text sample: {text[:500]}")
42
- if not text.strip():
43
- return "Error: No text extracted from PDF."
44
- return text
45
- except Exception as e:
46
- logger.error(f"Text extraction error: {str(e)}")
47
- return f"Error extracting text: {str(e)}"
48
-
49
- def parse_contract(text):
50
- """Parse contract text into clauses and classify risks using keyword-based heuristic."""
51
- # Clean text: replace multiple newlines with single, handle LaTeX artifacts
52
- text = text.replace("\n\n", "\n").replace("\t", " ")
53
- sentences = nltk.sent_tokenize(text)
54
- logger.info(f"Number of sentences tokenized: {len(sentences)}")
55
- logger.debug(f"Sample sentences: {sentences[:3]}")
56
-
57
  results = []
58
- risk_scores = []
59
-
60
- for idx, sentence in enumerate(sentences):
61
- sentence = sentence.strip()
62
- if len(sentence) < 10: # Skip short sentences
63
- logger.debug(f"Skipping short sentence (length {len(sentence)}): {sentence}")
64
- continue
65
-
66
- # Heuristic classification based on keywords
67
- sentence_lower = sentence.lower()
68
- clause_type = None
69
- for c_type, keywords in KEYWORD_MAP.items():
70
- if any(keyword in sentence_lower for keyword in keywords):
71
- clause_type = c_type
72
- break
73
-
74
- if clause_type not in CLAUSE_TYPES:
75
- logger.debug(f"No relevant clause type for sentence {idx}: {sentence}")
76
- continue
77
-
78
- # Assign a dummy score based on keyword presence (simulating model confidence)
79
- score = RISK_WEIGHTS[clause_type] * 0.9 # 0.9 as a dummy confidence score
80
  results.append({
81
- "clause_id": idx,
82
- "text": sentence,
83
- "clause_type": clause_type,
84
- "risk_score": round(score, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  })
86
- risk_scores.append(score)
87
- logger.info(f"Detected clause {idx}: {clause_type} with risk score {score}")
88
-
89
- return results, risk_scores
90
 
91
- def generate_heatmap(risk_scores):
92
- """Generate heatmap for risk scores."""
93
- if not risk_scores:
94
- logger.warning("No risk scores to generate heatmap.")
95
- return None
96
- data = np.array(risk_scores).reshape(1, -1)
97
  plt.figure(figsize=(10, 2))
98
- sns.heatmap(data, cmap="YlOrRd", annot=True, fmt=".2f", cbar_kws={'label': 'Risk Score'})
99
- plt.title("Contract Risk Heatmap")
100
- plt.xlabel("Clause Index")
101
- plt.ylabel("Risk")
102
- buffer = BytesIO()
103
- plt.savefig(buffer, format="png", bbox_inches="tight")
104
- plt.close()
105
- buffer.seek(0)
106
- return buffer
107
-
108
- def generate_pdf_report(results, heatmap_buffer):
109
- """Generate PDF report with summary and heatmap."""
110
- buffer = BytesIO()
111
- c = canvas.Canvas(buffer, pagesize=letter)
112
- c.setFont("Helvetica", 12)
113
- c.drawString(50, 750, "Contract Risk Analysis Report")
114
 
115
- # Summary
116
- c.drawString(50, 720, "Summary of Risk-Prone Clauses:")
117
- y = 700
118
- for result in results[:5]: # Limit to top 5 for brevity
119
- text = f"Clause {result['clause_id']}: {result['clause_type'].capitalize()} (Risk: {result['risk_score']})"
120
- c.drawString(50, y, text[:80] + "..." if len(text) > 80 else text)
121
- y -= 20
122
-
123
- # Embed heatmap
124
- if heatmap_buffer:
125
- c.drawImage(BytesIO(heatmap_buffer.read()), 50, y-200, width=500, height=100)
126
 
127
- c.showPage()
128
- c.save()
129
- buffer.seek(0)
130
- return buffer
131
-
132
- def process_contract(pdf_file):
133
- """Main function to process uploaded contract."""
134
- # Extract text
135
- text = extract_text_from_pdf(pdf_file)
136
- if "Error" in text:
137
- return text, None, None, {"Error": text}
138
-
139
- # Parse and classify
140
- results, risk_scores = parse_contract(text)
141
- if not results:
142
- return "No relevant clauses detected.", None, None, {"Summary": "No risk-prone clauses found."}
143
-
144
- # Generate outputs
145
- json_output = json.dumps(results, indent=2)
146
- heatmap_buffer = generate_heatmap(risk_scores)
147
- pdf_report = generate_pdf_report(results, heatmap_buffer)
148
-
149
- return json_output, heatmap_buffer, pdf_report, {"Summary": f"Detected {len(results)} risk-prone clauses."}
150
-
151
- # Gradio interface
152
- iface = gr.Interface(
153
- fn=process_contract,
154
- inputs=gr.File(label="Upload Contract PDF"),
155
- outputs=[
156
- gr.Textbox(label="JSON Output"),
157
- gr.Image(label="Risk Heatmap"),
158
- gr.File(label="Download PDF Report"),
159
- gr.JSON(label="Summary")
160
- ],
161
- title="Contract Risk Analyzer",
162
- description="Upload a contract PDF to analyze risk-prone clauses and visualize results."
163
- )
164
 
165
  if __name__ == "__main__":
166
- iface.launch()
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from transformers import BertTokenizer, BertForSequenceClassification
5
+ from simple_salesforce import Salesforce
6
+ import torch
7
+ from PyPDF2 import PdfReader
8
+ import re
9
  import seaborn as sns
10
  import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Salesforce connection
13
+ def connect_to_salesforce():
14
+ sf = Salesforce(
15
+ username='your_username',
16
+ password='your_password',
17
+ security_token='your_security_token',
18
+ domain='login' # or 'test' for sandbox
19
+ )
20
+ return sf
21
+
22
+ # Extract text from PDF
23
  def extract_text_from_pdf(pdf_file):
24
+ reader = PdfReader(pdf_file)
25
+ text = ""
26
+ for page in reader.pages:
27
+ text += page.extract_text() + "\n"
28
+ return text
29
+
30
+ # Split text into clauses
31
+ def split_into_clauses(text):
32
+ clauses = re.split(r'\n\s*\d+\.\s*|\n\s*[A-Z]\.\s*', text)
33
+ clauses = [clause.strip() for clause in clauses if clause.strip()]
34
+ return clauses
35
+
36
+ # Load BERT model and tokenizer
37
+ @st.cache_resource
38
+ def load_model():
39
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
40
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3) # Fine-tuned for 3 risk levels
41
+ return tokenizer, model
42
+
43
+ # Process clauses and assign risk scores
44
+ def process_clauses(clauses, tokenizer, model):
 
 
 
45
  results = []
46
+ risk_levels = {0: 'Low', 1: 'Medium', 2: 'High'}
47
+
48
+ for clause in clauses:
49
+ inputs = tokenizer(clause, return_tensors="pt", truncation=True, padding=True, max_length=512)
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+ logits = outputs.logits
53
+ risk_score = torch.softmax(logits, dim=1).numpy()[0]
54
+ risk_level = risk_levels[np.argmax(risk_score)]
55
+
 
 
 
 
 
 
 
 
 
 
 
 
56
  results.append({
57
+ 'clause_text': clause,
58
+ 'risk_level': risk_level,
59
+ 'severity_score': float(np.max(risk_score)),
60
+ 'clause_type': infer_clause_type(clause) # Simplified clause type inference
61
+ })
62
+
63
+ return results
64
+
65
+ # Simplified clause type inference (extend with more sophisticated logic as needed)
66
+ def infer_clause_type(clause):
67
+ if 'liability' in clause.lower():
68
+ return 'Liability'
69
+ elif 'payment' in clause.lower():
70
+ return 'Payment'
71
+ else:
72
+ return 'General'
73
+
74
+ # Save results to Salesforce
75
+ def save_to_salesforce(sf, results, contract_id):
76
+ for result in results:
77
+ sf.Contract_Risk__c.create({
78
+ 'Contract__c': contract_id,
79
+ 'Clause_Text__c': result['clause_text'][:255], # Truncate if needed
80
+ 'Risk_Level__c': result['risk_level'],
81
+ 'Severity_Score__c': result['severity_score'],
82
+ 'Clause_Type__c': result['clause_type']
83
  })
 
 
 
 
84
 
85
+ # Generate heatmap
86
+ def generate_heatmap(results):
87
+ df = pd.DataFrame(results)
88
+ risk_scores = df['severity_score'].values
 
 
89
  plt.figure(figsize=(10, 2))
90
+ sns.heatmap([risk_scores], cmap='RdYlGn_r', annot=True, fmt='.2f', cbar_kws={'label': 'Risk Severity'})
91
+ plt.title('Contract Clause Risk Heatmap')
92
+ plt.xlabel('Clause Index')
93
+ plt.yticks([])
94
+ st.pyplot(plt)
95
+
96
+ # Streamlit interface
97
+ def main():
98
+ st.title("Contract Risk Analyzer")
 
 
 
 
 
 
 
99
 
100
+ # File upload
101
+ uploaded_file = st.file_uploader("Upload Contract PDF", type=["pdf"])
102
+ contract_id = st.text_input("Enter Contract ID")
 
 
 
 
 
 
 
 
103
 
104
+ if uploaded_file and contract_id:
105
+ # Extract and process text
106
+ text = extract_text_from_pdf(uploaded_file)
107
+ clauses = split_into_clauses(text)
108
+
109
+ # Load model and process clauses
110
+ tokenizer, model = load_model()
111
+ results = process_clauses(clauses, tokenizer, model)
112
+
113
+ # Display results
114
+ st.subheader("Clause Analysis Results")
115
+ for i, result in enumerate(results, 1):
116
+ st.write(f"**Clause {i}**")
117
+ st.write(f"Text: {result['clause_text'][:100]}...")
118
+ st.write(f"Clause Type: {result['clause_type']}")
119
+ st.write(f"Risk Level: {result['risk_level']}")
120
+ st.write(f"Severity Score: {result['severity_score']:.2f}")
121
+ st.write("---")
122
+
123
+ # Generate and display heatmap
124
+ generate_heatmap(results)
125
+
126
+ # Save to Salesforce
127
+ if st.button("Save to Salesforce"):
128
+ sf = connect_to_salesforce()
129
+ save_to_salesforce(sf, results, contract_id)
130
+ st.success("Results saved to Salesforce!")
 
 
 
 
 
 
 
 
 
 
131
 
132
  if __name__ == "__main__":
133
+ main()