code2-repo-roBERTa / train.py
Deepu1965's picture
Upload folder using huggingface_hub
aeb53bb verified
"""
Main Training Script for RoBERTa-base Legal-BERT
Executes Week 4-5: Model Training and Evaluation
Uses RoBERTa-base model for legal risk analysis
"""
import torch
import os
import json
import argparse
from datetime import datetime
from config import LegalBertConfig
from trainer import LegalBertTrainer
from utils import set_seed, plot_training_history
def main():
"""Execute RoBERTa-base Legal-BERT training pipeline"""
# Parse command-line arguments (optional overrides)
parser = argparse.ArgumentParser(description='Train RoBERTa-base Legal-BERT model')
parser.add_argument('--epochs', type=int, default=None,
help='Number of training epochs')
parser.add_argument('--batch-size', type=int, default=None,
help='Batch size for training')
args = parser.parse_args()
print("=" * 80)
print("πŸ›οΈ ROBERTA-BASE LEGAL-BERT TRAINING PIPELINE")
print("=" * 80)
# Initialize configuration
config = LegalBertConfig()
# Apply command-line overrides
if args.epochs is not None:
config.num_epochs = args.epochs
if args.batch_size is not None:
config.batch_size = args.batch_size
# Set random seed for reproducibility
set_seed(42)
print(f"\nπŸ“‹ Configuration:")
print(f" Model type: RoBERTa-base")
print(f" Base model: {config.bert_model_name}")
print(f" Data path: {config.data_path}")
print(f" Device: {config.device}")
print(f" Batch size: {config.batch_size}")
print(f" Epochs: {config.num_epochs}")
print(f" Learning rate: {config.learning_rate}")
print(f" Risk discovery clusters: {config.risk_discovery_clusters}")
# Initialize trainer
trainer = LegalBertTrainer(config)
# Prepare data with unsupervised risk discovery
print("\n" + "=" * 80)
print("πŸ“Š PHASE 1: DATA PREPARATION & RISK DISCOVERY")
print("=" * 80)
try:
train_loader, val_loader, test_loader = trainer.prepare_data(config.data_path)
except FileNotFoundError:
print(f"❌ Error: Dataset not found at {config.data_path}")
print("Please ensure CUAD dataset is downloaded and path is correct.")
return None, None
except Exception as e:
print(f"❌ Error during data preparation: {e}")
import traceback
traceback.print_exc()
return None, None
# Display discovered risk patterns
print("\nπŸ” Discovered Risk Patterns:")
for pattern_name, pattern_info in trainer.risk_discovery.discovered_patterns.items():
print(f" β€’ {pattern_name}")
print(f" Keywords: {', '.join(pattern_info['keywords'][:5])}")
# Train model
print("\n" + "=" * 80)
print("πŸ‹οΈ PHASE 2: MODEL TRAINING")
print("=" * 80)
try:
history = trainer.train(train_loader, val_loader)
except Exception as e:
print(f"❌ Error during training: {e}")
import traceback
traceback.print_exc()
return None, None
# Plot training history
print("\nπŸ“ˆ Plotting training history...")
plot_training_history(history, save_path=os.path.join(config.checkpoint_dir, 'training_history.png'))
# Save final model
print("\nπŸ’Ύ Saving final model...")
final_model_path = os.path.join(config.model_save_path, 'final_model.pt')
os.makedirs(config.model_save_path, exist_ok=True)
torch.save({
'model_state_dict': trainer.model.state_dict(),
'model_type': 'roberta-base',
'config': config,
'risk_discovery_model': trainer.risk_discovery,
'discovered_patterns': trainer.risk_discovery.discovered_patterns,
'training_history': history
}, final_model_path)
print(f"βœ… Model saved to: {final_model_path}")
# Save training summary
summary = {
'training_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'config': {
'batch_size': config.batch_size,
'num_epochs': config.num_epochs,
'learning_rate': config.learning_rate,
'device': config.device
},
'final_metrics': {
'train_loss': history['train_loss'][-1],
'val_loss': history['val_loss'][-1],
'train_acc': history['train_acc'][-1],
'val_acc': history['val_acc'][-1]
},
'num_discovered_risks': trainer.risk_discovery.n_clusters,
'discovered_patterns': list(trainer.risk_discovery.discovered_patterns.keys())
}
summary_path = os.path.join(config.checkpoint_dir, 'training_summary.json')
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
print(f"\nπŸ“„ Training summary saved to: {summary_path}")
# Print final results
print("\n" + "=" * 80)
print("βœ… TRAINING COMPLETE!")
print("=" * 80)
print(f"\nπŸ“Š Final Results:")
print(f" Train Loss: {history['train_loss'][-1]:.4f}")
print(f" Train Accuracy: {history['train_acc'][-1]:.4f}")
print(f" Val Loss: {history['val_loss'][-1]:.4f}")
print(f" Val Accuracy: {history['val_acc'][-1]:.4f}")
print(f"\n🎯 Next Steps:")
print(f" 1. Run evaluation: python evaluate.py")
print(f" 2. Apply calibration methods")
print(f" 3. Generate comprehensive analysis report")
return trainer, history
if __name__ == "__main__":
result = main()
if result is not None:
trainer, history = result
else:
print("\n❌ Training failed. Please check errors above.")
exit(1)