Spaces:
Runtime error
Runtime error
File size: 4,494 Bytes
c76c941 f20ba38 e78607f dcbd7b1 c76c941 f20ba38 c76c941 f20ba38 c76c941 dcbd7b1 c76c941 dcbd7b1 f20ba38 c76c941 f20ba38 c76c941 e353374 c76c941 e353374 c76c941 e353374 c76c941 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import streamlit as st
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification
from simple_salesforce import Salesforce
import torch
from PyPDF2 import PdfReader
import re
import seaborn as sns
import matplotlib.pyplot as plt
# Salesforce connection
def connect_to_salesforce():
sf = Salesforce(
username='your_username',
password='your_password',
security_token='your_security_token',
domain='login' # or 'test' for sandbox
)
return sf
# Extract text from PDF
def extract_text_from_pdf(pdf_file):
reader = PdfReader(pdf_file)
text = ""
for page in reader.pages:
text += page.extract_text() + "\n"
return text
# Split text into clauses
def split_into_clauses(text):
clauses = re.split(r'\n\s*\d+\.\s*|\n\s*[A-Z]\.\s*', text)
clauses = [clause.strip() for clause in clauses if clause.strip()]
return clauses
# Load BERT model and tokenizer
@st.cache_resource
def load_model():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3) # Fine-tuned for 3 risk levels
return tokenizer, model
# Process clauses and assign risk scores
def process_clauses(clauses, tokenizer, model):
results = []
risk_levels = {0: 'Low', 1: 'Medium', 2: 'High'}
for clause in clauses:
inputs = tokenizer(clause, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
risk_score = torch.softmax(logits, dim=1).numpy()[0]
risk_level = risk_levels[np.argmax(risk_score)]
results.append({
'clause_text': clause,
'risk_level': risk_level,
'severity_score': float(np.max(risk_score)),
'clause_type': infer_clause_type(clause) # Simplified clause type inference
})
return results
# Simplified clause type inference (extend with more sophisticated logic as needed)
def infer_clause_type(clause):
if 'liability' in clause.lower():
return 'Liability'
elif 'payment' in clause.lower():
return 'Payment'
else:
return 'General'
# Save results to Salesforce
def save_to_salesforce(sf, results, contract_id):
for result in results:
sf.Contract_Risk__c.create({
'Contract__c': contract_id,
'Clause_Text__c': result['clause_text'][:255], # Truncate if needed
'Risk_Level__c': result['risk_level'],
'Severity_Score__c': result['severity_score'],
'Clause_Type__c': result['clause_type']
})
# Generate heatmap
def generate_heatmap(results):
df = pd.DataFrame(results)
risk_scores = df['severity_score'].values
plt.figure(figsize=(10, 2))
sns.heatmap([risk_scores], cmap='RdYlGn_r', annot=True, fmt='.2f', cbar_kws={'label': 'Risk Severity'})
plt.title('Contract Clause Risk Heatmap')
plt.xlabel('Clause Index')
plt.yticks([])
st.pyplot(plt)
# Streamlit interface
def main():
st.title("Contract Risk Analyzer")
# File upload
uploaded_file = st.file_uploader("Upload Contract PDF", type=["pdf"])
contract_id = st.text_input("Enter Contract ID")
if uploaded_file and contract_id:
# Extract and process text
text = extract_text_from_pdf(uploaded_file)
clauses = split_into_clauses(text)
# Load model and process clauses
tokenizer, model = load_model()
results = process_clauses(clauses, tokenizer, model)
# Display results
st.subheader("Clause Analysis Results")
for i, result in enumerate(results, 1):
st.write(f"**Clause {i}**")
st.write(f"Text: {result['clause_text'][:100]}...")
st.write(f"Clause Type: {result['clause_type']}")
st.write(f"Risk Level: {result['risk_level']}")
st.write(f"Severity Score: {result['severity_score']:.2f}")
st.write("---")
# Generate and display heatmap
generate_heatmap(results)
# Save to Salesforce
if st.button("Save to Salesforce"):
sf = connect_to_salesforce()
save_to_salesforce(sf, results, contract_id)
st.success("Results saved to Salesforce!")
if __name__ == "__main__":
main() |