Phase 4 implementation: multi-task learning, P-Tuning, SI/LwF continual learning, automated tests, deployment templates, troubleshooting guide
5a593b3 | """ | |
| Continual Learning Utilities for Nexuss Transformer Framework | |
| Mechanisms to avoid catastrophic forgetting during continuous training | |
| """ | |
| import torch | |
| from torch import nn | |
| from torch.utils.data import Dataset, DataLoader | |
| from dataclasses import dataclass, field | |
| from typing import Optional, List, Dict, Any, Tuple | |
| from collections import OrderedDict | |
| import copy | |
| class EWCConfig: | |
| """Configuration for Elastic Weight Consolidation""" | |
| ewc_lambda: float = 1000.0 # Strength of EWC regularization | |
| fisher_samples: int = 200 # Number of samples to estimate Fisher information | |
| damping: float = 0.1 # Damping factor for Fisher matrix | |
| mc_samples: int = 1 # Monte Carlo samples for Fisher estimation | |
| class ReplayConfig: | |
| """Configuration for Experience Replay""" | |
| replay_size: int = 1000 # Size of replay buffer | |
| replay_ratio: float = 0.5 # Ratio of replay data in each batch | |
| selection_strategy: str = "uniform" # uniform, recent, diverse | |
| reservoir_sampling: bool = True # Use reservoir sampling for streaming data | |
| class GEMConfig: | |
| """Configuration for Gradient Episodic Memory""" | |
| memory_size: int = 100 # Number of examples per task | |
| num_tasks: int = 5 # Expected number of tasks | |
| use_quadprog: bool = True # Use quadratic programming for constraint solving | |
| class ContinualLearningConfig: | |
| """Unified configuration for continual learning strategies""" | |
| strategy: str = "none" # none, ewc, replay, gem, lwf | |
| ewc: Optional[EWCConfig] = field(default_factory=EWCConfig) | |
| replay: Optional[ReplayConfig] = field(default_factory=ReplayConfig) | |
| gem: Optional[GEMConfig] = field(default_factory=GEMConfig) | |
| # LwF (Learning without Forgetting) settings | |
| lwf_alpha: float = 1.0 # Distillation loss weight | |
| lwf_temperature: float = 2.0 # Temperature for knowledge distillation | |
| # Regularization | |
| weight_decay: float = 0.01 | |
| grad_clip: float = 1.0 | |
| class EWCRegularizer: | |
| """Elastic Weight Consolidation implementation""" | |
| def __init__(self, model: nn.Module, config: EWCConfig): | |
| self.model = model | |
| self.config = config | |
| self.fisher: Dict[str, torch.Tensor] = {} | |
| self.optimal_params: Dict[str, torch.Tensor] = {} | |
| def compute_fisher(self, dataloader: DataLoader, device: torch.device): | |
| """Compute Fisher Information Matrix diagonal approximation""" | |
| self.model.train() | |
| fisher_dict = {name: torch.zeros_like(param) | |
| for name, param in self.model.named_parameters() | |
| if param.requires_grad} | |
| samples_processed = 0 | |
| for batch in dataloader: | |
| if samples_processed >= self.config.fisher_samples: | |
| break | |
| self.model.zero_grad() | |
| # Forward pass | |
| inputs = batch["input_ids"].to(device) if isinstance(batch, dict) else batch.to(device) | |
| outputs = self.model(inputs) | |
| # Compute log-likelihood gradient | |
| log_probs = torch.log_softmax(outputs.logits, dim=-1) | |
| loss = log_probs.mean() | |
| # Compute gradients | |
| grads = torch.autograd.grad(loss, [p for p in self.model.parameters() if p.requires_grad], | |
| retain_graph=False) | |
| # Accumulate squared gradients (Fisher diagonal) | |
| for (name, _), grad in zip(self.model.named_parameters(), grads): | |
| if name in fisher_dict: | |
| fisher_dict[name] += grad.pow(2) | |
| samples_processed += inputs.size(0) | |
| # Average and store | |
| n_samples = max(samples_processed, 1) | |
| self.fisher = {name: tensor / n_samples + self.config.damping | |
| for name, tensor in fisher_dict.items()} | |
| # Store optimal parameters | |
| self.optimal_params = {name: param.clone().detach() | |
| for name, param in self.model.named_parameters() | |
| if param.requires_grad} | |
| def compute_ewc_loss(self) -> torch.Tensor: | |
| """Compute EWC regularization loss""" | |
| if not self.fisher or not self.optimal_params: | |
| return torch.tensor(0.0) | |
| ewc_loss = torch.tensor(0.0) | |
| for name, param in self.model.named_parameters(): | |
| if param.requires_grad and name in self.fisher: | |
| delta = param - self.optimal_params[name] | |
| ewc_loss += (self.fisher[name] * delta.pow(2)).sum() | |
| return self.config.ewc_lambda * ewc_loss | |
| class ReplayBuffer: | |
| """Experience Replay Buffer for continual learning""" | |
| def __init__(self, config: ReplayConfig): | |
| self.config = config | |
| self.buffer: List[Dict[str, Any]] = [] | |
| self.task_data: Dict[int, List[Dict[str, Any]]] = {} | |
| def add(self, samples: List[Dict[str, Any]], task_id: Optional[int] = None): | |
| """Add samples to replay buffer""" | |
| if self.config.reservoir_sampling and len(self.buffer) + len(samples) > self.config.replay_size: | |
| # Reservoir sampling for streaming data | |
| for sample in samples: | |
| if len(self.buffer) < self.config.replay_size: | |
| self.buffer.append(sample) | |
| else: | |
| # Randomly replace with decreasing probability | |
| j = torch.randint(0, len(self.buffer) + 1, (1,)).item() | |
| if j < self.config.replay_size: | |
| self.buffer[j] = sample | |
| else: | |
| self.buffer.extend(samples) | |
| # Trim if exceeds size | |
| if len(self.buffer) > self.config.replay_size: | |
| if self.config.selection_strategy == "recent": | |
| self.buffer = self.buffer[-self.config.replay_size:] | |
| elif self.config.selection_strategy == "diverse": | |
| # Simple diversity: keep every nth item | |
| step = len(self.buffer) // self.config.replay_size | |
| self.buffer = self.buffer[::step][:self.config.replay_size] | |
| else: # uniform | |
| indices = torch.randperm(len(self.buffer))[:self.config.replay_size] | |
| self.buffer = [self.buffer[i] for i in indices] | |
| # Store by task if task_id provided | |
| if task_id is not None: | |
| if task_id not in self.task_data: | |
| self.task_data[task_id] = [] | |
| self.task_data[task_id].extend(samples) | |
| def get_batch(self, current_batch: Dict[str, Any]) -> Dict[str, Any]: | |
| """Mix replay data with current batch""" | |
| if not self.buffer: | |
| return current_batch | |
| replay_size = int(current_batch["input_ids"].size(0) * self.config.replay_ratio) | |
| replay_size = min(replay_size, len(self.buffer)) | |
| if replay_size == 0: | |
| return current_batch | |
| # Sample from buffer | |
| indices = torch.randperm(len(self.buffer))[:replay_size] | |
| replay_samples = [self.buffer[i] for i in indices] | |
| # Combine with current batch (simplified - in practice need proper merging) | |
| # This is a placeholder - actual implementation depends on your data format | |
| return current_batch # TODO: Implement proper batch merging | |
| def get_task_buffer(self, task_id: int) -> List[Dict[str, Any]]: | |
| """Get replay buffer for specific task""" | |
| return self.task_data.get(task_id, []) | |
| class GEMOptimizer: | |
| """Gradient Episodic Memory optimizer""" | |
| def __init__(self, model: nn.Module, config: GEMConfig): | |
| self.model = model | |
| self.config = config | |
| self.memory: Dict[int, List[Dict[str, Any]]] = {i: [] for i in range(config.num_tasks)} | |
| self.gradient_memory: Dict[int, torch.Tensor] = {} | |
| def store_in_memory(self, samples: List[Dict[str, Any]], task_id: int): | |
| """Store samples in task-specific memory""" | |
| available_space = self.config.memory_size - len(self.memory[task_id]) | |
| if available_space >= len(samples): | |
| self.memory[task_id].extend(samples) | |
| else: | |
| # Random subsample | |
| indices = torch.randperm(len(samples))[:available_space] | |
| self.memory[task_id].extend([samples[i] for i in indices]) | |
| def compute_gradient_constraints(self, task_id: int, device: torch.device) -> List[torch.Tensor]: | |
| """Compute stored gradients for previous tasks""" | |
| constraints = [] | |
| for prev_task_id in range(task_id): | |
| if prev_task_id not in self.gradient_memory: | |
| continue | |
| constraints.append(self.gradient_memory[prev_task_id]) | |
| return constraints | |
| def project_gradient(self, gradient: torch.Tensor, constraints: List[torch.Tensor]) -> torch.Tensor: | |
| """Project gradient to satisfy memory constraints using quadratic programming""" | |
| if not constraints: | |
| return gradient | |
| projected = gradient.clone() | |
| for constraint in constraints: | |
| # Check if gradient violates constraint | |
| dot_product = torch.dot(projected.flatten(), constraint.flatten()) | |
| if dot_product < 0: | |
| # Project gradient | |
| norm_sq = constraint.pow(2).sum() | |
| if norm_sq > 1e-8: | |
| projection_coef = dot_product / norm_sq | |
| projected -= projection_coef * constraint | |
| return projected | |
| def update_gradient_memory(self, task_id: int, dataloader: DataLoader, device: torch.device): | |
| """Update stored gradients for current task""" | |
| self.model.eval() | |
| # Compute average gradient over memory samples | |
| total_gradient = None | |
| count = 0 | |
| for batch in dataloader: | |
| self.model.zero_grad() | |
| inputs = batch["input_ids"].to(device) if isinstance(batch, dict) else batch.to(device) | |
| outputs = self.model(inputs) | |
| loss = outputs.loss if hasattr(outputs, 'loss') else outputs.logits.mean() | |
| grads = torch.autograd.grad(loss, [p for p in self.model.parameters() if p.requires_grad]) | |
| # Flatten and concatenate all gradients | |
| flat_grad = torch.cat([g.flatten() for g in grads]) | |
| if total_gradient is None: | |
| total_gradient = flat_grad | |
| else: | |
| total_gradient += flat_grad | |
| count += 1 | |
| if count > 0 and total_gradient is not None: | |
| self.gradient_memory[task_id] = total_gradient / count | |
| class LwFLoss(nn.Module): | |
| """Learning without Forgetting loss using knowledge distillation""" | |
| def __init__(self, config: ContinualLearningConfig): | |
| super().__init__() | |
| self.config = config | |
| self.kl_div = nn.KLDivLoss(reduction='batchmean') | |
| def forward(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor) -> torch.Tensor: | |
| """Compute LwF distillation loss""" | |
| T = self.config.lwf_temperature | |
| # Apply temperature scaling | |
| student_log_probs = torch.log_softmax(student_logits / T, dim=-1) | |
| teacher_probs = torch.softmax(teacher_logits / T, dim=-1) | |
| # Knowledge distillation loss | |
| kd_loss = self.kl_div(student_log_probs, teacher_probs) * (T ** 2) | |
| return self.config.lwf_alpha * kd_loss | |
| class SIRegularizer: | |
| """Synaptic Intelligence implementation for continual learning.""" | |
| def __init__(self, model: nn.Module, c: float = 0.1): | |
| self.model = model | |
| self.c = c # Importance weight | |
| self.importance: Dict[str, torch.Tensor] = {} | |
| self.prev_params: Dict[str, torch.Tensor] = {} | |
| self.trajectory: Dict[str, torch.Tensor] = {} | |
| def initialize_trajectory(self): | |
| """Initialize trajectory tracking for parameters.""" | |
| self.prev_params = { | |
| name: param.clone().detach() | |
| for name, param in self.model.named_parameters() | |
| if param.requires_grad | |
| } | |
| self.trajectory = { | |
| name: torch.zeros_like(param) | |
| for name, param in self.model.named_parameters() | |
| if param.requires_grad | |
| } | |
| self.importance = { | |
| name: torch.zeros_like(param) | |
| for name, param in self.model.named_parameters() | |
| if param.requires_grad | |
| } | |
| def update_trajectory(self): | |
| """Update parameter change trajectory after each training step.""" | |
| with torch.no_grad(): | |
| for name, param in self.model.named_parameters(): | |
| if param.requires_grad and name in self.prev_params: | |
| delta = param - self.prev_params[name] | |
| self.trajectory[name] += delta.pow(2) | |
| self.prev_params[name] = param.clone().detach() | |
| def compute_importance(self, loss_change: float): | |
| """ | |
| Compute parameter importance based on loss change. | |
| Args: | |
| loss_change: Change in loss from previous iteration | |
| """ | |
| if loss_change > 0: # Only update if loss decreased | |
| for name in self.importance: | |
| if name in self.trajectory: | |
| denom = self.trajectory[name] + 1e-8 | |
| self.importance[name] += loss_change / denom | |
| def compute_si_loss(self) -> torch.Tensor: | |
| """Compute Synaptic Intelligence regularization loss.""" | |
| si_loss = torch.tensor(0.0) | |
| for name, param in self.model.named_parameters(): | |
| if param.requires_grad and name in self.importance: | |
| delta = param - self.prev_params.get(name, param) | |
| si_loss += (self.importance[name] * delta.pow(2)).sum() | |
| return self.c * si_loss | |
| class LwFRegularizer(nn.Module): | |
| """Learning without Forgetting using knowledge distillation.""" | |
| def __init__(self, alpha: float = 0.5, temperature: float = 2.0): | |
| super().__init__() | |
| self.alpha = alpha # Distillation loss weight | |
| self.temperature = temperature | |
| self.kl_div = nn.KLDivLoss(reduction='batchmean') | |
| self.old_outputs: Dict[str, torch.Tensor] = {} | |
| def store_old_outputs(self, task_name: str, outputs: torch.Tensor): | |
| """Store outputs from old model for distillation.""" | |
| self.old_outputs[task_name] = outputs.detach() | |
| def clear_old_outputs(self): | |
| """Clear stored old outputs.""" | |
| self.old_outputs.clear() | |
| def forward( | |
| self, | |
| student_logits: torch.Tensor, | |
| teacher_logits: torch.Tensor, | |
| task_name: Optional[str] = None | |
| ) -> torch.Tensor: | |
| """ | |
| Compute LwF distillation loss. | |
| Args: | |
| student_logits: Current model logits | |
| teacher_logits: Old model logits (stored or provided) | |
| task_name: Optional task name for stored outputs | |
| Returns: | |
| Knowledge distillation loss | |
| """ | |
| T = self.temperature | |
| # Apply temperature scaling | |
| student_log_probs = torch.log_softmax(student_logits / T, dim=-1) | |
| teacher_probs = torch.softmax(teacher_logits / T, dim=-1) | |
| # Knowledge distillation loss | |
| kd_loss = self.kl_div(student_log_probs, teacher_probs) * (T ** 2) | |
| return self.alpha * kd_loss | |
| def create_continual_learning_wrapper(trainer, config: ContinualLearningConfig): | |
| """ | |
| Wrap existing trainer with continual learning capabilities. | |
| Returns modified trainer with CL methods integrated. | |
| """ | |
| if config.strategy == "ewc": | |
| trainer.ewc_regularizer = EWCRegularizer(trainer.model, config.ewc) | |
| # Hook into training loop to add EWC loss | |
| original_compute_loss = trainer.compute_loss | |
| def compute_loss_with_ewc(model, inputs, return_outputs=False): | |
| loss = original_compute_loss(model, inputs, return_outputs) | |
| ewc_loss = trainer.ewc_regularizer.compute_ewc_loss() | |
| if return_outputs: | |
| return loss + ewc_loss, outputs | |
| return loss + ewc_loss | |
| trainer.compute_loss = compute_loss_with_ewc | |
| elif config.strategy == "replay": | |
| trainer.replay_buffer = ReplayBuffer(config.replay) | |
| # Modify data loading to include replay | |
| # Implementation depends on trainer's data loading mechanism | |
| elif config.strategy == "gem": | |
| trainer.gem_optimizer = GEMOptimizer(trainer.model, config.gem) | |
| # Hook into optimization step to project gradients | |
| # Implementation depends on trainer's optimization loop | |
| elif config.strategy == "lwf": | |
| trainer.lwf_loss = LwFLoss(config) | |
| # Store teacher model outputs for distillation | |
| # Implementation depends on training setup | |
| elif config.strategy == "si": | |
| trainer.si_regularizer = SIRegularizer(trainer.model, c=config.weight_decay) | |
| # Initialize trajectory tracking | |
| trainer.si_regularizer.initialize_trajectory() | |
| # Hook into training loop | |
| original_compute_loss = trainer.compute_loss | |
| def compute_loss_with_si(model, inputs, return_outputs=False): | |
| loss = original_compute_loss(model, inputs, return_outputs) | |
| si_loss = trainer.si_regularizer.compute_si_loss() | |
| if return_outputs: | |
| return loss + si_loss, outputs | |
| return loss + si_loss | |
| trainer.compute_loss = compute_loss_with_si | |
| # Hook into optimizer step to update trajectory | |
| original_step = trainer.optimizer.step if hasattr(trainer, 'optimizer') else None | |
| if original_step: | |
| def step_with_trajectory(): | |
| original_step() | |
| trainer.si_regularizer.update_trajectory() | |
| trainer.optimizer.step = step_with_trajectory | |
| return trainer | |
| class ContinualLearningWrapper: | |
| """ | |
| High-level wrapper for applying continual learning methods. | |
| Provides a unified API for EWC, SI, and LwF regularization. | |
| Args: | |
| model: Model to wrap | |
| method: Continual learning method (ewc, si, lwf) | |
| """ | |
| def __init__(self, model: nn.Module, method: str = "ewc"): | |
| self.model = model | |
| self.method = method | |
| self.ewc = None | |
| self.si = None | |
| self.lwf = None | |
| if method == "ewc": | |
| self.ewc = EWCRegularizer(model, EWCConfig()) | |
| elif method == "si": | |
| self.si = SIRegularizer(model) | |
| self.si.initialize_trajectory() | |
| elif method == "lwf": | |
| self.lwf = LwFRegularizer() | |
| def apply_ewc_regularization(self, lambda_ewc: float = 0.5): | |
| """Apply Elastic Weight Consolidation regularization.""" | |
| if self.ewc is None: | |
| self.ewc = EWCRegularizer(self.model, EWCConfig(ewc_lambda=lambda_ewc)) | |
| else: | |
| self.ewc.config.ewc_lambda = lambda_ewc | |
| return self | |
| def apply_si_regularization(self, c: float = 0.1): | |
| """Apply Synaptic Intelligence regularization.""" | |
| if self.si is None: | |
| self.si = SIRegularizer(self.model, c=c) | |
| self.si.initialize_trajectory() | |
| else: | |
| self.si.c = c | |
| return self | |
| def apply_lwf_regularization(self, alpha: float = 0.5): | |
| """Apply Learning without Forgetting regularization.""" | |
| if self.lwf is None: | |
| self.lwf = LwFRegularizer(alpha=alpha) | |
| else: | |
| self.lwf.alpha = alpha | |
| return self | |
| def compute_fisher(self, dataloader: DataLoader, device: torch.device): | |
| """Compute Fisher information matrix for EWC.""" | |
| if self.ewc: | |
| self.ewc.compute_fisher(dataloader, device) | |
| def get_regularization_loss(self) -> torch.Tensor: | |
| """Get current regularization loss.""" | |
| if self.ewc: | |
| return self.ewc.compute_ewc_loss() | |
| elif self.si: | |
| return self.si.compute_si_loss() | |
| return torch.tensor(0.0) | |
| def progressive_unfreeze( | |
| self, | |
| start_layers: int = 4, | |
| unfreeze_every_n_epochs: int = 2, | |
| max_layers: Optional[int] = None | |
| ): | |
| """ | |
| Progressive unfreezing strategy for continual learning. | |
| Args: | |
| start_layers: Number of layers to keep unfrozen initially | |
| unfreeze_every_n_epochs: Epochs between unfreezing | |
| max_layers: Maximum layers to unfreeze (None = all) | |
| """ | |
| self.start_layers = start_layers | |
| self.unfreeze_every_n_epochs = unfreeze_every_n_epochs | |
| self.max_layers = max_layers | |
| self.current_epoch = 0 | |
| # Initially freeze all but top layers | |
| self._unfreeze_layers(start_layers) | |
| def _unfreeze_layers(self, num_layers: int): | |
| """Unfreeze top N layers of the model.""" | |
| layers = list(self.model.modules()) | |
| # Unfreeze from the end (top layers) | |
| for layer in layers[-num_layers:]: | |
| for param in layer.parameters(): | |
| param.requires_grad = True | |
| def step_epoch(self): | |
| """Call at end of each epoch for progressive unfreezing.""" | |
| if hasattr(self, 'unfreeze_every_n_epochs'): | |
| self.current_epoch += 1 | |
| if self.current_epoch % self.unfreeze_every_n_epochs == 0: | |
| current_unfrozen = self.start_layers + (self.current_epoch // self.unfreeze_every_n_epochs) * 2 | |
| if self.max_layers is None or current_unfrozen <= self.max_layers: | |
| self._unfreeze_layers(current_unfrozen) | |