Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| from simple_salesforce import Salesforce | |
| import torch | |
| from PyPDF2 import PdfReader | |
| import re | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| # Salesforce connection | |
| def connect_to_salesforce(): | |
| sf = Salesforce( | |
| username='your_username', | |
| password='your_password', | |
| security_token='your_security_token', | |
| domain='login' # or 'test' for sandbox | |
| ) | |
| return sf | |
| # Extract text from PDF | |
| def extract_text_from_pdf(pdf_file): | |
| reader = PdfReader(pdf_file) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| # Split text into clauses | |
| def split_into_clauses(text): | |
| clauses = re.split(r'\n\s*\d+\.\s*|\n\s*[A-Z]\.\s*', text) | |
| clauses = [clause.strip() for clause in clauses if clause.strip()] | |
| return clauses | |
| # Load BERT model and tokenizer | |
| def load_model(): | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3) # Fine-tuned for 3 risk levels | |
| return tokenizer, model | |
| # Process clauses and assign risk scores | |
| def process_clauses(clauses, tokenizer, model): | |
| results = [] | |
| risk_levels = {0: 'Low', 1: 'Medium', 2: 'High'} | |
| for clause in clauses: | |
| inputs = tokenizer(clause, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| risk_score = torch.softmax(logits, dim=1).numpy()[0] | |
| risk_level = risk_levels[np.argmax(risk_score)] | |
| results.append({ | |
| 'clause_text': clause, | |
| 'risk_level': risk_level, | |
| 'severity_score': float(np.max(risk_score)), | |
| 'clause_type': infer_clause_type(clause) # Simplified clause type inference | |
| }) | |
| return results | |
| # Simplified clause type inference (extend with more sophisticated logic as needed) | |
| def infer_clause_type(clause): | |
| if 'liability' in clause.lower(): | |
| return 'Liability' | |
| elif 'payment' in clause.lower(): | |
| return 'Payment' | |
| else: | |
| return 'General' | |
| # Save results to Salesforce | |
| def save_to_salesforce(sf, results, contract_id): | |
| for result in results: | |
| sf.Contract_Risk__c.create({ | |
| 'Contract__c': contract_id, | |
| 'Clause_Text__c': result['clause_text'][:255], # Truncate if needed | |
| 'Risk_Level__c': result['risk_level'], | |
| 'Severity_Score__c': result['severity_score'], | |
| 'Clause_Type__c': result['clause_type'] | |
| }) | |
| # Generate heatmap | |
| def generate_heatmap(results): | |
| df = pd.DataFrame(results) | |
| risk_scores = df['severity_score'].values | |
| plt.figure(figsize=(10, 2)) | |
| sns.heatmap([risk_scores], cmap='RdYlGn_r', annot=True, fmt='.2f', cbar_kws={'label': 'Risk Severity'}) | |
| plt.title('Contract Clause Risk Heatmap') | |
| plt.xlabel('Clause Index') | |
| plt.yticks([]) | |
| st.pyplot(plt) | |
| # Streamlit interface | |
| def main(): | |
| st.title("Contract Risk Analyzer") | |
| # File upload | |
| uploaded_file = st.file_uploader("Upload Contract PDF", type=["pdf"]) | |
| contract_id = st.text_input("Enter Contract ID") | |
| if uploaded_file and contract_id: | |
| # Extract and process text | |
| text = extract_text_from_pdf(uploaded_file) | |
| clauses = split_into_clauses(text) | |
| # Load model and process clauses | |
| tokenizer, model = load_model() | |
| results = process_clauses(clauses, tokenizer, model) | |
| # Display results | |
| st.subheader("Clause Analysis Results") | |
| for i, result in enumerate(results, 1): | |
| st.write(f"**Clause {i}**") | |
| st.write(f"Text: {result['clause_text'][:100]}...") | |
| st.write(f"Clause Type: {result['clause_type']}") | |
| st.write(f"Risk Level: {result['risk_level']}") | |
| st.write(f"Severity Score: {result['severity_score']:.2f}") | |
| st.write("---") | |
| # Generate and display heatmap | |
| generate_heatmap(results) | |
| # Save to Salesforce | |
| if st.button("Save to Salesforce"): | |
| sf = connect_to_salesforce() | |
| save_to_salesforce(sf, results, contract_id) | |
| st.success("Results saved to Salesforce!") | |
| if __name__ == "__main__": | |
| main() |