Komal133 commited on
Commit
f20ba38
·
verified ·
1 Parent(s): 4fe62f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -141
app.py CHANGED
@@ -1,148 +1,140 @@
1
- import dash
2
- from dash import dcc, html
3
- import dash_bootstrap_components as dbc
4
- from transformers import pipeline
5
  import PyPDF2
6
- import docx
 
 
7
  import matplotlib.pyplot as plt
8
- import numpy as np
9
- import pandas as pd
10
- import requests
11
  import json
 
 
 
 
 
 
 
12
 
13
- # Initialize the BERT-based NLP pipeline
14
- model_name = "dbmdz/bert-large-cased-finetuned-conll03-english" # Example, replace with your model
15
- nlp_pipeline = pipeline("ner", model=model_name)
16
-
17
- # Initialize Dash App
18
- app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
19
-
20
- # Define app layout
21
- app.layout = html.Div([
22
- dbc.Row([
23
- dbc.Col(html.H1("Contract Risk Analyzer", style={'textAlign': 'center'})),
24
- ]),
25
- dbc.Row([
26
- dbc.Col(html.Div([
27
- html.Label("Upload Contract"),
28
- dcc.Upload(
29
- id='upload-data',
30
- children=html.Button('Upload File'),
31
- multiple=False
32
- ),
33
- html.Div(id='file-upload-status'),
34
- ]), width=12),
35
- ]),
36
- dbc.Row([
37
- dbc.Col(html.Div(id='output-text'), width=12),
38
- ]),
39
- dbc.Row([
40
- dbc.Col(dcc.Graph(id='risk-heatmap'), width=12),
41
- ]),
42
- ])
43
-
44
-
45
- # Function to analyze contract text
46
- def analyze_contract(contract_text):
47
  try:
48
- # Run the contract through the NLP pipeline
49
- results = nlp_pipeline(contract_text)
50
-
51
- # Parse and score clauses (this is a simplified version)
52
- risk_score = 0
53
- high_risk_clauses = []
54
-
55
- for result in results:
56
- # This assumes 'labels' are risk-related; adjust as per model output
57
- if result['label'] in ["PENALTY", "OBLIGATION", "DELAY"]: # Customize as per your model's tags
58
- high_risk_clauses.append(result['word'])
59
- risk_score += 10 # Example scoring logic, modify as needed
60
-
61
- return {
62
- "high_risk_clauses": high_risk_clauses,
63
- "risk_score": risk_score
64
- }
65
  except Exception as e:
66
- return {"error": str(e)}
67
-
68
-
69
- # Function to parse uploaded contract
70
- def parse_contract(file_content, file_type):
71
- contract_text = ""
72
- if file_type == "application/pdf":
73
- try:
74
- pdf_reader = PyPDF2.PdfReader(file_content)
75
- for page in pdf_reader.pages:
76
- contract_text += page.extract_text()
77
- except Exception as e:
78
- return f"Error reading PDF: {str(e)}"
79
-
80
- elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
81
- try:
82
- doc = docx.Document(file_content)
83
- for para in doc.paragraphs:
84
- contract_text += para.text
85
- except Exception as e:
86
- return f"Error reading DOCX: {str(e)}"
87
-
88
- elif file_type == "text/plain":
89
- contract_text = file_content.decode("utf-8")
90
-
91
- return contract_text
92
-
93
-
94
- # Callback to handle file upload
95
- @app.callback(
96
- [dash.dependencies.Output('file-upload-status', 'children'),
97
- dash.dependencies.Output('output-text', 'children'),
98
- dash.dependencies.Output('risk-heatmap', 'figure')],
99
- [dash.dependencies.Input('upload-data', 'contents'),
100
- dash.dependencies.State('upload-data', 'filename'),
101
- dash.dependencies.State('upload-data', 'type')]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  )
103
- def update_output(file_contents, filename, file_type):
104
- if file_contents is not None:
105
- # Parse the contract
106
- contract_text = parse_contract(file_contents, file_type)
107
-
108
- if contract_text:
109
- # Analyze the contract
110
- analysis_results = analyze_contract(contract_text)
111
-
112
- if "error" in analysis_results:
113
- return "Error", f"An error occurred during analysis: {analysis_results['error']}", {}
114
-
115
- # Display high-risk clauses and overall risk score
116
- high_risk_clauses = analysis_results["high_risk_clauses"]
117
- risk_score = analysis_results["risk_score"]
118
-
119
- high_risk_text = f"High Risk Clauses: {', '.join(high_risk_clauses)}"
120
- risk_score_text = f"Overall Risk Score: {risk_score}"
121
-
122
- # Generate the risk heatmap (simplified here)
123
- fig, ax = plt.subplots()
124
- ax.barh(['Contract'], [risk_score], color='red')
125
- ax.set_xlim(0, 100) # Assuming risk score ranges from 0 to 100
126
- ax.set_xlabel("Risk Score")
127
-
128
- # Returning results for display
129
- return "File Uploaded Successfully", [high_risk_text, risk_score_text], {
130
- 'data': [{
131
- 'x': ['Contract'],
132
- 'y': [risk_score],
133
- 'type': 'bar',
134
- 'name': 'Risk Score',
135
- 'marker': {'color': 'red'}
136
- }],
137
- 'layout': {
138
- 'title': 'Risk Heatmap',
139
- 'xaxis': {'title': 'Risk Score'},
140
- 'yaxis': {'title': 'Contract'}
141
- }
142
- }
143
-
144
- return "No File Uploaded", "", {}
145
-
146
-
147
- if __name__ == '__main__':
148
- app.run_server(debug=True)
 
1
+ import gradio as gr
 
 
 
