Spaces:
Running
Running
File size: 2,733 Bytes
01042a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from transformers import pipeline
import torch
# Check if GPU is available
device = 0 if torch.cuda.is_available() else -1
print(f"utilizing device: {'GPU' if device == 0 else 'CPU'}")
# 1. LOAD MODELS
print("Loading Summarization Model...")
# Force PyTorch framework with framework="pt"
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=device, framework="pt")
print("Loading Risk Detection Model...")
risk_detector = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device, framework="pt")
def analyze_chunk(text_chunk):
"""
Analyzes a single chunk. Returns a summary and A LIST of risks.
"""
# A. SUMMARIZE
try:
summary_result = summarizer(text_chunk, max_length=150, min_length=30, do_sample=False)
summary = summary_result[0]['summary_text']
except Exception as e:
print(f"Summarization error: {e}")
summary = ""
# B. DETECT RISKS (MULTI-LABEL)
# The AI will now check for these 10 distinct legal traps + "Safe"
candidate_labels = [
"Financial Penalty",
"Privacy Violation",
"Non-Compete Restriction",
"Termination Without Cause",
"Intellectual Property Transfer",
"Mandatory Arbitration",
"Indemnification Obligation",
"Unilateral Amendment",
"Jurisdiction Waiver",
"Automatic Renewal",
"Safe Standard Clause"
]
# multi_label=True allows multiple independent high scores
risk_result = risk_detector(text_chunk, candidate_labels, multi_label=True)
# Collect ALL risks above the threshold (50%)
detected_risks = []
for label, score in zip(risk_result['labels'], risk_result['scores']):
# If it's a risk label AND confidence is > 50%
if label != "Safe Standard Clause" and score > 0.50:
detected_risks.append({
"type": label,
"score": round(score, 2),
"text_snippet": text_chunk[:200] + "..." # Snippet for context
})
return summary, detected_risks
def analyze_document(chunks):
"""
Orchestrates the analysis.
"""
full_summary = []
all_risks = []
print(f"Starting analysis on {len(chunks)} chunks...")
for i, chunk in enumerate(chunks):
summary, risks = analyze_chunk(chunk)
full_summary.append(summary)
# Add all found risks to the master list
if risks:
all_risks.extend(risks)
final_executive_summary = " ".join(full_summary)
return final_executive_summary, all_risks |