Komal133 commited on
Commit
dcbd7b1
·
verified ·
1 Parent(s): e353374

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -39
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import PyPDF2
3
  import nltk
4
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
5
  import seaborn as sns
6
  import matplotlib.pyplot as plt
7
  from reportlab.lib.pagesizes import letter
@@ -10,7 +9,6 @@ import json
10
  import os
11
  from io import BytesIO
12
  import numpy as np
13
- import torch
14
  import logging
15
 
16
  # Set up logging
@@ -20,23 +18,17 @@ logger = logging.getLogger(__name__)
20
  # Download NLTK data
21
  nltk.download('punkt')
22
 
23
- # Initialize BERT model and tokenizer
24
- model_name = "nlpaueb/legal-bert-base-uncased"
25
- tokenizer = AutoTokenizer.from_pretrained(model_name)
26
- model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3) # 3 labels: penalty, obligation, delay
27
- classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
28
-
29
- # Map model labels to clause types (adjust based on actual model labels after fine-tuning)
30
- LABEL_MAP = {
31
- "LABEL_0": "penalty",
32
- "LABEL_1": "obligation",
33
- "LABEL_2": "delay"
34
- }
35
-
36
  # Clause types and risk scoring logic
37
  CLAUSE_TYPES = ["penalty", "obligation", "delay"]
38
  RISK_WEIGHTS = {"penalty": 0.8, "obligation": 0.5, "delay": 0.6}
39
 
 
 
 
 
 
 
 
40
  def extract_text_from_pdf(pdf_file):
41
  """Extract text from uploaded PDF file."""
42
  try:
@@ -55,7 +47,7 @@ def extract_text_from_pdf(pdf_file):
55
  return f"Error extracting text: {str(e)}"
56
 
57
  def parse_contract(text):
58
- """Parse contract text into clauses and classify risks."""
59
  # Clean text: replace multiple newlines with single, handle LaTeX artifacts
60
  text = text.replace("\n\n", "\n").replace("\t", " ")
61
  sentences = nltk.sent_tokenize(text)
@@ -70,31 +62,30 @@ def parse_contract(text):
70
  if len(sentence) < 10: # Skip short sentences
71
  logger.debug(f"Skipping short sentence (length {len(sentence)}): {sentence}")
72
  continue
73
- # Classify clause
74
- try:
75
- classification = classifier(sentence)
76
- logger.debug(f"Classification for sentence {idx}: {classification}")
77
- # Map model labels to clause types
78
- top_label = max(classification[0], key=lambda x: x['score'])['label']
79
- clause_type = LABEL_MAP.get(top_label, None)
80
- if clause_type not in CLAUSE_TYPES:
81
- logger.debug(f"Clause type {clause_type} not in {CLAUSE_TYPES}, skipping.")
82
- continue
83
-
84
- # Calculate risk score
85
- score = classification[0][[label for label in LABEL_MAP if LABEL_MAP[label] == clause_type][0]]['score'] * RISK_WEIGHTS[clause_type]
86
- results.append({
87
- "clause_id": idx,
88
- "text": sentence,
89
- "clause_type": clause_type,
90
- "risk_score": round(score, 2)
91
- })
92
- risk_scores.append(score)
93
- logger.info(f"Detected clause {idx}: {clause_type} with risk score {score}")
94
- except Exception as e:
95
- logger.error(f"Error classifying sentence {idx}: {str(e)}")
96
  continue
97
 
 
 
 
 
 
 
 
 
 
 
 
98
  return results, risk_scores
99
 
100
  def generate_heatmap(risk_scores):
 
1
  import gradio as gr
2
  import PyPDF2
3
  import nltk
 
4
  import seaborn as sns
5
  import matplotlib.pyplot as plt
6
  from reportlab.lib.pagesizes import letter
 
9
  import os
10
  from io import BytesIO
11
  import numpy as np
 
12
  import logging
13
 
14
  # Set up logging
 
18
  # Download NLTK data
19
  nltk.download('punkt')
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Clause types and risk scoring logic
22
  CLAUSE_TYPES = ["penalty", "obligation", "delay"]
23
  RISK_WEIGHTS = {"penalty": 0.8, "obligation": 0.5, "delay": 0.6}
24
 
25
+ # Keyword-based heuristic for clause classification
26
+ KEYWORD_MAP = {
27
+ "penalty": ["penalty", "fee", "fine", "charge", "incur"],
28
+ "obligation": ["shall", "must", "obligated", "required", "responsible"],
29
+ "delay": ["delay", "late", "beyond", "postpone", "deferred"]
30
+ }
31
+
32
  def extract_text_from_pdf(pdf_file):
33
  """Extract text from uploaded PDF file."""
34
  try:
 
47
  return f"Error extracting text: {str(e)}"
48
 
49
  def parse_contract(text):
50
+ """Parse contract text into clauses and classify risks using keyword-based heuristic."""
51
  # Clean text: replace multiple newlines with single, handle LaTeX artifacts
52
  text = text.replace("\n\n", "\n").replace("\t", " ")
53
  sentences = nltk.sent_tokenize(text)
 
62
  if len(sentence) < 10: # Skip short sentences
63
  logger.debug(f"Skipping short sentence (length {len(sentence)}): {sentence}")
64
  continue
65
+
66
+ # Heuristic classification based on keywords
67
+ sentence_lower = sentence.lower()
68
+ clause_type = None
69
+ for c_type, keywords in KEYWORD_MAP.items():
70
+ if any(keyword in sentence_lower for keyword in keywords):
71
+ clause_type = c_type
72
+ break
73
+
74
+ if clause_type not in CLAUSE_TYPES:
75
+ logger.debug(f"No relevant clause type for sentence {idx}: {sentence}")
 
 
 
 
 
 
 
 
 
 
 
 
76
  continue
77
 
78
+ # Assign a dummy score based on keyword presence (simulating model confidence)
79
+ score = RISK_WEIGHTS[clause_type] * 0.9 # 0.9 as a dummy confidence score
80
+ results.append({
81
+ "clause_id": idx,
82
+ "text": sentence,
83
+ "clause_type": clause_type,
84
+ "risk_score": round(score, 2)
85
+ })
86
+ risk_scores.append(score)
87
+ logger.info(f"Detected clause {idx}: {clause_type} with risk score {score}")
88
+
89
  return results, risk_scores
90
 
91
  def generate_heatmap(risk_scores):