ABDALLAH31 commited on
Commit
a2aaebb
·
verified ·
1 Parent(s): 9d951f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -42
app.py CHANGED
@@ -1,58 +1,33 @@
1
- from transformers import pipeline
2
  import matplotlib.pyplot as plt
3
  import seaborn as sns
4
  import numpy as np
5
- import gradio as gr
6
-
7
- # Step 1: Load the classifier model
8
- classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
9
-
10
- # Step 2: Function to process the input and generate the heatmap
11
- def generate_heatmap(file):
12
- # Example clauses from a contract
13
- clauses = [
14
- "The contract may be terminated by either party with a 30-day notice.",
15
- "The vendor is liable for all damages during service interruptions.",
16
- "User data may be shared with third parties without notice.",
17
- "Payment must be made within 15 days of invoice receipt.",
18
- "No liability is accepted for indirect damages."
19
- ]
20
-
21
- # Step 3: Define candidate labels for risk
22
- labels = ["high risk", "medium risk", "low risk"]
23
-
24
- # Step 4: Generate scores for each clause
25
- scores = []
26
- for clause in clauses:
27
- result = classifier(clause, labels)
28
- scores.append(result['scores'])
29
 
30
- # Convert to NumPy array for heatmap
31
- scores_array = np.array(scores)
 
 
 
 
32
 
33
- # Step 5: Plot heatmap
34
- plt.figure(figsize=(10, 6))
35
- sns.heatmap(scores_array, annot=True, xticklabels=labels, yticklabels=[f"Clause {i+1}" for i in range(len(clauses))], cmap="Reds")
36
- plt.title("Contract Clause Risk Heatmap")
37
- plt.xlabel("Risk Level")
38
- plt.ylabel("Clauses")
39
 
40
- # Save the heatmap as an image file
41
- heatmap_path = "/mnt/data/contract_risk_heatmap.png"
42
  plt.savefig(heatmap_path)
43
- plt.close()
44
 
45
  return heatmap_path
46
 
47
- # Step 6: Create Gradio interface
48
  iface = gr.Interface(
49
  fn=generate_heatmap,
50
- inputs=gr.File(label="Upload Contract Document"),
51
- outputs=gr.Image(label="Risk Heatmap"),
52
  title="📜 Contract Risk Heatmap Generator",
53
  description="Upload a contract and get a clause-level risk heatmap."
54
  )
55
 
56
- # Ensure the Gradio app starts properly
57
- if __name__ == "__main__":
58
- iface.launch()
 
1
+ import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import seaborn as sns
4
  import numpy as np
5
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ def generate_heatmap(results):
8
+ # Sample risk levels
9
+ risk_levels = {"High": 3, "Medium": 2, "Low": 1}
10
+
11
+ clauses = [r['clause'] for r in results]
12
+ risk_values = [risk_levels.get(r['risk_level'], 1) for r in results]
13
 
14
+ fig = plt.figure(figsize=(10, 6))
15
+ sns.heatmap([risk_values], annot=True, xticklabels=clauses, yticklabels=["Risk Levels"], cmap="YlOrRd")
 
 
 
 
16
 
17
+ # Save heatmap image to current directory
18
+ heatmap_path = os.path.join(os.getcwd(), 'contract_risk_heatmap.png')
19
  plt.savefig(heatmap_path)
 
20
 
21
  return heatmap_path
22
 
23
+ # Gradio interface
24
  iface = gr.Interface(
25
  fn=generate_heatmap,
26
+ inputs=gr.File(label="Upload Contract PDF"),
27
+ outputs=gr.File(label="Download Heatmap Image"),
28
  title="📜 Contract Risk Heatmap Generator",
29
  description="Upload a contract and get a clause-level risk heatmap."
30
  )
31
 
32
+ # Start the app
33
+ iface.launch()