ABDALLAH31 commited on
Commit
6d45671
·
verified ·
1 Parent(s): ee8bd9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -34
app.py CHANGED
@@ -2,47 +2,28 @@ import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import seaborn as sns
4
  import numpy as np
5
- import pdfplumber
6
  import os
7
- from transformers import pipeline
8
 
9
- # Load the zero-shot classification model
10
- classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
 
11
 
12
- # Function to extract text from PDF
13
- def extract_text_from_pdf(file_path):
14
- text = ""
15
- with pdfplumber.open(file_path) as pdf:
16
- for page in pdf.pages:
17
- text += page.extract_text()
18
- return text
 
 
 
19
 
20
- # Generate heatmap
21
- def generate_heatmap(file):
22
- # Step 1: Extract text from the uploaded PDF
23
- text = extract_text_from_pdf(file.name)
24
-
25
- # Step 2: Split text into individual clauses (simple split by periods)
26
- clauses = text.split(". ")
27
-
28
- # Step 3: Define candidate labels for risk
29
- labels = ["high risk", "medium risk", "low risk"]
30
-
31
- # Step 4: Classify each clause and store the scores
32
- scores = []
33
- for clause in clauses:
34
- result = classifier(clause, labels)
35
- scores.append(result['scores'])
36
-
37
- # Step 5: Create the heatmap data
38
- risk_levels = {"High": 3, "Medium": 2, "Low": 1}
39
- risk_values = [risk_levels.get(r['label'], 1) for r in result['labels']]
40
-
41
  # Plot heatmap
42
  fig = plt.figure(figsize=(10, 6))
43
  sns.heatmap([risk_values], annot=True, xticklabels=clauses, yticklabels=["Risk Levels"], cmap="YlOrRd")
44
-
45
- # Save the heatmap as an image
46
  heatmap_path = os.path.join(os.getcwd(), 'contract_risk_heatmap.png')
47
  plt.savefig(heatmap_path)
48
 
 
2
  import matplotlib.pyplot as plt
3
  import seaborn as sns
4
  import numpy as np
 
5
  import os
 
6
 
7
+ def generate_heatmap(results):
8
+ # Check the structure of the results
9
+ print("Results:", results)
10
 
11
+ # If the results are strings (e.g., just the clauses)
12
+ if isinstance(results, list) and isinstance(results[0], str):
13
+ clauses = results # Directly use clauses
14
+ # For simplicity, assume all clauses are "high risk" here for testing purposes
15
+ risk_values = [3 for _ in clauses] # Replace with actual risk assessment logic
16
+ else:
17
+ # Assuming results are in the format [{'clause': ..., 'risk_level': ...}, ...]
18
+ clauses = [r['clause'] for r in results] # Extract clause text
19
+ risk_levels = {"High": 3, "Medium": 2, "Low": 1}
20
+ risk_values = [risk_levels.get(r['risk_level'], 1) for r in results] # Map risk level to value
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Plot heatmap
23
  fig = plt.figure(figsize=(10, 6))
24
  sns.heatmap([risk_values], annot=True, xticklabels=clauses, yticklabels=["Risk Levels"], cmap="YlOrRd")
25
+
26
+ # Save heatmap image
27
  heatmap_path = os.path.join(os.getcwd(), 'contract_risk_heatmap.png')
28
  plt.savefig(heatmap_path)
29