2
  import PyPDF2
3
+ import nltk
4
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
5
+ import seaborn as sns
6
  import matplotlib.pyplot as plt
7
+ from reportlab.lib.pagesizes import letter
8
+ from reportlab.pdfgen import canvas
 
9
  import json
10
+ import os
11
+ from io import BytesIO
12
+ import numpy as np
13
+ import torch
14
+
15
+ # Download NLTK data
16
+ nltk.download('punkt')
17
 
18
+ # Initialize BERT model and tokenizer
19
+ model_name = "nlpaueb/legal-bert-base-uncased"
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3) # 3 labels: penalty, obligation, delay
22
+ classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
23
+
24
+ # Clause types and risk scoring logic
25
+ CLAUSE_TYPES = ["penalty", "obligation", "delay"]
26
+ RISK_WEIGHTS = {"penalty": 0.8, "obligation": 0.5, "delay": 0.6}
27
+
28
+ def extract_text_from_pdf(pdf_file):
29
+ """Extract text from uploaded PDF file."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  try:
31
+ reader = PyPDF2.PdfReader(pdf_file)
32
+ text = ""
33
+ for page in reader.pages:
34
+ text += page.extract_text() or ""
35
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
36
  except Exception as e:
37
+ return f"Error extracting text: {str(e)}"
38
+
39
+ def parse_contract(text):
40
+ """Parse contract text into clauses and classify risks."""
41
+ sentences = nltk.sent_tokenize(text)
42
+ results = []
43
+ risk_scores = []
44
+
45
+ for idx, sentence in enumerate(sentences):
46
+ if len(sentence.strip()) < 10: # Skip short sentences
47
+ continue
48
+ # Classify clause
49
+ classification = classifier(sentence)
50
+ clause_type = max(classification[0], key=lambda x: x['score'])['label']
51
+ if clause_type not in CLAUSE_TYPES:
52
+ continue
53
+
54
+ # Calculate risk score
55
+ score = classification[0][CLAUSE_TYPES.index(clause_type)]['score'] * RISK_WEIGHTS[clause_type]
56
+ results.append({
57
+ "clause_id": idx,
58
+ "text": sentence,
59
+ "clause_type": clause_type,
60
+ "risk_score": round(score, 2)
61
+ })
62
+ risk_scores.append(score)
63
+
64
+ return results, risk_scores
65
+
66
+ def generate_heatmap(risk_scores):
67
+ """Generate heatmap for risk scores."""
68
+ if not risk_scores:
69
+ return None
70
+ data = np.array(risk_scores).reshape(1, -1)
71
+ plt.figure(figsize=(10, 2))
72
+ sns.heatmap(data, cmap="YlOrRd", annot=True, fmt=".2f", cbar_kws={'label': 'Risk Score'})
73
+ plt.title("Contract Risk Heatmap")
74
+ plt.xlabel("Clause Index")
75
+ plt.ylabel("Risk")
76
+ buffer = BytesIO()
77
+ plt.savefig(buffer, format="png", bbox_inches="tight")
78
+ plt.close()
79
+ buffer.seek(0)
80
+ return buffer
81
+
82
+ def generate_pdf_report(results, heatmap_buffer):
83
+ """Generate PDF report with summary and heatmap."""
84
+ buffer = BytesIO()
85
+ c = canvas.Canvas(buffer, pagesize=letter)
86
+ c.setFont("Helvetica", 12)
87
+ c.drawString(50, 750, "Contract Risk Analysis Report")
88
+
89
+ # Summary
90
+ c.drawString(50, 720, "Summary of Risk-Prone Clauses:")
91
+ y = 700
92
+ for result in results[:5]: # Limit to top 5 for brevity
93
+ text = f"Clause {result['clause_id']}: {result['clause_type'].capitalize()} (Risk: {result['risk_score']})"
94
+ c.drawString(50, y, text[:80] + "..." if len(text) > 80 else text)
95
+ y -= 20
96
+
97
+ # Embed heatmap
98
+ if heatmap_buffer:
99
+ c.drawImage(BytesIO(heatmap_buffer.read()), 50, y-200, width=500, height=100)
100
+
101
+ c.showPage()
102
+ c.save()
103
+ buffer.seek(0)
104
+ return buffer
105
+
106
+ def process_contract(pdf_file):
107
+ """Main function to process uploaded contract."""
108
+ # Extract text
109
+ text = extract_text_from_pdf(pdf_file)
110
+ if "Error" in text:
111
+ return text, None, None, None
112
+
113
+ # Parse and classify
114
+ results, risk_scores = parse_contract(text)
115
+ if not results:
116
+ return "No relevant clauses detected.", None, None, None
117
+
118
+ # Generate outputs
119
+ json_output = json.dumps(results, indent=2)
120
+ heatmap_buffer = generate_heatmap(risk_scores)
121
+ pdf_report = generate_pdf_report(results, heatmap_buffer)
122
+
123
+ return json_output, heatmap_buffer, pdf_report, {"Summary": f"Detected {len(results)} risk-prone clauses."}
124
+
125
+ # Gradio interface
126
+ iface = gr.Interface(
127
+ fn=process_contract,
128
+ inputs=gr.File(label="Upload Contract PDF"),
129
+ outputs=[
130
+ gr.Textbox(label="JSON Output"),
131
+ gr.Image(label="Risk Heatmap"),
132
+ gr.File(label="Download PDF Report"),
133
+ gr.JSON(label="Summary")
134
+ ],
135
+ title="Contract Risk Analyzer",
136
+ description="Upload a contract PDF to analyze risk-prone clauses and visualize results."
137
  )
138
+
139
+ if __name__ == "__main__":
140
+ iface.launch()