Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import PyPDF2
|
| 3 |
import nltk
|
| 4 |
-
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
| 5 |
import seaborn as sns
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
from reportlab.lib.pagesizes import letter
|
|
@@ -10,7 +9,6 @@ import json
|
|
| 10 |
import os
|
| 11 |
from io import BytesIO
|
| 12 |
import numpy as np
|
| 13 |
-
import torch
|
| 14 |
import logging
|
| 15 |
|
| 16 |
# Set up logging
|
|
@@ -20,23 +18,17 @@ logger = logging.getLogger(__name__)
|
|
| 20 |
# Download NLTK data
|
| 21 |
nltk.download('punkt')
|
| 22 |
|
| 23 |
-
# Initialize BERT model and tokenizer
|
| 24 |
-
model_name = "nlpaueb/legal-bert-base-uncased"
|
| 25 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 26 |
-
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3) # 3 labels: penalty, obligation, delay
|
| 27 |
-
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
|
| 28 |
-
|
| 29 |
-
# Map model labels to clause types (adjust based on actual model labels after fine-tuning)
|
| 30 |
-
LABEL_MAP = {
|
| 31 |
-
"LABEL_0": "penalty",
|
| 32 |
-
"LABEL_1": "obligation",
|
| 33 |
-
"LABEL_2": "delay"
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
# Clause types and risk scoring logic
|
| 37 |
CLAUSE_TYPES = ["penalty", "obligation", "delay"]
|
| 38 |
RISK_WEIGHTS = {"penalty": 0.8, "obligation": 0.5, "delay": 0.6}
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def extract_text_from_pdf(pdf_file):
|
| 41 |
"""Extract text from uploaded PDF file."""
|
| 42 |
try:
|
|
@@ -55,7 +47,7 @@ def extract_text_from_pdf(pdf_file):
|
|
| 55 |
return f"Error extracting text: {str(e)}"
|
| 56 |
|
| 57 |
def parse_contract(text):
|
| 58 |
-
"""Parse contract text into clauses and classify risks."""
|
| 59 |
# Clean text: replace multiple newlines with single, handle LaTeX artifacts
|
| 60 |
text = text.replace("\n\n", "\n").replace("\t", " ")
|
| 61 |
sentences = nltk.sent_tokenize(text)
|
|
@@ -70,31 +62,30 @@ def parse_contract(text):
|
|
| 70 |
if len(sentence) < 10: # Skip short sentences
|
| 71 |
logger.debug(f"Skipping short sentence (length {len(sentence)}): {sentence}")
|
| 72 |
continue
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
# Calculate risk score
|
| 85 |
-
score = classification[0][[label for label in LABEL_MAP if LABEL_MAP[label] == clause_type][0]]['score'] * RISK_WEIGHTS[clause_type]
|
| 86 |
-
results.append({
|
| 87 |
-
"clause_id": idx,
|
| 88 |
-
"text": sentence,
|
| 89 |
-
"clause_type": clause_type,
|
| 90 |
-
"risk_score": round(score, 2)
|
| 91 |
-
})
|
| 92 |
-
risk_scores.append(score)
|
| 93 |
-
logger.info(f"Detected clause {idx}: {clause_type} with risk score {score}")
|
| 94 |
-
except Exception as e:
|
| 95 |
-
logger.error(f"Error classifying sentence {idx}: {str(e)}")
|
| 96 |
continue
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
return results, risk_scores
|
| 99 |
|
| 100 |
def generate_heatmap(risk_scores):
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import PyPDF2
|
| 3 |
import nltk
|
|
|
|
| 4 |
import seaborn as sns
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
from reportlab.lib.pagesizes import letter
|
|
|
|
| 9 |
import os
|
| 10 |
from io import BytesIO
|
| 11 |
import numpy as np
|
|
|
|
| 12 |
import logging
|
| 13 |
|
| 14 |
# Set up logging
|
|
|
|
| 18 |
# Download NLTK data
|
| 19 |
nltk.download('punkt')
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# Clause types and risk scoring logic
|
| 22 |
CLAUSE_TYPES = ["penalty", "obligation", "delay"]
|
| 23 |
RISK_WEIGHTS = {"penalty": 0.8, "obligation": 0.5, "delay": 0.6}
|
| 24 |
|
| 25 |
+
# Keyword-based heuristic for clause classification
|
| 26 |
+
KEYWORD_MAP = {
|
| 27 |
+
"penalty": ["penalty", "fee", "fine", "charge", "incur"],
|
| 28 |
+
"obligation": ["shall", "must", "obligated", "required", "responsible"],
|
| 29 |
+
"delay": ["delay", "late", "beyond", "postpone", "deferred"]
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
def extract_text_from_pdf(pdf_file):
|
| 33 |
"""Extract text from uploaded PDF file."""
|
| 34 |
try:
|
|
|
|
| 47 |
return f"Error extracting text: {str(e)}"
|
| 48 |
|
| 49 |
def parse_contract(text):
|
| 50 |
+
"""Parse contract text into clauses and classify risks using keyword-based heuristic."""
|
| 51 |
# Clean text: replace multiple newlines with single, handle LaTeX artifacts
|
| 52 |
text = text.replace("\n\n", "\n").replace("\t", " ")
|
| 53 |
sentences = nltk.sent_tokenize(text)
|
|
|
|
| 62 |
if len(sentence) < 10: # Skip short sentences
|
| 63 |
logger.debug(f"Skipping short sentence (length {len(sentence)}): {sentence}")
|
| 64 |
continue
|
| 65 |
+
|
| 66 |
+
# Heuristic classification based on keywords
|
| 67 |
+
sentence_lower = sentence.lower()
|
| 68 |
+
clause_type = None
|
| 69 |
+
for c_type, keywords in KEYWORD_MAP.items():
|
| 70 |
+
if any(keyword in sentence_lower for keyword in keywords):
|
| 71 |
+
clause_type = c_type
|
| 72 |
+
break
|
| 73 |
+
|
| 74 |
+
if clause_type not in CLAUSE_TYPES:
|
| 75 |
+
logger.debug(f"No relevant clause type for sentence {idx}: {sentence}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
continue
|
| 77 |
|
| 78 |
+
# Assign a dummy score based on keyword presence (simulating model confidence)
|
| 79 |
+
score = RISK_WEIGHTS[clause_type] * 0.9 # 0.9 as a dummy confidence score
|
| 80 |
+
results.append({
|
| 81 |
+
"clause_id": idx,
|
| 82 |
+
"text": sentence,
|
| 83 |
+
"clause_type": clause_type,
|
| 84 |
+
"risk_score": round(score, 2)
|
| 85 |
+
})
|
| 86 |
+
risk_scores.append(score)
|
| 87 |
+
logger.info(f"Detected clause {idx}: {clause_type} with risk score {score}")
|
| 88 |
+
|
| 89 |
return results, risk_scores
|
| 90 |
|
| 91 |
def generate_heatmap(risk_scores):
|