code2-repo-deBERTa / trainer.py
Deepu1965's picture
Upload folder using huggingface_hub
5c0f558 verified
"""
Legal-DeBERTa Training Pipeline - Learning-Based Risk Classification
PHASE 1 IMPROVEMENTS: Focal Loss, Rebalanced weights, Class boosting, LR scheduling
Memory Optimizations: Gradient Accumulation, Mixed Precision (FP16)
"""
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR
from torch.cuda.amp import autocast, GradScaler
import numpy as np
from typing import Dict, List, Tuple, Any
import os
from sklearn.metrics import accuracy_score, classification_report, recall_score
from sklearn.utils.class_weight import compute_class_weight
import json
import time
from config import LegalBertConfig
from model import HierarchicalLegalBERT, LegalBertTokenizer
from risk_discovery import UnsupervisedRiskDiscovery, LDARiskDiscovery
from data_loader import CUADDataLoader
from focal_loss import FocalLoss, compute_class_weights
from risk_postprocessing import merge_duplicate_topics, detect_duplicate_topics, validate_cluster_quality
def collate_batch(batch):
"""
Custom collate function to handle variable-length sequences in batch.
Pads all sequences to the maximum length in the batch.
"""
# Find max length in this batch
max_len = max(item['input_ids'].size(0) for item in batch)
# Prepare batched tensors
input_ids_batch = []
attention_mask_batch = []
risk_labels_batch = []
severity_scores_batch = []
importance_scores_batch = []
for item in batch:
input_ids = item['input_ids']
attention_mask = item['attention_mask']
current_len = input_ids.size(0)
# Pad if needed
if current_len < max_len:
padding_len = max_len - current_len
# Pad with 0 (PAD token) for input_ids
input_ids = torch.cat([input_ids, torch.zeros(padding_len, dtype=torch.long)])
# Pad with 0 for attention_mask (0 = don't attend)
attention_mask = torch.cat([attention_mask, torch.zeros(padding_len, dtype=torch.long)])
input_ids_batch.append(input_ids)
attention_mask_batch.append(attention_mask)
risk_labels_batch.append(item['risk_label'])
severity_scores_batch.append(item['severity_score'])
importance_scores_batch.append(item['importance_score'])
# Stack into batched tensors
return {
'input_ids': torch.stack(input_ids_batch),
'attention_mask': torch.stack(attention_mask_batch),
'risk_label': torch.stack(risk_labels_batch),
'severity_score': torch.stack(severity_scores_batch),
'importance_score': torch.stack(importance_scores_batch)
}
class LegalClauseDataset(Dataset):
"""Dataset for legal clauses with discovered risk labels"""
def __init__(self, clauses: List[str], risk_labels: List[int],
severity_scores: List[float], importance_scores: List[float],
tokenizer: LegalBertTokenizer, max_length: int = 512):
self.clauses = clauses
self.risk_labels = risk_labels
self.severity_scores = severity_scores
self.importance_scores = importance_scores
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.clauses)
def __getitem__(self, idx):
clause = self.clauses[idx]
# Tokenize
encoded = self.tokenizer.tokenize_clauses([clause], self.max_length)
return {
'input_ids': encoded['input_ids'].squeeze(0),
'attention_mask': encoded['attention_mask'].squeeze(0),
'risk_label': torch.tensor(self.risk_labels[idx], dtype=torch.long),
'severity_score': torch.tensor(self.severity_scores[idx], dtype=torch.float),
'importance_score': torch.tensor(self.importance_scores[idx], dtype=torch.float)
}
class LegalBertTrainer:
"""
Trainer for Legal-DeBERTa with discovered risk patterns.
NO hardcoded risk categories!
Includes memory optimizations for DeBERTa: gradient accumulation & mixed precision
"""
def __init__(self, config: LegalBertConfig):
self.config = config
self.device = torch.device(config.device)
# Initialize gradient scaler for mixed precision training
self.use_amp = config.fp16_training and torch.cuda.is_available()
self.scaler = GradScaler() if self.use_amp else None
if self.use_amp:
print("βœ… Mixed Precision (FP16) training enabled - saves GPU memory!")
# Gradient accumulation setup
self.gradient_accumulation_steps = getattr(config, 'gradient_accumulation_steps', 1)
if self.gradient_accumulation_steps > 1:
print(f"βœ… Gradient accumulation enabled: {self.gradient_accumulation_steps} steps")
print(f" Effective batch size: {config.batch_size * self.gradient_accumulation_steps}")
# Initialize risk discovery based on configured method
risk_method = config.risk_discovery_method.lower()
if risk_method == 'lda':
print(f"🎯 Using LDA (Topic Modeling) for risk discovery")
self.risk_discovery = LDARiskDiscovery(
n_clusters=config.risk_discovery_clusters,
doc_topic_prior=config.lda_doc_topic_prior,
topic_word_prior=config.lda_topic_word_prior,
max_iter=config.lda_max_iter,
max_features=config.lda_max_features,
learning_method=config.lda_learning_method,
random_state=42
)
elif risk_method == 'kmeans':
print(f"🎯 Using K-Means for risk discovery")
self.risk_discovery = UnsupervisedRiskDiscovery(
n_clusters=config.risk_discovery_clusters,
random_state=42
)
else:
print(f"⚠️ Unknown risk discovery method '{risk_method}', defaulting to LDA")
self.risk_discovery = LDARiskDiscovery(
n_clusters=config.risk_discovery_clusters,
doc_topic_prior=config.lda_doc_topic_prior,
topic_word_prior=config.lda_topic_word_prior,
max_iter=config.lda_max_iter,
max_features=config.lda_max_features,
learning_method=config.lda_learning_method,
random_state=42
)
self.tokenizer = LegalBertTokenizer(config.bert_model_name)
# Will be initialized during training
self.model = None
self.optimizer = None
self.scheduler = None
# Training state
self.training_history = {
'train_loss': [],
'val_loss': [],
'train_acc': [],
'val_acc': [],
'per_class_recall': [] # Track per-class recall for Classes 0 and 5
}
# PHASE 1 IMPROVEMENT: Initialize loss functions with Focal Loss
if config.use_focal_loss:
print("πŸ”₯ Using Focal Loss for classification (gamma=2.5)")
# Will be initialized after discovering class distribution
self.classification_loss = None # Set in prepare_data
else:
print("⚠️ Using standard CrossEntropyLoss (not recommended)")
self.classification_loss = nn.CrossEntropyLoss()
self.regression_loss = nn.MSELoss()
# Early stopping state
self.best_val_loss = float('inf')
self.patience_counter = 0
def prepare_data(self, data_path: str) -> Tuple[DataLoader, DataLoader, DataLoader]:
"""Load data and discover risk patterns"""
print("πŸ”„ Preparing data with unsupervised risk discovery...")
# Load CUAD data
data_loader = CUADDataLoader(data_path)
df_clauses, contracts = data_loader.load_data()
splits = data_loader.create_splits()
# Get training clauses for risk discovery
train_clauses = splits['train']['clause_text'].tolist()
# Discover risk patterns from training data
discovered_patterns = self.risk_discovery.discover_risk_patterns(train_clauses)
# PHASE 2 IMPROVEMENT: Validate and merge duplicate topics
print("\nπŸ” Validating discovered risk patterns...")
validation_report = validate_cluster_quality(discovered_patterns, min_cluster_size=150)
if not validation_report['is_valid']:
print("⚠️ Cluster quality issues detected:")
for issue in validation_report['issues']:
print(f" - {issue}")
if validation_report['warnings']:
for warning in validation_report['warnings']:
print(f" ⚠️ {warning}")
# Detect and merge duplicate topics (e.g., Classes 0 and 6 both named "LIABILITY")
merge_rules = detect_duplicate_topics(discovered_patterns)
if merge_rules:
print(f"\nπŸ”§ Merging {len(merge_rules)} duplicate topic groups...")
discovered_patterns, original_labels = merge_duplicate_topics(
discovered_patterns,
self.risk_discovery.cluster_labels,
merge_rules
)
# Update risk discovery with merged results
self.risk_discovery.discovered_patterns = discovered_patterns
self.risk_discovery.cluster_labels = original_labels
self.risk_discovery.n_clusters = len(discovered_patterns)
print(f"βœ… Merged to {self.risk_discovery.n_clusters} distinct risk categories\n")
# PHASE 1 IMPROVEMENT: Compute class weights with minority boost
# Get training labels to compute balanced weights
train_risk_labels = self.risk_discovery.get_risk_labels(train_clauses)
if self.config.use_focal_loss:
print("\nπŸ“Š Computing class weights for Focal Loss...")
class_weights = compute_class_weights(
train_risk_labels,
num_classes=self.risk_discovery.n_clusters,
minority_boost=self.config.minority_class_boost
)
# Initialize Focal Loss with computed weights
self.classification_loss = FocalLoss(
alpha=class_weights,
gamma=self.config.focal_loss_gamma,
reduction='mean'
)
print(f"βœ… Focal Loss initialized with Ξ³={self.config.focal_loss_gamma}\n")
# Create datasets for each split
datasets = {}
dataloaders = {}
for split_name, split_data in splits.items():
clauses = split_data['clause_text'].tolist()
# Get discovered risk labels
risk_labels = self.risk_discovery.get_risk_labels(clauses)
# Generate synthetic severity and importance scores
# (In practice, these could be learned from other signals)
severity_scores = self._generate_synthetic_scores(clauses, 'severity')
importance_scores = self._generate_synthetic_scores(clauses, 'importance')
# Create dataset
dataset = LegalClauseDataset(
clauses=clauses,
risk_labels=risk_labels,
severity_scores=severity_scores,
importance_scores=importance_scores,
tokenizer=self.tokenizer,
max_length=self.config.max_sequence_length
)
datasets[split_name] = dataset
# Create dataloader
shuffle = (split_name == 'train')
dataloader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=shuffle,
num_workers=0, # Set to 0 to avoid multiprocessing issues
collate_fn=collate_batch # Custom collate for variable-length sequences
)
dataloaders[split_name] = dataloader
print(f"βœ… Data preparation complete!")
print(f"πŸ“Š Discovered {len(discovered_patterns)} risk patterns")
return dataloaders['train'], dataloaders['val'], dataloaders['test']
def _generate_synthetic_scores(self, clauses: List[str], score_type: str) -> List[float]:
"""
Calculate severity/importance scores based on extracted text features
NOT synthetic - based on actual risk analysis from the clauses
"""
scores = []
for clause in clauses:
# Extract risk features from the clause
features = self.risk_discovery.extract_risk_features(clause)
if score_type == 'severity':
# Calculate severity based on risk indicators
# Higher severity for liability, prohibition, and obligation terms
score = (
features.get('risk_intensity', 0) * 30 + # Risk intensity (liability, prohibition)
features.get('obligation_strength', 0) * 20 + # Obligation strength
features.get('prohibition_terms_density', 0) * 100 + # Prohibitions are severe
features.get('liability_terms_density', 0) * 100 + # Liability is severe
min(features.get('monetary_terms_count', 0) * 0.5, 2) # Monetary impact
)
else: # importance
# Calculate importance based on legal complexity and clause characteristics
score = (
features.get('legal_complexity', 0) * 30 + # Legal complexity
min(features.get('clause_length', 0) / 50, 1) * 20 + # Longer = potentially more important
features.get('conditional_risk_density', 0) * 100 + # Conditional clauses are important
features.get('obligation_terms_complexity', 0) * 100 + # Obligations are important
features.get('temporal_urgency_density', 0) * 50 # Time-sensitive = important
)
# Normalize to 0-10 scale
normalized_score = min(max(score, 0), 10)
scores.append(normalized_score)
return scores
def setup_training(self, train_loader: DataLoader):
"""Initialize model, optimizer, and scheduler"""
num_discovered_risks = self.risk_discovery.n_clusters
# Initialize Hierarchical BERT model (context-aware)
print("πŸ“Š Using Hierarchical BERT model (context-aware)")
self.model = HierarchicalLegalBERT(
config=self.config,
num_discovered_risks=num_discovered_risks,
hidden_dim=self.config.hierarchical_hidden_dim,
num_lstm_layers=self.config.hierarchical_num_lstm_layers
).to(self.device)
# Initialize optimizer
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
# PHASE 1 IMPROVEMENT: Initialize OneCycleLR scheduler
if self.config.use_lr_scheduler:
total_steps = len(train_loader) * self.config.num_epochs
self.scheduler = OneCycleLR(
self.optimizer,
max_lr=self.config.learning_rate,
total_steps=total_steps,
pct_start=self.config.scheduler_pct_start, # 10% warmup
anneal_strategy='cos',
div_factor=25.0, # initial_lr = max_lr / 25
final_div_factor=10000.0 # min_lr = initial_lr / 10000
)
print(f"πŸ“ˆ OneCycleLR scheduler initialized (warmup={self.config.scheduler_pct_start*100:.0f}%)")
else:
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=len(train_loader) * self.config.num_epochs
)
print("⚠️ Using basic CosineAnnealingLR (not recommended)")
print(f"πŸ—οΈ Model initialized with {num_discovered_risks} discovered risk categories")
def compute_loss(self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Compute multi-task loss"""
# Classification loss (discovered risk patterns)
classification_loss = self.classification_loss(
outputs['risk_logits'],
batch['risk_label']
)
# Severity regression loss
severity_loss = self.regression_loss(
outputs['severity_score'],
batch['severity_score']
)
# Importance regression loss
importance_loss = self.regression_loss(
outputs['importance_score'],
batch['importance_score']
)
# Weighted combination
total_loss = (
self.config.task_weights['classification'] * classification_loss +
self.config.task_weights['severity'] * severity_loss +
self.config.task_weights['importance'] * importance_loss
)
return {
'total_loss': total_loss,
'classification_loss': classification_loss,
'severity_loss': severity_loss,
'importance_loss': importance_loss
}
def train_epoch(self, train_loader: DataLoader, epoch: int) -> Tuple[float, float, Dict[str, float]]:
"""Train for one epoch with gradient accumulation and mixed precision"""
self.model.train()
total_loss = 0
correct_predictions = 0
total_samples = 0
loss_components = {'classification': 0, 'severity': 0, 'importance': 0}
# Zero gradients at start
self.optimizer.zero_grad()
for batch_idx, batch in enumerate(train_loader):
# Move batch to device
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
risk_labels = batch['risk_label'].to(self.device)
severity_scores = batch['severity_score'].to(self.device)
importance_scores = batch['importance_score'].to(self.device)
# Mixed precision training
with autocast(enabled=self.use_amp):
# Forward pass (hierarchical model in training mode)
outputs = self.model.forward_single_clause(input_ids, attention_mask)
# Prepare batch for loss computation
batch_for_loss = {
'risk_label': risk_labels,
'severity_score': severity_scores,
'importance_score': importance_scores
}
# Compute loss
losses = self.compute_loss(outputs, batch_for_loss)
# Scale loss by accumulation steps
scaled_loss = losses['total_loss'] / self.gradient_accumulation_steps
# Backward pass with gradient scaling (for mixed precision)
if self.use_amp:
self.scaler.scale(scaled_loss).backward()
else:
scaled_loss.backward()
# Update weights every gradient_accumulation_steps
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
# PHASE 1 IMPROVEMENT: Gradient clipping
if self.use_amp:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.config.gradient_clip_norm
)
# Optimizer step
if self.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
# Update metrics
total_loss += losses['total_loss'].item()
# Classification accuracy
predictions = torch.argmax(outputs['risk_logits'], dim=-1)
correct_predictions += (predictions == risk_labels).sum().item()
total_samples += risk_labels.size(0)
# Loss components
loss_components['classification'] += losses['classification_loss'].item()
loss_components['severity'] += losses['severity_loss'].item()
loss_components['importance'] += losses['importance_loss'].item()
# Progress logging
if batch_idx % 50 == 0:
print(f" Batch {batch_idx}/{len(train_loader)}, Loss: {losses['total_loss'].item():.4f}")
avg_loss = total_loss / len(train_loader)
accuracy = correct_predictions / total_samples
# Average loss components
for key in loss_components:
loss_components[key] /= len(train_loader)
return avg_loss, accuracy, loss_components
def validate_epoch(self, val_loader: DataLoader) -> Tuple[float, float, np.ndarray]:
"""Validate for one epoch with per-class recall tracking"""
self.model.eval()
total_loss = 0
correct_predictions = 0
total_samples = 0
# PHASE 1 IMPROVEMENT: Track predictions and labels for per-class metrics
all_predictions = []
all_labels = []
with torch.no_grad():
for batch in val_loader:
# Move batch to device
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
risk_labels = batch['risk_label'].to(self.device)
severity_scores = batch['severity_score'].to(self.device)
importance_scores = batch['importance_score'].to(self.device)
# Forward pass (hierarchical model in training mode)
outputs = self.model.forward_single_clause(input_ids, attention_mask)
# Prepare batch for loss computation
batch_for_loss = {
'risk_label': risk_labels,
'severity_score': severity_scores,
'importance_score': importance_scores
}
# Compute loss
losses = self.compute_loss(outputs, batch_for_loss)
total_loss += losses['total_loss'].item()
# Classification accuracy
predictions = torch.argmax(outputs['risk_logits'], dim=-1)
correct_predictions += (predictions == risk_labels).sum().item()
total_samples += risk_labels.size(0)
# Store for per-class metrics
all_predictions.extend(predictions.cpu().numpy())
all_labels.extend(risk_labels.cpu().numpy())
avg_loss = total_loss / len(val_loader)
accuracy = correct_predictions / total_samples
# PHASE 1 IMPROVEMENT: Compute per-class recall (especially for Classes 0 and 5)
per_class_recall = recall_score(
all_labels,
all_predictions,
average=None, # Return recall for each class
zero_division=0
)
return avg_loss, accuracy, per_class_recall
def train(self, train_loader: DataLoader, val_loader: DataLoader) -> Dict[str, List[float]]:
"""Complete training pipeline"""
print(f"πŸš€ Starting Legal-DeBERTa training...")
print(f"Device: {self.device}")
print(f"Epochs: {self.config.num_epochs}")
print(f"Batch size: {self.config.batch_size}")
self.setup_training(train_loader)
# Track total training time
total_start_time = time.time()
for epoch in range(self.config.num_epochs):
print(f"\nπŸ“ˆ Epoch {epoch+1}/{self.config.num_epochs}")
# Track epoch time
epoch_start_time = time.time()
# Train
train_loss, train_acc, loss_components = self.train_epoch(train_loader, epoch)
# Validate (now returns per-class recall too)
val_loss, val_acc, per_class_recall = self.validate_epoch(val_loader)
# Calculate epoch time
epoch_time = time.time() - epoch_start_time
# Store history
self.training_history['train_loss'].append(train_loss)
self.training_history['val_loss'].append(val_loss)
self.training_history['train_acc'].append(train_acc)
self.training_history['val_acc'].append(val_acc)
self.training_history['per_class_recall'].append(per_class_recall.tolist())
# Print detailed results
print(f" Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
print(f" Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
print(f" Loss Components - Class: {loss_components['classification']:.4f}, "
f"Sev: {loss_components['severity']:.4f}, Imp: {loss_components['importance']:.4f}")
# PHASE 1 IMPROVEMENT: Display per-class recall (focus on Classes 0 and 5)
print(f" Per-Class Recall:")
critical_classes = [0, 5] # Classes with 0% recall in previous training
for cls_idx, recall in enumerate(per_class_recall):
marker = " ⚠️ CRITICAL" if cls_idx in critical_classes else ""
print(f" Class {cls_idx}: {recall:.3f}{marker}")
# Display epoch time
print(f" ⏱️ Epoch Time: {epoch_time:.2f}s ({epoch_time/60:.2f} minutes)")
# PHASE 1 IMPROVEMENT: Early stopping check
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.patience_counter = 0
print(f" βœ… New best validation loss: {val_loss:.4f}")
else:
self.patience_counter += 1
print(f" ⚠️ No improvement ({self.patience_counter}/{self.config.early_stopping_patience})")
if self.patience_counter >= self.config.early_stopping_patience:
print(f"\nπŸ›‘ Early stopping triggered after {epoch+1} epochs")
break
# Log results (optional: save checkpoint)
print(f" πŸ“Š Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
print(f" πŸ“Š Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
print(f" πŸ” Loss Components:")
print(f" Classification: {loss_components['classification']:.4f}")
print(f" Severity: {loss_components['severity']:.4f}")
print(f" Importance: {loss_components['importance']:.4f}")
print(f" ⏱️ Epoch Time: {epoch_time:.2f}s ({epoch_time/60:.2f} minutes)")
# Save checkpoint
self.save_checkpoint(epoch)
# Calculate total training time
total_time = time.time() - total_start_time
print(f"\nβœ… Training complete!")
print(f"⏱️ Total Training Time: {total_time:.2f}s ({total_time/60:.2f} minutes / {total_time/3600:.2f} hours)")
print(f"⏱️ Average Time per Epoch: {total_time/self.config.num_epochs:.2f}s")
return self.training_history
def save_checkpoint(self, epoch: int):
"""Save model checkpoint"""
if not os.path.exists(self.config.checkpoint_dir):
os.makedirs(self.config.checkpoint_dir)
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'training_history': self.training_history,
'config': self.config,
'discovered_patterns': self.risk_discovery.discovered_patterns
}
checkpoint_path = os.path.join(
self.config.checkpoint_dir,
f'legal_bert_epoch_{epoch+1}.pt'
)
torch.save(checkpoint, checkpoint_path)
print(f"πŸ’Ύ Checkpoint saved: {checkpoint_path}")
def load_checkpoint(self, checkpoint_path: str):
"""Load model checkpoint"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Restore model
num_discovered_risks = len(checkpoint['discovered_patterns'])
self.model = HierarchicalLegalBERT(
config=checkpoint['config'],
num_discovered_risks=num_discovered_risks,
hidden_dim=checkpoint['config'].hierarchical_hidden_dim,
num_lstm_layers=checkpoint['config'].hierarchical_num_lstm_layers
).to(self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
# Restore training state
self.training_history = checkpoint['training_history']
self.risk_discovery.discovered_patterns = checkpoint['discovered_patterns']
print(f"βœ… Checkpoint loaded: {checkpoint_path}")
return checkpoint['epoch']