File size: 2,359 Bytes
97f13a8
 
 
 
576eef7
97f13a8
576eef7
9d35238
 
 
 
 
 
 
 
 
 
 
 
 
 
97f13a8
576eef7
 
9d35238
576eef7
 
9d35238
 
576eef7
 
97f13a8
 
9d35238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97f13a8
576eef7
9d35238
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# utils/risk_detector.py

from transformers import pipeline

# ⚖️ Load zero-shot classification model
classifier = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")

# 🎯 Define risk-related labels (can expand as needed)
labels = ["Indemnity", "Exclusivity", "Termination", "Jurisdiction", "Confidentiality", "Fees"]

# Optional fallback suggestions
fallbacks = {
    "Indemnity": "Consider adding a mutual indemnification clause or capping liability.",
    "Exclusivity": "Suggest clarifying duration and scope of exclusivity.",
    "Termination": "Check for balanced termination rights and notice period.",
    "Jurisdiction": "Ensure forum is neutral or matches your operational base.",
    "Confidentiality": "Include a clear definition of confidential information and duration.",
    "Fees": "Ensure clarity on payment structure, late fees, and reimbursement terms."
}

# ========== Core Function ==========

def detect_risks(text, verbose=False):
    """
    Detect and classify legal risks across multiple clauses.

    Returns:
        - List of tuples (clause_text, label, score, fallback) if verbose=True
        - Otherwise: List of (label, score) tuples aggregated
    """
    if not text.strip():
        return []

    # Break into clauses (simple split by period, can be improved)
    clauses = [c.strip() for c in text.split(".") if len(c.strip()) > 20]
    all_results = []

    for clause in clauses:
        result = classifier(clause[:1000], candidate_labels=labels, multi_label=True)
        top_labels = list(zip(result["labels"], result["scores"]))

        if verbose:
            top_risks = [(lbl, score) for lbl, score in top_labels if score >= 0.5]
            for lbl, score in top_risks:
                all_results.append({
                    "clause": clause,
                    "label": lbl,
                    "score": round(score, 3),
                    "suggestion": fallbacks.get(lbl, "")
                })
        else:
            all_results.extend(top_labels)

    if verbose:
        return all_results
    else:
        # Return aggregated top risks (non-verbose mode)
        from collections import Counter
        agg = Counter()
        for lbl, score in all_results:
            agg[lbl] += score
        return sorted(agg.items(), key=lambda x: x[1], reverse=True)