File size: 5,734 Bytes
9b1c753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
Main Training Script for Hierarchical Legal-BERT
Executes Week 4-5: Model Training and Evaluation
Uses Hierarchical BERT (context-aware) model
"""
import torch
import os
import json
import argparse
from datetime import datetime

from config import LegalBertConfig
from trainer import LegalBertTrainer
from utils import set_seed, plot_training_history

def main():
    """Execute Hierarchical Legal-BERT training pipeline"""
    
    # Parse command-line arguments (optional overrides)
    parser = argparse.ArgumentParser(description='Train Hierarchical Legal-BERT model')
    parser.add_argument('--epochs', type=int, default=None,
                       help='Number of training epochs')
    parser.add_argument('--batch-size', type=int, default=None,
                       help='Batch size for training')
    args = parser.parse_args()
    
    print("=" * 80)
    print("πŸ›οΈ  HIERARCHICAL LEGAL-BERT TRAINING PIPELINE")
    print("=" * 80)
    
    # Initialize configuration
    config = LegalBertConfig()
    
    # Apply command-line overrides
    if args.epochs is not None:
        config.num_epochs = args.epochs
    if args.batch_size is not None:
        config.batch_size = args.batch_size
    
    # Set random seed for reproducibility
    set_seed(42)
    
    print(f"\nπŸ“‹ Configuration:")
    print(f"  Model type: Hierarchical BERT (context-aware)")
    print(f"  Data path: {config.data_path}")
    print(f"  Device: {config.device}")
    print(f"  Batch size: {config.batch_size}")
    print(f"  Epochs: {config.num_epochs}")
    print(f"  Learning rate: {config.learning_rate}")
    print(f"  Risk discovery clusters: {config.risk_discovery_clusters}")
    print(f"  Hierarchical hidden dim: {config.hierarchical_hidden_dim}")
    print(f"  Hierarchical LSTM layers: {config.hierarchical_num_lstm_layers}")
    
    # Initialize trainer
    trainer = LegalBertTrainer(config)
    
    # Prepare data with unsupervised risk discovery
    print("\n" + "=" * 80)
    print("πŸ“Š PHASE 1: DATA PREPARATION & RISK DISCOVERY")
    print("=" * 80)
    
    try:
        train_loader, val_loader, test_loader = trainer.prepare_data(config.data_path)
    except FileNotFoundError:
        print(f"❌ Error: Dataset not found at {config.data_path}")
        print("Please ensure CUAD dataset is downloaded and path is correct.")
        return None, None
    except Exception as e:
        print(f"❌ Error during data preparation: {e}")
        import traceback
        traceback.print_exc()
        return None, None
    
    # Display discovered risk patterns
    print("\nπŸ” Discovered Risk Patterns:")
    for pattern_name, pattern_info in trainer.risk_discovery.discovered_patterns.items():
        print(f"  β€’ {pattern_name}")
        print(f"    Keywords: {', '.join(pattern_info['keywords'][:5])}")
    
    # Train model
    print("\n" + "=" * 80)
    print("πŸ‹οΈ  PHASE 2: MODEL TRAINING")
    print("=" * 80)
    
    try:
        history = trainer.train(train_loader, val_loader)
    except Exception as e:
        print(f"❌ Error during training: {e}")
        import traceback
        traceback.print_exc()
        return None, None
    
    # Plot training history
    print("\nπŸ“ˆ Plotting training history...")
    plot_training_history(history, save_path=os.path.join(config.checkpoint_dir, 'training_history.png'))
    
    # Save final model
    print("\nπŸ’Ύ Saving final model...")
    final_model_path = os.path.join(config.model_save_path, 'final_model.pt')
    os.makedirs(config.model_save_path, exist_ok=True)
    
    torch.save({
        'model_state_dict': trainer.model.state_dict(),
        'model_type': 'hierarchical',
        'config': config,
        'risk_discovery_model': trainer.risk_discovery,
        'discovered_patterns': trainer.risk_discovery.discovered_patterns,
        'training_history': history
    }, final_model_path)
    
    print(f"βœ… Model saved to: {final_model_path}")
    
    # Save training summary
    summary = {
        'training_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'config': {
            'batch_size': config.batch_size,
            'num_epochs': config.num_epochs,
            'learning_rate': config.learning_rate,
            'device': config.device
        },
        'final_metrics': {
            'train_loss': history['train_loss'][-1],
            'val_loss': history['val_loss'][-1],
            'train_acc': history['train_acc'][-1],
            'val_acc': history['val_acc'][-1]
        },
        'num_discovered_risks': trainer.risk_discovery.n_clusters,
        'discovered_patterns': list(trainer.risk_discovery.discovered_patterns.keys())
    }
    
    summary_path = os.path.join(config.checkpoint_dir, 'training_summary.json')
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    
    print(f"\nπŸ“„ Training summary saved to: {summary_path}")
    
    # Print final results
    print("\n" + "=" * 80)
    print("βœ… TRAINING COMPLETE!")
    print("=" * 80)
    print(f"\nπŸ“Š Final Results:")
    print(f"  Train Loss: {history['train_loss'][-1]:.4f}")
    print(f"  Train Accuracy: {history['train_acc'][-1]:.4f}")
    print(f"  Val Loss: {history['val_loss'][-1]:.4f}")
    print(f"  Val Accuracy: {history['val_acc'][-1]:.4f}")
    print(f"\n🎯 Next Steps:")
    print(f"  1. Run evaluation: python evaluate.py")
    print(f"  2. Apply calibration methods")
    print(f"  3. Generate comprehensive analysis report")
    
    return trainer, history

if __name__ == "__main__":
    result = main()
    if result is not None:
        trainer, history = result
    else:
        print("\n❌ Training failed. Please check errors above.")
        exit(1)