""" Main Training Script for RoBERTa-base Legal-BERT Executes Week 4-5: Model Training and Evaluation Uses RoBERTa-base model for legal risk analysis """ import torch import os import json import argparse from datetime import datetime from config import LegalBertConfig from trainer import LegalBertTrainer from utils import set_seed, plot_training_history def main(): """Execute RoBERTa-base Legal-BERT training pipeline""" # Parse command-line arguments (optional overrides) parser = argparse.ArgumentParser(description='Train RoBERTa-base Legal-BERT model') parser.add_argument('--epochs', type=int, default=None, help='Number of training epochs') parser.add_argument('--batch-size', type=int, default=None, help='Batch size for training') args = parser.parse_args() print("=" * 80) print("šŸ›ļø ROBERTA-BASE LEGAL-BERT TRAINING PIPELINE") print("=" * 80) # Initialize configuration config = LegalBertConfig() # Apply command-line overrides if args.epochs is not None: config.num_epochs = args.epochs if args.batch_size is not None: config.batch_size = args.batch_size # Set random seed for reproducibility set_seed(42) print(f"\nšŸ“‹ Configuration:") print(f" Model type: RoBERTa-base") print(f" Base model: {config.bert_model_name}") print(f" Data path: {config.data_path}") print(f" Device: {config.device}") print(f" Batch size: {config.batch_size}") print(f" Epochs: {config.num_epochs}") print(f" Learning rate: {config.learning_rate}") print(f" Risk discovery clusters: {config.risk_discovery_clusters}") # Initialize trainer trainer = LegalBertTrainer(config) # Prepare data with unsupervised risk discovery print("\n" + "=" * 80) print("šŸ“Š PHASE 1: DATA PREPARATION & RISK DISCOVERY") print("=" * 80) try: train_loader, val_loader, test_loader = trainer.prepare_data(config.data_path) except FileNotFoundError: print(f"āŒ Error: Dataset not found at {config.data_path}") print("Please ensure CUAD dataset is downloaded and path is correct.") return None, None except Exception as e: print(f"āŒ Error during data preparation: {e}") import traceback traceback.print_exc() return None, None # Display discovered risk patterns print("\nšŸ” Discovered Risk Patterns:") for pattern_name, pattern_info in trainer.risk_discovery.discovered_patterns.items(): print(f" • {pattern_name}") print(f" Keywords: {', '.join(pattern_info['keywords'][:5])}") # Train model print("\n" + "=" * 80) print("šŸ‹ļø PHASE 2: MODEL TRAINING") print("=" * 80) try: history = trainer.train(train_loader, val_loader) except Exception as e: print(f"āŒ Error during training: {e}") import traceback traceback.print_exc() return None, None # Plot training history print("\nšŸ“ˆ Plotting training history...") plot_training_history(history, save_path=os.path.join(config.checkpoint_dir, 'training_history.png')) # Save final model print("\nšŸ’¾ Saving final model...") final_model_path = os.path.join(config.model_save_path, 'final_model.pt') os.makedirs(config.model_save_path, exist_ok=True) torch.save({ 'model_state_dict': trainer.model.state_dict(), 'model_type': 'roberta-base', 'config': config, 'risk_discovery_model': trainer.risk_discovery, 'discovered_patterns': trainer.risk_discovery.discovered_patterns, 'training_history': history }, final_model_path) print(f"āœ… Model saved to: {final_model_path}") # Save training summary summary = { 'training_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'config': { 'batch_size': config.batch_size, 'num_epochs': config.num_epochs, 'learning_rate': config.learning_rate, 'device': config.device }, 'final_metrics': { 'train_loss': history['train_loss'][-1], 'val_loss': history['val_loss'][-1], 'train_acc': history['train_acc'][-1], 'val_acc': history['val_acc'][-1] }, 'num_discovered_risks': trainer.risk_discovery.n_clusters, 'discovered_patterns': list(trainer.risk_discovery.discovered_patterns.keys()) } summary_path = os.path.join(config.checkpoint_dir, 'training_summary.json') with open(summary_path, 'w') as f: json.dump(summary, f, indent=2) print(f"\nšŸ“„ Training summary saved to: {summary_path}") # Print final results print("\n" + "=" * 80) print("āœ… TRAINING COMPLETE!") print("=" * 80) print(f"\nšŸ“Š Final Results:") print(f" Train Loss: {history['train_loss'][-1]:.4f}") print(f" Train Accuracy: {history['train_acc'][-1]:.4f}") print(f" Val Loss: {history['val_loss'][-1]:.4f}") print(f" Val Accuracy: {history['val_acc'][-1]:.4f}") print(f"\nšŸŽÆ Next Steps:") print(f" 1. Run evaluation: python evaluate.py") print(f" 2. Apply calibration methods") print(f" 3. Generate comprehensive analysis report") return trainer, history if __name__ == "__main__": result = main() if result is not None: trainer, history = result else: print("\nāŒ Training failed. Please check errors above.") exit(1)