Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| from utils import extract_text_from_pdf, simple_clause_split | |
| from report import generate_pdf | |
| from salesforce_stub import send_to_salesforce | |
| # Load zero-shot classifier | |
| classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
| # Candidate labels | |
| labels = ["high risk", "medium risk", "low risk"] | |
| # Function to classify clauses and generate heatmap | |
| def classify_clauses(text_input): | |
| # Split input into clauses | |
| clauses = [clause.strip() for clause in text_input.strip().split('\n') if clause.strip()] | |
| results = [] | |
| scores = [] | |
| for clause in clauses: | |
| result = classifier(clause, labels) | |
| # Extract scores and store the results | |
| if 'scores' in result: | |
| scores.append(result['scores']) | |
| results.append({ | |
| 'clause': clause, | |
| 'risk_score': result['scores'][0], # High-risk score | |
| 'risk_level': labels[result['scores'].index(max(result['scores']))] # Determine risk level | |
| }) | |
| else: | |
| scores.append([0, 0, 0]) # Default to no risk if 'scores' is missing | |
| results.append({ | |
| 'clause': clause, | |
| 'risk_score': 0, | |
| 'risk_level': 'low' | |
| }) | |
| scores_array = np.array(scores) | |
| # Plot heatmap | |
| plt.figure(figsize=(10, 6)) | |
| sns.heatmap( | |
| scores_array, | |
| annot=True, | |
| xticklabels=labels, | |
| yticklabels=[f"Clause {i+1}" for i in range(len(clauses))], | |
| cmap="Reds" | |
| ) | |
| plt.title("Contract Clause Risk Heatmap") | |
| plt.xlabel("Risk Level") | |
| plt.ylabel("Clauses") | |
| plt.tight_layout() | |
| # Save and return the plot path | |
| plot_path = "heatmap.png" | |
| plt.savefig(plot_path) | |
| plt.close() | |
| return results, plot_path | |
| # Function to analyze contract and generate results | |
| def analyze_contract(file): | |
| # Extract text from PDF | |
| tmp_path = file.name | |
| text = extract_text_from_pdf(tmp_path) | |
| # Split the text into clauses | |
| clauses = simple_clause_split(text) | |
| results, heatmap_path = classify_clauses("\n".join(clauses)) # Classify the clauses | |
| # Calculate the overall risk score | |
| overall_score = sum(r['risk_score'] for r in results) / len(results) if results else 0 | |
| # Generate the PDF report | |
| report_path = generate_pdf(results, overall_score) | |
| # Generate high-risk clause highlights | |
| highlight_output = "" | |
| for r in results: | |
| color = ( | |
| "red" if r['risk_level'] == "High" else | |
| "orange" if r['risk_level'] == "Medium" else | |
| "green" | |
| ) | |
| highlight_output += f"<div style='color:{color}'><b>{r['risk_level']}</b>: {r['clause']}</div><br>" | |
| # Send results to Salesforce | |
| send_to_salesforce({ | |
| "clauses": results, | |
| "overall_score": overall_score | |
| }) | |
| return f"Overall Risk Score: {overall_score:.2f}", highlight_output, report_path, heatmap_path | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=analyze_contract, | |
| inputs=gr.File(label="Upload Contract PDF"), | |
| outputs=[ | |
| gr.Textbox(label="Overall Risk Score"), | |
| gr.HTML(label="Clause Risk Highlight"), | |
| gr.File(label="Download Risk Report (PDF)"), | |
| gr.Image(type="filepath", label="Risk Heatmap") | |
| ], | |
| title="📜 Contract Risk Heatmap Generator", | |
| description="Upload a contract and get clause-level risk scores, high-risk clause highlights, and a heatmap visualization." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |