ABDALLAH31's picture
Update app.py
113b3f5 verified
import gradio as gr
from transformers import pipeline
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from utils import extract_text_from_pdf, simple_clause_split
from report import generate_pdf
from salesforce_stub import send_to_salesforce
# Load zero-shot classifier
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Candidate labels
labels = ["high risk", "medium risk", "low risk"]
# Function to classify clauses and generate heatmap
def classify_clauses(text_input):
# Split input into clauses
clauses = [clause.strip() for clause in text_input.strip().split('\n') if clause.strip()]
results = []
scores = []
for clause in clauses:
result = classifier(clause, labels)
# Extract scores and store the results
if 'scores' in result:
scores.append(result['scores'])
results.append({
'clause': clause,
'risk_score': result['scores'][0], # High-risk score
'risk_level': labels[result['scores'].index(max(result['scores']))] # Determine risk level
})
else:
scores.append([0, 0, 0]) # Default to no risk if 'scores' is missing
results.append({
'clause': clause,
'risk_score': 0,
'risk_level': 'low'
})
scores_array = np.array(scores)
# Plot heatmap
plt.figure(figsize=(10, 6))
sns.heatmap(
scores_array,
annot=True,
xticklabels=labels,
yticklabels=[f"Clause {i+1}" for i in range(len(clauses))],
cmap="Reds"
)
plt.title("Contract Clause Risk Heatmap")
plt.xlabel("Risk Level")
plt.ylabel("Clauses")
plt.tight_layout()
# Save and return the plot path
plot_path = "heatmap.png"
plt.savefig(plot_path)
plt.close()
return results, plot_path
# Function to analyze contract and generate results
def analyze_contract(file):
# Extract text from PDF
tmp_path = file.name
text = extract_text_from_pdf(tmp_path)
# Split the text into clauses
clauses = simple_clause_split(text)
results, heatmap_path = classify_clauses("\n".join(clauses)) # Classify the clauses
# Calculate the overall risk score
overall_score = sum(r['risk_score'] for r in results) / len(results) if results else 0
# Generate the PDF report
report_path = generate_pdf(results, overall_score)
# Generate high-risk clause highlights
highlight_output = ""
for r in results:
color = (
"red" if r['risk_level'] == "High" else
"orange" if r['risk_level'] == "Medium" else
"green"
)
highlight_output += f"<div style='color:{color}'><b>{r['risk_level']}</b>: {r['clause']}</div><br>"
# Send results to Salesforce
send_to_salesforce({
"clauses": results,
"overall_score": overall_score
})
return f"Overall Risk Score: {overall_score:.2f}", highlight_output, report_path, heatmap_path
# Gradio interface
demo = gr.Interface(
fn=analyze_contract,
inputs=gr.File(label="Upload Contract PDF"),
outputs=[
gr.Textbox(label="Overall Risk Score"),
gr.HTML(label="Clause Risk Highlight"),
gr.File(label="Download Risk Report (PDF)"),
gr.Image(type="filepath", label="Risk Heatmap")
],
title="📜 Contract Risk Heatmap Generator",
description="Upload a contract and get clause-level risk scores, high-risk clause highlights, and a heatmap visualization."
)
if __name__ == "__main__":
demo.launch()