Kushalmanda commited on
Commit
2d14f55
·
verified ·
1 Parent(s): 478583e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -36
app.py CHANGED
@@ -1,60 +1,211 @@
1
- import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers import BertTokenizer, BertForSequenceClassification
3
  import torch
4
- from flask import Flask, request, jsonify
5
 
6
- # Initialize Flask app
7
- app = Flask(__name__)
 
8
 
9
- # Load the pre-trained BERT model and tokenizer (Replace with your own fine-tuned model path)
10
- model = BertForSequenceClassification.from_pretrained('path_to_your_finetuned_model') # Replace with your model path
11
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # Replace with your tokenizer path if custom
 
 
12
 
13
- # Define the possible risk levels (These are the output classes from your fine-tuned BERT model)
14
- risk_labels = ["low", "medium", "high"]
 
15
 
16
- # Function to process contract text through the BERT model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def process_contract(contract_text):
18
- # Tokenize input contract text
19
  inputs = tokenizer(contract_text, return_tensors="pt", truncation=True, padding=True, max_length=512)
20
 
21
- # Perform inference with the BERT model to get the logits (raw prediction scores)
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
 
25
- # Extract logits (raw prediction scores)
26
  logits = outputs.logits
27
-
28
- # Predict the risk level (index of max logit score)
29
  predicted_class = torch.argmax(logits, dim=1).item()
30
 
31
- # Map prediction index to risk level
32
  risk_tag = risk_labels[predicted_class]
33
 
34
- # Return the predicted risk tag and the confidence score (highest logit)
35
- return {
36
- "risk_tag": risk_tag,
37
- "score": logits.max().item(), # Confidence score
38
- "raw_scores": logits.squeeze().tolist() # Optionally return the raw logits for more insight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Flask route to process the contract text (POST request)
42
- @app.route('/process_contract', methods=['POST'])
43
- def analyze_contract():
44
  try:
45
- # Get JSON data from the incoming POST request
46
- data = request.json
47
- contract_text = data['contract_text']
 
 
48
 
49
- # Process the contract text and extract risk details
50
- result = process_contract(contract_text)
51
 
52
- # Return the result as JSON
53
- return jsonify(result), 200
 
 
54
  except Exception as e:
55
- # Return error message in case of failure
56
- return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Run the Flask app
59
- if __name__ == '__main__':
60
- app.run(debug=True, host="0.0.0.0", port=5000)
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from io import BytesIO
6
+ import os
7
+ import logging
8
+ import base64
9
+ import shutil
10
+ import tempfile
11
+ from simple_salesforce import Salesforce
12
+ from reportlab.lib.pagesizes import letter
13
+ from reportlab.pdfgen import canvas
14
+ from reportlab.lib import colors
15
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image, Table, TableStyle
16
+ from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
17
  from transformers import BertTokenizer, BertForSequenceClassification
18
  import torch
 
19
 
20
+ # Configure logging to show detailed messages
21
+ logging.basicConfig(level=logging.DEBUG)
22
+ logger = logging.getLogger(__name__)
23
 
24
+ # Salesforce credentials (use environment variables in production)
25
+ SALESFORCE_USERNAME = os.getenv("SALESFORCE_USERNAME", "user@example.com")
26
+ SALESFORCE_PASSWORD = os.getenv("SALESFORCE_PASSWORD", "password")
27
+ SALESFORCE_SECURITY_TOKEN = os.getenv("SALESFORCE_SECURITY_TOKEN", "security_token")
28
+ SALESFORCE_DOMAIN = os.getenv("SALESFORCE_DOMAIN", "login")
29
 
30
+ # Load the BERT model and tokenizer for risk classification (fine-tuned for contract clauses)
31
+ model = BertForSequenceClassification.from_pretrained('path_to_finetuned_model') # Replace with your model path
32
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
33
 
34
+ # Function to authenticate with Salesforce
35
+ def get_salesforce_connection():
36
+ try:
37
+ sf = Salesforce(
38
+ username=SALESFORCE_USERNAME,
39
+ password=SALESFORCE_PASSWORD,
40
+ security_token=SALESFORCE_SECURITY_TOKEN,
41
+ domain=SALESFORCE_DOMAIN
42
+ )
43
+ return sf
44
+ except Exception as e:
45
+ logger.error(f"Failed to connect to Salesforce: {str(e)}", exc_info=True)
46
+ return None
47
+
48
+ # Function to parse contract and predict risk score using BERT model
49
  def process_contract(contract_text):
 
50
  inputs = tokenizer(contract_text, return_tensors="pt", truncation=True, padding=True, max_length=512)
51
 
 
52
  with torch.no_grad():
53
  outputs = model(**inputs)
54
 
 
55
  logits = outputs.logits
 
 
56
  predicted_class = torch.argmax(logits, dim=1).item()
57
 
58
+ risk_labels = ["low", "medium", "high"]
59
  risk_tag = risk_labels[predicted_class]
60
 
61
+ return risk_tag, logits.max().item()
62
+
63
+ # Function to generate a heatmap of the contract with section-wise risk levels
64
+ def generate_heatmap(contract_text):
65
+ # Assuming the contract is split into sections; this is a simplified approach
66
+ sections = contract_text.split("\n\n") # Split by paragraphs/sections
67
+ risks = []
68
+ for section in sections:
69
+ risk_tag, score = process_contract(section)
70
+ risks.append((section, risk_tag, score))
71
+
72
+ # Create a heatmap visualization
73
+ fig, ax = plt.subplots(figsize=(10, len(sections) * 0.5))
74
+ ax.barh(range(len(sections)), [r[2] for r in risks], color='red', height=0.4)
75
+
76
+ ax.set_yticks(range(len(sections)))
77
+ ax.set_yticklabels([r[0][:50] for r in risks]) # Display first 50 characters of each section as label
78
+ ax.set_xlabel('Risk Score')
79
+ ax.set_title('Risk Heatmap of Contract Sections')
80
+
81
+ # Adjust layout and return the figure
82
+ plt.tight_layout()
83
+ return fig
84
+
85
+ # Function to generate comprehensive PDF report
86
+ def generate_pdf_report(project_title, risk_tags, ai_plan_score, estimated_duration, location, weather, gantt_chart_path=None):
87
+ pdf_file = BytesIO()
88
+ doc = SimpleDocTemplate(pdf_file, pagesize=letter)
89
+ styles = getSampleStyleSheet()
90
+ elements = []
91
+
92
+ title_style = ParagraphStyle('Title', parent=styles['Heading1'], fontSize=18, alignment=1, spaceAfter=20)
93
+ elements.append(Paragraph(f"Project Report: {project_title}", title_style))
94
+
95
+ details_style = styles['BodyText']
96
+ details = [
97
+ f"<b>Location:</b> {location}",
98
+ f"<b>Weather:</b> {weather.capitalize()}",
99
+ f"<b>Estimated Duration:</b> {estimated_duration} days",
100
+ f"<b>AI Plan Score:</b> {ai_plan_score:.1f}%",
101
+ ]
102
+ for detail in details:
103
+ elements.append(Paragraph(detail, details_style))
104
+
105
+ elements.append(Spacer(1, 12))
106
+ elements.append(Paragraph("<b>Risk Assessment:</b>", styles['Heading2']))
107
+
108
+ for risk in risk_tags.split("\n"):
109
+ elements.append(Paragraph(f"• {risk}", details_style))
110
+
111
+ if gantt_chart_path:
112
+ elements.append(Spacer(1, 24))
113
+ elements.append(Paragraph("<b>Project Timeline:</b>", styles['Heading2']))
114
+ img = Image(gantt_chart_path, width=6 * inch, height=4 * inch)
115
+ elements.append(img)
116
+
117
+ doc.build(elements)
118
+ pdf_file.seek(0)
119
+ return pdf_file
120
+
121
+ # Function to upload the generated PDF to Salesforce
122
+ def upload_pdf_to_salesforce(pdf_file, project_title):
123
+ sf = get_salesforce_connection()
124
+ if not sf:
125
+ logger.error("Salesforce connection failed. Cannot upload PDF.")
126
+ return None, None
127
+
128
+ encoded_pdf_data = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
129
+ content_version_data = {
130
+ "Title": f"{project_title} - Comprehensive Report",
131
+ "PathOnClient": f"{project_title}_Report.pdf",
132
+ "VersionData": encoded_pdf_data,
133
  }
134
+ content_version = sf.ContentVersion.create(content_version_data)
135
+ content_version_id = content_version["id"]
136
+ result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version_id}'")
137
+ content_document_id = result['records'][0]['ContentDocumentId']
138
+ file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version_id}"
139
+ return content_version_id, file_url
140
+
141
+ # Function to log project data to Salesforce
142
+ def send_to_salesforce(project_title, gantt_chart_url, ai_plan_score, estimated_duration, risk_tags, status="Draft", record_id=None, location="", weather_type=""):
143
+ sf = get_salesforce_connection()
144
+ if not sf:
145
+ logger.error("Salesforce connection failed. Cannot proceed with record creation/update.")
146
+ return None
147
+
148
+ sf_data = {
149
+ "Name": project_title[:80],
150
+ "Project_Title__c": project_title,
151
+ "Estimated_Duration__c": estimated_duration,
152
+ "AI_Plan_Score__c": ai_plan_score,
153
+ "Status__c": status,
154
+ "Location__c": location,
155
+ "Weather_Type__c": weather_type,
156
+ "Risk_Tags__c": risk_tags,
157
+ }
158
+
159
+ if gantt_chart_url:
160
+ sf_data["Gantt_Chart_PDF__c"] = gantt_chart_url
161
+
162
+ if record_id:
163
+ sf.AI_Project_Timeline__c.update(record_id, sf_data)
164
+ return record_id
165
+ else:
166
+ project_record = sf.AI_Project_Timeline__c.create(sf_data)
167
+ return project_record['id']
168
 
