LegalLens-AI / src /analysis.py
ardhigagan's picture
Upload 4 files
01042a2 verified
from transformers import pipeline
import torch
# Check if GPU is available
device = 0 if torch.cuda.is_available() else -1
print(f"utilizing device: {'GPU' if device == 0 else 'CPU'}")
# 1. LOAD MODELS
print("Loading Summarization Model...")
# Force PyTorch framework with framework="pt"
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=device, framework="pt")
print("Loading Risk Detection Model...")
risk_detector = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device, framework="pt")
def analyze_chunk(text_chunk):
"""
Analyzes a single chunk. Returns a summary and A LIST of risks.
"""
# A. SUMMARIZE
try:
summary_result = summarizer(text_chunk, max_length=150, min_length=30, do_sample=False)
summary = summary_result[0]['summary_text']
except Exception as e:
print(f"Summarization error: {e}")
summary = ""
# B. DETECT RISKS (MULTI-LABEL)
# The AI will now check for these 10 distinct legal traps + "Safe"
candidate_labels = [
"Financial Penalty",
"Privacy Violation",
"Non-Compete Restriction",
"Termination Without Cause",
"Intellectual Property Transfer",
"Mandatory Arbitration",
"Indemnification Obligation",
"Unilateral Amendment",
"Jurisdiction Waiver",
"Automatic Renewal",
"Safe Standard Clause"
]
# multi_label=True allows multiple independent high scores
risk_result = risk_detector(text_chunk, candidate_labels, multi_label=True)
# Collect ALL risks above the threshold (50%)
detected_risks = []
for label, score in zip(risk_result['labels'], risk_result['scores']):
# If it's a risk label AND confidence is > 50%
if label != "Safe Standard Clause" and score > 0.50:
detected_risks.append({
"type": label,
"score": round(score, 2),
"text_snippet": text_chunk[:200] + "..." # Snippet for context
})
return summary, detected_risks
def analyze_document(chunks):
"""
Orchestrates the analysis.
"""
full_summary = []
all_risks = []
print(f"Starting analysis on {len(chunks)} chunks...")
for i, chunk in enumerate(chunks):
summary, risks = analyze_chunk(chunk)
full_summary.append(summary)
# Add all found risks to the master list
if risks:
all_risks.extend(risks)
final_executive_summary = " ".join(full_summary)
return final_executive_summary, all_risks