Spaces:
Running
Running
| from transformers import pipeline | |
| import torch | |
| # Check if GPU is available | |
| device = 0 if torch.cuda.is_available() else -1 | |
| print(f"utilizing device: {'GPU' if device == 0 else 'CPU'}") | |
| # 1. LOAD MODELS | |
| print("Loading Summarization Model...") | |
| # Force PyTorch framework with framework="pt" | |
| summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=device, framework="pt") | |
| print("Loading Risk Detection Model...") | |
| risk_detector = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device, framework="pt") | |
| def analyze_chunk(text_chunk): | |
| """ | |
| Analyzes a single chunk. Returns a summary and A LIST of risks. | |
| """ | |
| # A. SUMMARIZE | |
| try: | |
| summary_result = summarizer(text_chunk, max_length=150, min_length=30, do_sample=False) | |
| summary = summary_result[0]['summary_text'] | |
| except Exception as e: | |
| print(f"Summarization error: {e}") | |
| summary = "" | |
| # B. DETECT RISKS (MULTI-LABEL) | |
| # The AI will now check for these 10 distinct legal traps + "Safe" | |
| candidate_labels = [ | |
| "Financial Penalty", | |
| "Privacy Violation", | |
| "Non-Compete Restriction", | |
| "Termination Without Cause", | |
| "Intellectual Property Transfer", | |
| "Mandatory Arbitration", | |
| "Indemnification Obligation", | |
| "Unilateral Amendment", | |
| "Jurisdiction Waiver", | |
| "Automatic Renewal", | |
| "Safe Standard Clause" | |
| ] | |
| # multi_label=True allows multiple independent high scores | |
| risk_result = risk_detector(text_chunk, candidate_labels, multi_label=True) | |
| # Collect ALL risks above the threshold (50%) | |
| detected_risks = [] | |
| for label, score in zip(risk_result['labels'], risk_result['scores']): | |
| # If it's a risk label AND confidence is > 50% | |
| if label != "Safe Standard Clause" and score > 0.50: | |
| detected_risks.append({ | |
| "type": label, | |
| "score": round(score, 2), | |
| "text_snippet": text_chunk[:200] + "..." # Snippet for context | |
| }) | |
| return summary, detected_risks | |
| def analyze_document(chunks): | |
| """ | |
| Orchestrates the analysis. | |
| """ | |
| full_summary = [] | |
| all_risks = [] | |
| print(f"Starting analysis on {len(chunks)} chunks...") | |
| for i, chunk in enumerate(chunks): | |
| summary, risks = analyze_chunk(chunk) | |
| full_summary.append(summary) | |
| # Add all found risks to the master list | |
| if risks: | |
| all_risks.extend(risks) | |
| final_executive_summary = " ".join(full_summary) | |
| return final_executive_summary, all_risks |