code2-repo-deBERTa / train.py
Deepu1965's picture
Upload folder using huggingface_hub
5c0f558 verified
"""
Main Training Script for Hierarchical Legal-DeBERTa
Executes Week 4-5: Model Training and Evaluation
Uses Hierarchical DeBERTa (context-aware) model
"""
import torch
import os
import json
import argparse
from datetime import datetime
from config import LegalBertConfig
from trainer import LegalBertTrainer
from utils import set_seed, plot_training_history
def main():
"""Execute Hierarchical Legal-DeBERTa training pipeline"""
# Parse arguments
parser = argparse.ArgumentParser(description='Train Hierarchical Legal-DeBERTa model')
parser.add_argument('--epochs', type=int, default=None,
help='Number of training epochs')
parser.add_argument('--batch-size', type=int, default=None,
help='Batch size for training')
args = parser.parse_args()
print("=" * 80)
print("πŸ›οΈ HIERARCHICAL LEGAL-DeBERTa TRAINING PIPELINE")
print("=" * 80)
# Initialize configuration
config = LegalBertConfig()
# Apply command-line overrides
if args.epochs is not None:
config.num_epochs = args.epochs
if args.batch_size is not None:
config.batch_size = args.batch_size
# Set random seed for reproducibility
set_seed(42)
print(f"\nπŸ“‹ Configuration:")
print(f" Model type: Hierarchical BERT (context-aware)")
print(f" Data path: {config.data_path}")
print(f" Device: {config.device}")
print(f" Batch size: {config.batch_size}")
print(f" Epochs: {config.num_epochs}")
print(f" Learning rate: {config.learning_rate}")
print(f" Risk discovery clusters: {config.risk_discovery_clusters}")
print(f" Hierarchical hidden dim: {config.hierarchical_hidden_dim}")
print(f" Hierarchical LSTM layers: {config.hierarchical_num_lstm_layers}")
# Initialize trainer
trainer = LegalBertTrainer(config)
# Prepare data with unsupervised risk discovery
print("\n" + "=" * 80)
print("πŸ“Š PHASE 1: DATA PREPARATION & RISK DISCOVERY")
print("=" * 80)
try:
train_loader, val_loader, test_loader = trainer.prepare_data(config.data_path)
except FileNotFoundError:
print(f"❌ Error: Dataset not found at {config.data_path}")
print("Please ensure CUAD dataset is downloaded and path is correct.")
return None, None
except Exception as e:
print(f"❌ Error during data preparation: {e}")
import traceback
traceback.print_exc()
return None, None
# Display discovered risk patterns
print("\nπŸ” Discovered Risk Patterns:")
for pattern_name, pattern_info in trainer.risk_discovery.discovered_patterns.items():
print(f" β€’ {pattern_name}")
print(f" Keywords: {', '.join(pattern_info['keywords'][:5])}")
# Train model
print("\n" + "=" * 80)
print("πŸ‹οΈ PHASE 2: MODEL TRAINING")
print("=" * 80)
try:
history = trainer.train(train_loader, val_loader)
except Exception as e:
print(f"❌ Error during training: {e}")
import traceback
traceback.print_exc()
return None, None
# Plot training history
print("\nπŸ“ˆ Plotting training history...")
plot_training_history(history, save_path=os.path.join(config.checkpoint_dir, 'training_history.png'))
# Save final model
print("\nπŸ’Ύ Saving final model...")
final_model_path = os.path.join(config.model_save_path, 'final_model.pt')
os.makedirs(config.model_save_path, exist_ok=True)
torch.save({
'model_state_dict': trainer.model.state_dict(),
'model_type': 'hierarchical',
'config': config,
'risk_discovery_model': trainer.risk_discovery,
'discovered_patterns': trainer.risk_discovery.discovered_patterns,
'training_history': history
}, final_model_path)
print(f"βœ… Model saved to: {final_model_path}")
# Save training summary
summary = {
'training_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'config': {
'batch_size': config.batch_size,
'num_epochs': config.num_epochs,
'learning_rate': config.learning_rate,
'device': config.device
},
'final_metrics': {
'train_loss': history['train_loss'][-1],
'val_loss': history['val_loss'][-1],
'train_acc': history['train_acc'][-1],
'val_acc': history['val_acc'][-1]
},
'num_discovered_risks': trainer.risk_discovery.n_clusters,
'discovered_patterns': list(trainer.risk_discovery.discovered_patterns.keys())
}
summary_path = os.path.join(config.checkpoint_dir, 'training_summary.json')
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
print(f"\nπŸ“„ Training summary saved to: {summary_path}")
# Print final results
print("\n" + "=" * 80)
print("βœ… TRAINING COMPLETE!")
print("=" * 80)
print(f"\nπŸ“Š Final Results:")
print(f" Train Loss: {history['train_loss'][-1]:.4f}")
print(f" Train Accuracy: {history['train_acc'][-1]:.4f}")
print(f" Val Loss: {history['val_loss'][-1]:.4f}")
print(f" Val Accuracy: {history['val_acc'][-1]:.4f}")
print(f"\n🎯 Next Steps:")
print(f" 1. Run evaluation: python evaluate.py")
print(f" 2. Apply calibration methods")
print(f" 3. Generate comprehensive analysis report")
return trainer, history
if __name__ == "__main__":
result = main()
if result is not None:
trainer, history = result
else:
print("\n❌ Training failed. Please check errors above.")
exit(1)