File size: 28,966 Bytes
5c0f558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
"""
Utilities and helper functions for Legal-BERT project
"""
import os
import json
import re
from typing import Dict, List, Any, Tuple
import logging

def setup_logging(log_level: str = "INFO") -> logging.Logger:
    """Set up logging configuration"""
    logging.basicConfig(
        level=getattr(logging, log_level.upper()),
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler('legal_bert.log'),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

def ensure_directory_exists(path: str):
    """Create directory if it doesn't exist"""
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"πŸ“ Created directory: {path}")

def save_json(data: Dict[str, Any], filepath: str):
    """Save data to JSON file"""
    ensure_directory_exists(os.path.dirname(filepath))
    with open(filepath, 'w') as f:
        json.dump(data, f, indent=2)
    print(f"πŸ’Ύ Saved JSON: {filepath}")

def load_json(filepath: str) -> Dict[str, Any]:
    """Load data from JSON file"""
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"JSON file not found: {filepath}")
    
    with open(filepath, 'r') as f:
        data = json.load(f)
    print(f"πŸ“‚ Loaded JSON: {filepath}")
    return data

def clean_text(text: str) -> str:
    """Clean and normalize text"""
    if not isinstance(text, str):
        return ""
    
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text)
    
    # Remove special characters but keep legal punctuation
    text = re.sub(r'[^\w\s.,;:()"-]', ' ', text)
    
    # Clean up spacing
    text = text.strip()
    
    return text

def extract_contract_metadata(filename: str) -> Dict[str, str]:
    """Extract metadata from contract filename"""
    # CUAD filename pattern: COMPANY_DATE_FILING_EXHIBIT_AGREEMENT
    parts = filename.replace('.txt', '').split('_')
    
    metadata = {
        'company': parts[0] if len(parts) > 0 else 'Unknown',
        'date': parts[1] if len(parts) > 1 else 'Unknown',
        'filing_type': parts[2] if len(parts) > 2 else 'Unknown',
        'exhibit': parts[3] if len(parts) > 3 else 'Unknown',
        'agreement_type': '_'.join(parts[4:]) if len(parts) > 4 else 'Unknown'
    }
    
    return metadata

def format_risk_score(score: float) -> str:
    """Format risk score for display"""
    if score < 2:
        return f"LOW ({score:.2f})"
    elif score < 5:
        return f"MEDIUM ({score:.2f})"
    elif score < 8:
        return f"HIGH ({score:.2f})"
    else:
        return f"CRITICAL ({score:.2f})"

def calculate_statistics(values: List[float]) -> Dict[str, float]:
    """Calculate basic statistics for a list of values"""
    if not values:
        return {'mean': 0, 'std': 0, 'min': 0, 'max': 0, 'median': 0}
    
    import statistics
    
    return {
        'mean': statistics.mean(values),
        'std': statistics.stdev(values) if len(values) > 1 else 0,
        'min': min(values),
        'max': max(values),
        'median': statistics.median(values)
    }

def set_seed(seed: int = 42):
    """Set random seed for reproducibility"""
    import random
    import numpy as np
    
    random.seed(seed)
    np.random.seed(seed)
    
    try:
        import torch
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
        print(f"🎲 Random seed set to {seed}")
    except ImportError:
        print(f"🎲 Random seed set to {seed} (torch not available)")

def plot_training_history(history: Dict[str, List[float]], save_path: str = None):
    """Plot training history curves"""
    try:
        import matplotlib.pyplot as plt
        
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))
        
        # Loss plot
        axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
        axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training and Validation Loss')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Accuracy plot
        axes[1].plot(history['train_acc'], label='Train Accuracy', marker='o')
        axes[1].plot(history['val_acc'], label='Val Accuracy', marker='s')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Accuracy')
        axes[1].set_title('Training and Validation Accuracy')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"πŸ’Ύ Training history plot saved to: {save_path}")
        else:
            plt.show()
        
        plt.close()
        
    except ImportError:
        print("⚠️ matplotlib not available. Skipping training history plot.")