169
+ # Gradio interface function
170
+ def gradio_interface(boq_file, weather, location, project_title):
 
171
  try:
172
+ if not boq_file:
173
+ return None, "Error: No BOQ file uploaded", None, None
174
+
175
+ fig = generate_heatmap(boq_file)
176
+ risk_tags = "Risk tags will be displayed here..." # Generate risk tags logic based on contract analysis
177
 
178
+ # Generating PDF report
179
+ pdf_report = generate_pdf_report(project_title, risk_tags, ai_plan_score=90, estimated_duration=30, location=location, weather=weather)
180
 
181
+ # Upload to Salesforce
182
+ pdf_content_id, pdf_url = upload_pdf_to_salesforce(pdf_report, project_title)
183
+
184
+ return fig, risk_tags, pdf_url, pdf_report
185
  except Exception as e:
186
+ logger.error(f"Error in Gradio interface: {str(e)}")
187
+ return None, f"Error in Gradio interface: {str(e)}", None, None
188
+
189
+ # Create Gradio interface
190
+ demo = gr.Blocks()
191
+ with demo:
192
+ gr.Markdown("## Contract Risk Analyzer")
193
+ gr.Markdown("Upload a contract, and the system will generate a heatmap and PDF report highlighting risk-prone clauses.")
194
+
195
+ with gr.Row():
196
+ with gr.Column():
197
+ contract_file = gr.File(label="Upload Contract (PDF or Text)")
198
+ weather = gr.Dropdown(label="Weather", choices=["sunny", "rainy", "cloudy"], value="sunny")
199
+ location = gr.Textbox(label="Location", placeholder="Enter project location")
200
+ project_title = gr.Textbox(label="Project Title", placeholder="Enter project title")
201
+ submit_btn = gr.Button("Analyze Contract")
202
+
203
+ with gr.Column():
204
+ plot_output = gr.Plot(label="Heatmap Visualization")
205
+ risk_tags_output = gr.Textbox(label="Risk Tags")
206
+ download_pdf = gr.File(label="Download Full Report (PDF)")
207
+
208
+ submit_btn.click(fn=gradio_interface, inputs=[contract_file, weather, location, project_title], outputs=[plot_output, risk_tags_output, download_pdf])
209
 
210
+ if __name__ == "__main__":
211
+ demo.launch()