import gradio as gr from transformers import pipeline import matplotlib.pyplot as plt import seaborn as sns import numpy as np # Load zero-shot classifier classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") # Candidate labels labels = ["high risk", "medium risk", "low risk"] def classify_clauses(text_input): # Split input into clauses clauses = [clause.strip() for clause in text_input.strip().split('\n') if clause.strip()] scores = [] for clause in clauses: result = classifier(clause, labels) scores.append(result['scores']) scores_array = np.array(scores) # Plot heatmap plt.figure(figsize=(10, 6)) sns.heatmap( scores_array, annot=True, xticklabels=labels, yticklabels=[f"Clause {i+1}" for i in range(len(clauses))], cmap="Reds" ) plt.title("Contract Clause Risk Heatmap") plt.xlabel("Risk Level") plt.ylabel("Clauses") plt.tight_layout() # Save and return the plot plot_path = "heatmap.png" plt.savefig(plot_path) plt.close() return plot_path # Gradio UI demo = gr.Interface( fn=classify_clauses, inputs=gr.Textbox(lines=10, label="Enter Contract Clauses (one per line)"), outputs=gr.Image(type="filepath"), title="Contract Risk Heatmap Generator", description="Enter clauses line by line. Uses zero-shot classification to visualize risk levels." ) if __name__ == "__main__": demo.launch()