File size: 5,734 Bytes
9b1c753 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
"""
Main Training Script for Hierarchical Legal-BERT
Executes Week 4-5: Model Training and Evaluation
Uses Hierarchical BERT (context-aware) model
"""
import torch
import os
import json
import argparse
from datetime import datetime
from config import LegalBertConfig
from trainer import LegalBertTrainer
from utils import set_seed, plot_training_history
def main():
"""Execute Hierarchical Legal-BERT training pipeline"""
# Parse command-line arguments (optional overrides)
parser = argparse.ArgumentParser(description='Train Hierarchical Legal-BERT model')
parser.add_argument('--epochs', type=int, default=None,
help='Number of training epochs')
parser.add_argument('--batch-size', type=int, default=None,
help='Batch size for training')
args = parser.parse_args()
print("=" * 80)
print("ποΈ HIERARCHICAL LEGAL-BERT TRAINING PIPELINE")
print("=" * 80)
# Initialize configuration
config = LegalBertConfig()
# Apply command-line overrides
if args.epochs is not None:
config.num_epochs = args.epochs
if args.batch_size is not None:
config.batch_size = args.batch_size
# Set random seed for reproducibility
set_seed(42)
print(f"\nπ Configuration:")
print(f" Model type: Hierarchical BERT (context-aware)")
print(f" Data path: {config.data_path}")
print(f" Device: {config.device}")
print(f" Batch size: {config.batch_size}")
print(f" Epochs: {config.num_epochs}")
print(f" Learning rate: {config.learning_rate}")
print(f" Risk discovery clusters: {config.risk_discovery_clusters}")
print(f" Hierarchical hidden dim: {config.hierarchical_hidden_dim}")
print(f" Hierarchical LSTM layers: {config.hierarchical_num_lstm_layers}")
# Initialize trainer
trainer = LegalBertTrainer(config)
# Prepare data with unsupervised risk discovery
print("\n" + "=" * 80)
print("π PHASE 1: DATA PREPARATION & RISK DISCOVERY")
print("=" * 80)
try:
train_loader, val_loader, test_loader = trainer.prepare_data(config.data_path)
except FileNotFoundError:
print(f"β Error: Dataset not found at {config.data_path}")
print("Please ensure CUAD dataset is downloaded and path is correct.")
return None, None
except Exception as e:
print(f"β Error during data preparation: {e}")
import traceback
traceback.print_exc()
return None, None
# Display discovered risk patterns
print("\nπ Discovered Risk Patterns:")
for pattern_name, pattern_info in trainer.risk_discovery.discovered_patterns.items():
print(f" β’ {pattern_name}")
print(f" Keywords: {', '.join(pattern_info['keywords'][:5])}")
# Train model
print("\n" + "=" * 80)
print("ποΈ PHASE 2: MODEL TRAINING")
print("=" * 80)
try:
history = trainer.train(train_loader, val_loader)
except Exception as e:
print(f"β Error during training: {e}")
import traceback
traceback.print_exc()
return None, None
# Plot training history
print("\nπ Plotting training history...")
plot_training_history(history, save_path=os.path.join(config.checkpoint_dir, 'training_history.png'))
# Save final model
print("\nπΎ Saving final model...")
final_model_path = os.path.join(config.model_save_path, 'final_model.pt')
os.makedirs(config.model_save_path, exist_ok=True)
torch.save({
'model_state_dict': trainer.model.state_dict(),
'model_type': 'hierarchical',
'config': config,
'risk_discovery_model': trainer.risk_discovery,
'discovered_patterns': trainer.risk_discovery.discovered_patterns,
'training_history': history
}, final_model_path)
print(f"β
Model saved to: {final_model_path}")
# Save training summary
summary = {
'training_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'config': {
'batch_size': config.batch_size,
'num_epochs': config.num_epochs,
'learning_rate': config.learning_rate,
'device': config.device
},
'final_metrics': {
'train_loss': history['train_loss'][-1],
'val_loss': history['val_loss'][-1],
'train_acc': history['train_acc'][-1],
'val_acc': history['val_acc'][-1]
},
'num_discovered_risks': trainer.risk_discovery.n_clusters,
'discovered_patterns': list(trainer.risk_discovery.discovered_patterns.keys())
}
summary_path = os.path.join(config.checkpoint_dir, 'training_summary.json')
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
print(f"\nπ Training summary saved to: {summary_path}")
# Print final results
print("\n" + "=" * 80)
print("β
TRAINING COMPLETE!")
print("=" * 80)
print(f"\nπ Final Results:")
print(f" Train Loss: {history['train_loss'][-1]:.4f}")
print(f" Train Accuracy: {history['train_acc'][-1]:.4f}")
print(f" Val Loss: {history['val_loss'][-1]:.4f}")
print(f" Val Accuracy: {history['val_acc'][-1]:.4f}")
print(f"\nπ― Next Steps:")
print(f" 1. Run evaluation: python evaluate.py")
print(f" 2. Apply calibration methods")
print(f" 3. Generate comprehensive analysis report")
return trainer, history
if __name__ == "__main__":
result = main()
if result is not None:
trainer, history = result
else:
print("\nβ Training failed. Please check errors above.")
exit(1)
|