def format_time(seconds: float) -> str:
    """Format time in seconds to human readable string"""
    if seconds < 60:
        return f"{seconds:.1f}s"
    elif seconds < 3600:
        minutes = int(seconds // 60)
        secs = int(seconds % 60)
        return f"{minutes}m {secs}s"
    else:
        hours = int(seconds // 3600)
        minutes = int((seconds % 3600) // 60)
        return f"{hours}h {minutes}m"

def print_progress_bar(iteration: int, total: int, prefix: str = 'Progress', 
                      suffix: str = 'Complete', length: int = 50):
    """Print a progress bar"""
    percent = (100 * (iteration / float(total)))
    filled_length = int(length * iteration // total)
    bar = 'β–ˆ' * filled_length + '-' * (length - filled_length)
    print(f'\r{prefix} |{bar}| {percent:.1f}% {suffix}', end='')
    if iteration == total:
        print()

def validate_config(config) -> List[str]:
    """Validate configuration settings"""
    errors = []
    
    # Check required fields
    required_fields = ['bert_model_name', 'data_path', 'batch_size', 'num_epochs']
    for field in required_fields:
        if not hasattr(config, field):
            errors.append(f"Missing required config field: {field}")
    
    # Check data path exists
    if hasattr(config, 'data_path') and not os.path.exists(config.data_path):
        errors.append(f"Data path does not exist: {config.data_path}")
    
    # Check positive values
    if hasattr(config, 'batch_size') and config.batch_size <= 0:
        errors.append("Batch size must be positive")
    
    if hasattr(config, 'num_epochs') and config.num_epochs <= 0:
        errors.append("Number of epochs must be positive")
    
    # Check learning rate range
    if hasattr(config, 'learning_rate') and (config.learning_rate <= 0 or config.learning_rate > 1):
        errors.append("Learning rate must be between 0 and 1")
    
    return errors

def create_model_summary(model, config) -> str:
    """Create a summary of the model architecture"""
    try:
        # Try to get parameter count
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    except:
        total_params = "Unknown"
        trainable_params = "Unknown"
    
    summary = [
        "πŸ“‹ MODEL SUMMARY",
        "=" * 50,
        f"Architecture: Legal-BERT (Fully Learning-Based)",
        f"Base Model: {config.bert_model_name}",
        f"Risk Categories: {config.num_risk_categories} (discovered)",
        f"Max Sequence Length: {config.max_sequence_length}",
        f"Dropout Rate: {config.dropout_rate}",
        f"Total Parameters: {total_params}",
        f"Trainable Parameters: {trainable_params}",
        f"Device: {config.device}",
        "=" * 50
    ]
    
    return "\n".join(summary)

def check_dependencies() -> Dict[str, bool]:
    """Check if required dependencies are available"""
    dependencies = {
        'torch': False,
        'transformers': False,
        'sklearn': False,
        'numpy': False,
        'pandas': False
    }
    
    for dep in dependencies:
        try:
            __import__(dep)
            dependencies[dep] = True
        except ImportError:
            dependencies[dep] = False
    
    return dependencies

def print_dependency_status():
    """Print status of dependencies"""
    deps = check_dependencies()
    
    print("πŸ“¦ DEPENDENCY STATUS")
    print("-" * 30)
    
    for dep, available in deps.items():
        status = "βœ… Available" if available else "❌ Missing"
        print(f"{dep:12} : {status}")
    
    missing = [dep for dep, available in deps.items() if not available]
    
    if missing:
        print(f"\n⚠️  Missing dependencies: {', '.join(missing)}")
        print("Install with: pip install torch transformers scikit-learn numpy pandas")
        print("For demo mode, dependencies are not required.")
    else:
        print("\nπŸŽ‰ All dependencies available!")

def get_sample_contract_text() -> str:
    """Get sample contract text for testing"""
    return """
    SERVICES AGREEMENT
    
    This Services Agreement ("Agreement") is entered into as of the Effective Date
    by and between Company A ("Provider") and Company B ("Client").
    
    1. SERVICES
    Provider shall provide the services described in Exhibit A ("Services") to Client
    in accordance with the terms and conditions set forth herein.
    
    2. PAYMENT TERMS
    Client shall pay Provider the fees specified in Exhibit B within thirty (30) days
    of receipt of each invoice. Late payments shall incur a penalty of 1.5% per month.
    
    3. INDEMNIFICATION
    Each party shall indemnify and hold harmless the other party from and against any
    third-party claims arising out of such party's breach of this Agreement.
    
    4. LIMITATION OF LIABILITY
    In no event shall either party's liability exceed the total amount paid under this
    Agreement in the twelve (12) months preceding the claim.
    
    5. TERMINATION
    Either party may terminate this Agreement upon thirty (30) days written notice
    to the other party. Upon termination, all confidential information shall be returned.
    
    6. GOVERNING LAW
    This Agreement shall be governed by and construed in accordance with the laws
    of the State of Delaware.
    """


def split_into_clauses(text: str, method: str = 'sentence') -> List[str]:
    """
    Split a contract paragraph/document into individual clauses.
    
    This is CRITICAL for real-world usage because:
    - Contracts have 50-500+ clauses
    - Model processes ONE clause at a time
    - Need to segment before analysis
    
    Args:
        text: Full contract text or paragraph
        method: 'sentence' (basic) or 'legal' (advanced legal-aware splitting)
    
    Returns:
        List of individual clauses
    
    Example:
        >>> text = "The Company shall not be liable. Either party may terminate."
        >>> clauses = split_into_clauses(text)
        >>> # Returns: ["The Company shall not be liable.", "Either party may terminate."]
    """
    if not text or not isinstance(text, str):
        return []
    
    if method == 'sentence':
        # Basic sentence splitting
        import re
        
        # Split on period, semicolon, or newline followed by capital letter
        clauses = re.split(r'(?<=[.;])\s+(?=[A-Z])|(?<=\n)\s*(?=[A-Z])', text)
        
        # Clean and filter
        clauses = [c.strip() for c in clauses if c.strip()]
        
        # Remove very short fragments (< 10 chars)
        clauses = [c for c in clauses if len(c) >= 10]
        
        return clauses
    
    elif method == 'legal':
        # Legal-aware splitting (handles numbered sections, subsections, etc.)
        import re
        
        clauses = []
        
        # Split on common legal delimiters
        # 1. Numbered sections: "1. SERVICES", "2.1 Payment", etc.
        # 2. Lettered sections: "(a)", "(i)", etc.
        # 3. Sentence boundaries
        
        # First, split by major section numbers
        sections = re.split(r'\n\s*(\d+\.?\s+[A-Z][A-Z\s]+)\n', text)
        
        for section in sections:
            if not section.strip():
                continue
            
            # Further split each section by sentences
            sentences = re.split(r'(?<=[.;])\s+(?=[A-Z(])', section)
            
            for sent in sentences:
                sent = sent.strip()
                if len(sent) >= 10:
                    clauses.append(sent)
        
        return clauses
    
    else:
        raise ValueError(f"Unknown method: {method}. Use 'sentence' or 'legal'")


def analyze_full_document(
    text: str, 
    model, 
    return_details: bool = True,
    use_context: bool = True,
    context_window: int = 1
) -> Dict[str, Any]:
    """
    Analyze a full contract document (multiple clauses).
    
    CONTEXT-AWARE ANALYSIS:
    - By default, includes surrounding clauses as context (use_context=True)
    - This solves the problem of references like "Such Services", "Section 5", etc.
    - Each clause gets analyzed with its neighboring clauses for better understanding
    
    This is the HIGH-LEVEL function you'd use in production:
    - Takes full contract text
    - Splits into clauses automatically
    - Analyzes each clause (with context!)
    - Returns aggregated results
    
    Args:
        text: Full contract text (can be 10+ pages)
        model: Trained LegalBERT model
        return_details: If True, include per-clause predictions
        use_context: If True, include surrounding clauses as context (RECOMMENDED)
        context_window: Number of clauses before/after to include (1 = prev + curr + next)
    
    Returns:
        Dictionary with document-level and clause-level analysis
    
    Example:
        >>> contract = "The Company shall provide services... [1000 more words]"
        >>> results = analyze_full_document(contract, model, use_context=True)
        >>> print(f"Document risk: {results['overall_severity']}")
        >>> print(f"High-risk clauses: {len(results['high_risk_clauses'])}")
    """
    # Step 1: Split into clauses
    clauses = split_into_clauses(text, method='legal')
    
    if not clauses:
        return {
            'error': 'No clauses found in document',
            'n_clauses': 0
        }
    
    # Step 2: Analyze each clause (WITH CONTEXT!)
    clause_predictions = []
    
    if use_context:
        print(f"πŸ“„ Analyzing document with {len(clauses)} clauses (context-aware)...")
        print(f"   Context window: Β±{context_window} clauses")
    else:
        print(f"πŸ“„ Analyzing document with {len(clauses)} clauses...")
    
    for i, clause in enumerate(clauses):
        try:
            # BUILD CONTEXT: Include surrounding clauses
            if use_context:
                # Get previous clauses
                start_idx = max(0, i - context_window)
                # Get next clauses
                end_idx = min(len(clauses), i + context_window + 1)
                
                # Combine: [prev clauses] + [CURRENT] + [next clauses]
                context_clauses = clauses[start_idx:end_idx]
                
                # Mark which is the target clause
                # Add special markers or just concatenate
                clause_with_context = " ".join(context_clauses)
                
                # Alternative: Mark the target clause explicitly
                # clause_with_context = (
                #     " ".join(clauses[start_idx:i]) +
                #     " [TARGET] " + clause + " [/TARGET] " +
                #     " ".join(clauses[i+1:end_idx])
                # )
                
                input_text = clause_with_context
            else:
                # No context - just the clause alone
                input_text = clause
            
            # Call model.predict() with context
            pred = model.predict(input_text)
            
            clause_predictions.append({
                'clause_id': i,
                'clause_text': clause,  # Store original clause (not context)
                'analyzed_with_context': use_context,
                'risk_type': pred.get('risk_type'),
                'risk_name': pred.get('risk_name'),
                'confidence': pred.get('confidence'),
                'severity': pred.get('severity'),
                'importance': pred.get('importance')
            })
            
            if (i + 1) % 10 == 0:
                print(f"  Processed {i + 1}/{len(clauses)} clauses...")
                
        except Exception as e:
            print(f"⚠️  Error analyzing clause {i}: {e}")
            continue
    
    # Step 3: Aggregate results
    if not clause_predictions:
        return {
            'error': 'Failed to analyze any clauses',
            'n_clauses': len(clauses)
        }
    
    # Calculate document-level metrics
    severities = [p['severity'] for p in clause_predictions if p.get('severity')]
    importances = [p['importance'] for p in clause_predictions if p.get('importance')]
    
    # Find high-risk clauses (severity > 7)
    high_risk_clauses = [
        p for p in clause_predictions 
        if p.get('severity', 0) > 7.0
    ]
    
    # Risk distribution
    from collections import Counter
    risk_counts = Counter([p['risk_name'] for p in clause_predictions if p.get('risk_name')])
    total = len(clause_predictions)
    risk_distribution = {
        risk: count / total 
        for risk, count in risk_counts.items()
    }
    
    # Find dominant risk
    dominant_risk = risk_counts.most_common(1)[0] if risk_counts else ('UNKNOWN', 0)
    
    # Build result
    result = {
        'document_summary': {
            'total_clauses': len(clauses),
            'analyzed_clauses': len(clause_predictions),
            'overall_severity': sum(severities) / len(severities) if severities else 0,
            'max_severity': max(severities) if severities else 0,
            'overall_importance': sum(importances) / len(importances) if importances else 0,
            'high_risk_clause_count': len(high_risk_clauses),
            'dominant_risk_type': dominant_risk[0],
            'dominant_risk_percentage': (dominant_risk[1] / total * 100) if total > 0 else 0
        },
        'risk_distribution': risk_distribution,
        'high_risk_clauses': high_risk_clauses[:10] if high_risk_clauses else []  # Top 10 only
    }
    
    # Optionally include all clause details
    if return_details:
        result['all_clauses'] = clause_predictions
    
    print(f"βœ… Analysis complete!")
    print(f"   Overall Severity: {result['document_summary']['overall_severity']:.2f}")
    print(f"   High-Risk Clauses: {len(high_risk_clauses)}")
    print(f"   Dominant Risk: {dominant_risk[0]} ({dominant_risk[1]} clauses)")
    
    return result


def analyze_with_section_context(text: str, model, return_details: bool = True) -> Dict[str, Any]:
    """
    Advanced context-aware analysis using document structure.
    
    SECTION-AWARE APPROACH:
    - Identifies document sections (e.g., "1. SERVICES", "2. PAYMENT")
    - Analyzes clauses within section context
    - Preserves hierarchical relationships
    
    This is better than sliding window because:
    - Respects document structure
    - Section headers provide semantic context
    - References like "this Section" are understood
    
    Args:
        text: Full contract text
        model: Trained model
        return_details: Include all clause predictions
    
    Returns:
        Analysis with section-level grouping
    
    Example:
        >>> results = analyze_with_section_context(contract, model)
        >>> for section in results['sections']:
        ...     print(f"{section['title']}: {section['avg_severity']}")
    """
    import re
    
    print("πŸ“„ Analyzing document with section-aware context...")
    
    # Parse document into sections
    # Match patterns like "1. SERVICES", "2.1 Payment Terms", etc.
    section_pattern = r'\n\s*(\d+\.?\d*\s+[A-Z][A-Z\s]+)\n'
    
    # Split by sections
    parts = re.split(section_pattern, text)
    
    sections = []
    current_section = {'title': 'Preamble', 'text': parts[0], 'clauses': []}
    
    # Group into (title, content) pairs
    for i in range(1, len(parts), 2):
        if i + 1 < len(parts):
            # Previous section complete - analyze it
            if current_section['text'].strip():
                section_clauses = split_into_clauses(current_section['text'], method='sentence')
                current_section['clauses'] = section_clauses
                sections.append(current_section)
            
            # Start new section
            current_section = {
                'title': parts[i].strip(),
                'text': parts[i + 1],
                'clauses': []
            }
    
    # Add last section
    if current_section['text'].strip():
        section_clauses = split_into_clauses(current_section['text'], method='sentence')
        current_section['clauses'] = section_clauses
        sections.append(current_section)
    
    print(f"   Identified {len(sections)} sections")
    
    # Analyze each section with full section context
    all_predictions = []
    section_summaries = []
    
    for sect_idx, section in enumerate(sections):
        section_title = section['title']
        section_text = section['text']
        clauses = section['clauses']
        
        print(f"   Analyzing section: {section_title} ({len(clauses)} clauses)")
        
        section_predictions = []
        
        for clause_idx, clause in enumerate(clauses):
            try:
                # CONTEXT = Section title + full section text
                # This way "such Services" knows we're in "1. SERVICES" section
                context_input = f"{section_title}. {section_text}"
                
                # Truncate if too long (BERT limit)
                if len(context_input) > 1000:  # ~200 tokens
                    # Use section title + nearby clauses
                    window_start = max(0, clause_idx - 2)
                    window_end = min(len(clauses), clause_idx + 3)
                    nearby = " ".join(clauses[window_start:window_end])
                    context_input = f"{section_title}. {nearby}"
                
                # Predict with section context
                pred = model.predict(context_input)
                
                prediction = {
                    'clause_id': len(all_predictions),
                    'section': section_title,
                    'clause_text': clause,
                    'risk_type': pred.get('risk_type'),
                    'risk_name': pred.get('risk_name'),
                    'confidence': pred.get('confidence'),
                    'severity': pred.get('severity'),
                    'importance': pred.get('importance'),
                    'analyzed_with_section_context': True
                }
                
                section_predictions.append(prediction)
                all_predictions.append(prediction)
                
            except Exception as e:
                print(f"⚠️  Error in {section_title}, clause {clause_idx}: {e}")
                continue
        
        # Section-level summary
        if section_predictions:
            severities = [p['severity'] for p in section_predictions if p.get('severity')]
            avg_severity = sum(severities) / len(severities) if severities else 0
            
            section_summaries.append({
                'title': section_title,
                'clause_count': len(clauses),
                'avg_severity': avg_severity,
                'max_severity': max(severities) if severities else 0,
                'high_risk_count': sum(1 for s in severities if s > 7)
            })
    
    # Document-level aggregation
    if not all_predictions:
        return {'error': 'No predictions generated'}
    
    from collections import Counter
    
    severities = [p['severity'] for p in all_predictions if p.get('severity')]
    risk_counts = Counter([p['risk_name'] for p in all_predictions if p.get('risk_name')])
    total = len(all_predictions)
    
    result = {
        'document_summary': {
            'total_sections': len(sections),
            'total_clauses': len(all_predictions),
            'overall_severity': sum(severities) / len(severities) if severities else 0,
            'max_severity': max(severities) if severities else 0,
            'high_risk_clause_count': sum(1 for s in severities if s > 7)
        },
        'sections': section_summaries,
        'risk_distribution': {risk: count/total for risk, count in risk_counts.items()},
        'all_clauses': all_predictions if return_details else []
    }
    
    print(f"βœ… Analysis complete!")
    print(f"   {len(sections)} sections analyzed")
    print(f"   Overall severity: {result['document_summary']['overall_severity']:.2f}")
    
    return result


def print_document_analysis(results: Dict[str, Any]):
    """
    Pretty-print document analysis results.
    
    Args:
        results: Output from analyze_full_document()
    """
    print("\n" + "=" * 80)
    print("πŸ“Š DOCUMENT RISK ANALYSIS REPORT")
    print("=" * 80)
    
    summary = results.get('document_summary', {})
    
    print(f"\nπŸ“„ Document Overview:")
    print(f"   Total Clauses: {summary.get('total_clauses', 0)}")
    print(f"   Analyzed: {summary.get('analyzed_clauses', 0)}")
    
    print(f"\n⚠️  Risk Assessment:")
    severity = summary.get('overall_severity', 0)
    print(f"   Overall Severity: {severity:.2f}/10 - {format_risk_score(severity)}")
    print(f"   Maximum Severity: {summary.get('max_severity', 0):.2f}/10")
    print(f"   Overall Importance: {summary.get('overall_importance', 0):.2f}/10")
    
    print(f"\nπŸ”΄ High-Risk Clauses:")
    print(f"   Count: {summary.get('high_risk_clause_count', 0)}")
    
    print(f"\nπŸ“Š Risk Distribution:")
    for risk_type, percentage in results.get('risk_distribution', {}).items():
        print(f"   {risk_type}: {percentage*100:.1f}%")
    
    print(f"\n🎯 Dominant Risk:")
    print(f"   {summary.get('dominant_risk_type', 'N/A')} "
          f"({summary.get('dominant_risk_percentage', 0):.1f}% of clauses)")
    
    # Show top high-risk clauses
    high_risk = results.get('high_risk_clauses', [])
    if high_risk:
        print(f"\nπŸ” Top High-Risk Clauses:")
        for i, clause in enumerate(high_risk[:5], 1):
            print(f"\n   {i}. {clause['risk_name']} (Severity: {clause['severity']:.1f})")
            text = clause['clause_text'][:100] + "..." if len(clause['clause_text']) > 100 else clause['clause_text']
            print(f"      \"{text}\"")
    
    print("\n" + "=" * 80)


def parse_document_hierarchically(text: str) -> List[List[str]]:
    """
    Parse document into hierarchical structure: sections β†’ clauses
    
    Args:
        text: Full document text
    
    Returns:
        List of sections, each containing list of clauses
        Example: [
            ['clause1', 'clause2'],  # Section 1
            ['clause3', 'clause4'],  # Section 2
        ]
    """
    # Split into sections (numbered headings like "1. SERVICES")
    section_pattern = r'\n\s*(\d+\.?\d*\s+[A-Z][A-Z\s]+)\n'
    sections = re.split(section_pattern, text)
    
    document_structure = []
    
    # Process sections (odd indices are titles, even are content)
    for i in range(1, len(sections), 2):
        if i + 1 < len(sections):
            section_title = sections[i].strip()
            section_text = sections[i + 1].strip()
            
            # Split section into clauses (sentences)
            clauses = split_into_clauses(section_text, method='sentence')
            
            if clauses:
                document_structure.append(clauses)
    
    # If no sections found, treat whole document as one section
    if not document_structure:
        clauses = split_into_clauses(text, method='sentence')
        if clauses:
            document_structure.append(clauses)
    
    return document_structure


def prepare_hierarchical_input(clauses: List[str], tokenizer) -> List[Dict[str, Any]]:
    """
    Prepare clauses for hierarchical model input
    
    Args:
        clauses: List of clause texts
        tokenizer: LegalBertTokenizer instance
    
    Returns:
        List of tokenized inputs for each clause
    """
    clause_inputs = []
    
    for clause in clauses:
        encoded = tokenizer.tokenize_clauses([clause], max_length=128)
        clause_inputs.append({
            'input_ids': encoded['input_ids'].squeeze(0),
            'attention_mask': encoded['attention_mask'].squeeze(0)
        })
    
    return clause_inputs