Nexuss-Transformer / Tutorials /08-continual-learning-lifecycle.md
Nexuss0781's picture
Upload data/train-00000-of-00001.parquet with huggingface_hub
7cb972e

Tutorial 08: Continual Learning & Model Lifecycle Management

Overview

This tutorial covers continual learning strategies for updating models with new data while preventing catastrophic forgetting, along with complete model lifecycle management including versioning, deployment, monitoring, and retirement. We'll integrate NTF's configuration system for hyperparameter tuning and use the ContinualLearningWrapper for EWC regularization.

Table of Contents

  1. Continual Learning Fundamentals
  2. Catastrophic Forgetting
  3. Replay-Based Methods
  4. Regularization-Based Methods
  5. Architecture-Based Methods
  6. Incremental Training Strategies
  7. Hyperparameter Tuning with NTF Config
  8. Model Versioning
  9. Deployment Strategies
  10. Production Monitoring
  11. Model Retirement and Archival

Continual Learning Fundamentals

What is Continual Learning?

Continual learning (also called lifelong learning or incremental learning) enables models to:

  • Learn continuously from new data over time
  • Adapt to changing distributions and tasks
  • Accumulate knowledge without forgetting previous capabilities
  • Operate in non-stationary environments

Key Challenges

  1. Catastrophic Forgetting: Learning new information causes loss of old knowledge
  2. Stability-Plasticity Dilemma: Balance between retaining old knowledge and learning new patterns
  3. Task Boundary Detection: Knowing when a new task or distribution shift occurs
  4. Computational Efficiency: Avoiding full retraining on all historical data
  5. Memory Constraints: Storing representative examples without keeping everything

Continual Learning Scenarios

from enum import Enum

class ContinualLearningScenario(Enum):
    """Different continual learning scenarios."""
    
    TASK_INCREMENTAL = "task_incremental"
    # New tasks arrive sequentially with clear boundaries
    # Example: First train on sentiment analysis, then on QA
    
    DOMAIN_INCREMENTAL = "domain_incremental"  
    # Same task but domains change over time
    # Example: News articles from 2020, then 2021, then 2022
    
    CLASS_INCREMENTAL = "class_incremental"
    # New classes appear over time
    # Example: First classify cats/dogs, then add birds, then fish
    
    INSTANCE_INCREMENTAL = "instance_incremental"
    # Same task and classes, just more data arrives
    # Example: Continuous stream of customer support tickets

Evaluation Metrics for Continual Learning

import numpy as np

class ContinualLearningMetrics:
    def __init__(self):
        self.task_accuracies = {}  # {task_id: {timestamp: accuracy}}
        
    def record_accuracy(self, task_id, timestamp, accuracy):
        """Record accuracy for a task at a specific time."""
        if task_id not in self.task_accuracies:
            self.task_accuracies[task_id] = {}
        self.task_accuracies[task_id][timestamp] = accuracy
    
    def calculate_forward_transfer(self):
        """
        Forward Transfer: How much does learning task A help with task B?
        
        Positive values indicate beneficial transfer.
        """
        transfers = []
        
        task_ids = sorted(self.task_accuracies.keys())
        
        for i, task_b in enumerate(task_ids[1:], 1):
            # Get initial accuracy on task B before training
            # Compare to accuracy after training on previous tasks
            
            # Simplified: compare first evaluation to best later evaluation
            times_b = sorted(self.task_accuracies[task_b].keys())
            if len(times_b) > 1:
                initial_acc = self.task_accuracies[task_b][times_b[0]]
                best_acc = max(self.task_accuracies[task_b][t] for t in times_b)
                transfer = best_acc - initial_acc
                transfers.append(transfer)
        
        return np.mean(transfers) if transfers else 0.0
    
    def calculate_backward_transfer(self):
        """
        Backward Transfer: How does learning new tasks affect old tasks?
        
        Negative values indicate forgetting.
        """
        transfers = []
        
        task_ids = sorted(self.task_accuracies.keys())
        
        for i, task_a in enumerate(task_ids[:-1]):
            times_a = sorted(self.task_accuracies[task_a].keys())
            
            if len(times_a) > 1:
                # Accuracy immediately after training on task A
                initial_acc = self.task_accuracies[task_a][times_a[0]]
                
                # Accuracy after training on all subsequent tasks
                final_acc = self.task_accuracies[task_a][times_a[-1]]
                
                transfer = final_acc - initial_acc
                transfers.append(transfer)
        
        return np.mean(transfers) if transfers else 0.0
    
    def calculate_forgetting_measure(self):
        """
        Forgetting Measure: Maximum decrease in accuracy on any old task.
        """
        forgetting_scores = []
        
        for task_id, time_accuracies in self.task_accuracies.items():
            times = sorted(time_accuracies.keys())
            
            if len(times) > 1:
                max_acc = max(time_accuracies[t] for t in times)
                final_acc = time_accuracies[times[-1]]
                
                forgetting = max_acc - final_acc
                forgetting_scores.append(forgetting)
        
        return np.mean(forgetting_scores) if forgetting_scores else 0.0
    
    def calculate_average_accuracy(self, final_only=False):
        """
        Average Accuracy across all tasks.
        """
        all_final_accuracies = []
        
        for task_id, time_accuracies in self.task_accuracies.items():
            if final_only:
                # Only use final accuracy
                final_time = max(time_accuracies.keys())
                all_final_accuracies.append(time_accuracies[final_time])
            else:
                # Average across all evaluations
                all_final_accuracies.extend(time_accuracies.values())
        
        return np.mean(all_final_accuracies) if all_final_accuracies else 0.0
    
    def generate_report(self):
        """Generate comprehensive continual learning report."""
        report = {
            'average_accuracy': self.calculate_average_accuracy(final_only=True),
            'forward_transfer': self.calculate_forward_transfer(),
            'backward_transfer': self.calculate_backward_transfer(),
            'forgetting_measure': self.calculate_forgetting_measure()
        }
        
        print("Continual Learning Performance Report")
        print("=" * 50)
        print(f"Average Accuracy:     {report['average_accuracy']:.4f}")
        print(f"Forward Transfer:     {report['forward_transfer']:+.4f}")
        print(f"Backward Transfer:    {report['backward_transfer']:+.4f}")
        print(f"Forgetting Measure:   {report['forgetting_measure']:.4f}")
        print("=" * 50)
        
        if report['forgetting_measure'] < 0.05:
            print("✅ Minimal forgetting detected")
        elif report['forgetting_measure'] < 0.15:
            print("⚠️  Moderate forgetting - consider mitigation")
        else:
            print("❌ Severe forgetting - immediate action needed")
        
        return report

Catastrophic Forgetting

Understanding the Problem

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

def demonstrate_catastrophic_forgetting(model, task1_data, task2_data):
    """
    Demonstrate catastrophic forgetting phenomenon.
    
    Train on Task 1, then Task 2, observe performance drop on Task 1.
    """
    metrics = {
        'task1_before': [],
        'task1_after_task2': [],
        'task2_after': []
    }
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    # Phase 1: Train on Task 1
    print("Phase 1: Training on Task 1...")
    for epoch in range(10):
        model.train()
        for texts, labels in task1_data:
            optimizer.zero_grad()
            outputs = model(texts)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
    # Evaluate on Task 1
    task1_acc = evaluate(model, task1_data)
    metrics['task1_before'].append(task1_acc)
    print(f"Task 1 Accuracy after Task 1 training: {task1_acc:.4f}")
    
    # Phase 2: Train on Task 2 (without seeing Task 1 data)
    print("\nPhase 2: Training on Task 2...")
    for epoch in range(10):
        model.train()
        for texts, labels in task2_data:
            optimizer.zero_grad()
            outputs = model(texts)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
    # Evaluate on both tasks
    task1_acc_after = evaluate(model, task1_data)
    task2_acc = evaluate(model, task2_data)
    
    metrics['task1_after_task2'].append(task1_acc_after)
    metrics['task2_after'].append(task2_acc)
    
    print(f"\nTask 1 Accuracy after Task 2 training: {task1_acc_after:.4f}")
    print(f"Task 2 Accuracy: {task2_acc:.4f}")
    
    forgetting = task1_acc - task1_acc_after
    print(f"\n📉 FORGETTING: {forgetting:.4f} ({forgetting/task1_acc*100:.1f}% drop)")
    
    # Visualization
    fig, ax = plt.subplots(figsize=(10, 6))
    
    ax.bar(['Task 1\n(Before)', 'Task 1\n(After)', 'Task 2'], 
           [task1_acc, task1_acc_after, task2_acc],
           color=['green', 'red', 'blue'], alpha=0.7)
    
    ax.set_ylabel('Accuracy')
    ax.set_title('Demonstration of Catastrophic Forgetting')
    ax.set_ylim(0, 1)
    
    for i, v in enumerate([task1_acc, task1_acc_after, task2_acc]):
        ax.text(i, v + 0.02, f'{v:.3f}', ha='center')
    
    plt.tight_layout()
    plt.savefig('catastrophic_forgetting_demo.png')
    plt.show()
    
    return metrics, forgetting

Why Does Forgetting Happen?

  1. Weight Interference: New task optimizes weights in directions that conflict with old task
  2. Representation Shift: Hidden representations change to accommodate new patterns
  3. Decision Boundary Movement: Classification boundaries shift away from old class regions
  4. Capacity Limits: Model has finite capacity; new knowledge displaces old

Replay-Based Methods

Experience Replay

import random
from collections import deque
import pickle

class ExperienceReplayBuffer:
    """
    Store and sample past experiences for replay during training.
    """
    
    def __init__(self, max_size=10000, strategy='uniform'):
        """
        Args:
            max_size: Maximum number of samples to store
            strategy: 'uniform', 'reservoir', 'class_balanced'
        """
        self.max_size = max_size
        self.strategy = strategy
        self.buffer = deque(maxlen=max_size)
        self.class_buffers = {}  # For class-balanced sampling
        
    def add(self, sample, label=None):
        """Add a sample to the replay buffer."""
        if self.strategy == 'class_balanced' and label is not None:
            if label not in self.class_buffers:
                self.class_buffers[label] = deque(maxlen=self.max_size // 10)
            self.class_buffers[label].append(sample)
        else:
            self.buffer.append((sample, label))
    
    def sample(self, batch_size):
        """Sample a batch from the replay buffer."""
        if self.strategy == 'class_balanced':
            # Sample equally from each class
            all_samples = []
            for label, class_buf in self.class_buffers.items():
                if len(class_buf) > 0:
                    n_samples = min(batch_size // len(self.class_buffers), len(class_buf))
                    all_samples.extend(random.sample(list(class_buf), n_samples))
            
            # Pad if necessary
            while len(all_samples) < batch_size and self.buffer:
                all_samples.append(random.choice(list(self.buffer)))
            
            return random.sample(all_samples, min(batch_size, len(all_samples)))
        
        elif self.strategy == 'reservoir':
            # Reservoir sampling already handled by add method
            return random.sample(list(self.buffer), min(batch_size, len(self.buffer)))
        
        else:  # uniform
            return random.sample(list(self.buffer), min(batch_size, len(self.buffer)))
    
    def __len__(self):
        return len(self.buffer)
    
    def save(self, path):
        """Save replay buffer to disk."""
        with open(path, 'wb') as f:
            pickle.dump({
                'buffer': list(self.buffer),
                'class_buffers': {k: list(v) for k, v in self.class_buffers.items()},
                'strategy': self.strategy
            }, f)
    
    def load(self, path):
        """Load replay buffer from disk."""
        with open(path, 'rb') as f:
            data = pickle.load(f)
            self.buffer = deque(data['buffer'], maxlen=self.max_size)
            self.class_buffers = {
                k: deque(v, maxlen=self.max_size // 10) 
                for k, v in data['class_buffers'].items()
            }
            self.strategy = data['strategy']


class ReplayBasedTrainer:
    """
    Trainer with experience replay for continual learning.
    """
    
    def __init__(self, model, replay_buffer, config):
        self.model = model
        self.replay_buffer = replay_buffer
        self.config = config
        self.optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
        self.criterion = nn.CrossEntropyLoss()
    
    def train_step(self, current_batch, replay_batch_size=32):
        """
        Train on current data mixed with replayed examples.
        """
        self.model.train()
        
        # Get replay samples
        if len(self.replay_buffer) > 0:
            replay_samples = self.replay_buffer.sample(replay_batch_size)
            replay_texts = [s[0] for s in replay_samples]
            replay_labels = [s[1] for s in replay_samples]
            
            # Combine current and replay data
            combined_texts = current_batch['texts'] + replay_texts
            combined_labels = current_batch['labels'] + replay_labels
        else:
            combined_texts = current_batch['texts']
            combined_labels = current_batch['labels']
        
        # Train on combined batch
        self.optimizer.zero_grad()
        outputs = self.model(combined_texts)
        loss = self.criterion(outputs, combined_labels)
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def train_on_task(self, task_data_loader, task_id, epochs=5):
        """Train on a new task while replaying old experiences."""
        print(f"\nTraining on Task {task_id} with replay...")
        
        for epoch in range(epochs):
            total_loss = 0
            
            for batch in task_data_loader:
                loss = self.train_step(batch, replay_batch_size=32)
                total_loss += loss
                
                # Add current batch to replay buffer
                for text, label in zip(batch['texts'], batch['labels']):
                    self.replay_buffer.add((text, label), label=label)
            
            avg_loss = total_loss / len(task_data_loader)
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")
        
        # Save replay buffer checkpoint
        self.replay_buffer.save(f'replay_buffer_task{task_id}.pkl')

Generative Replay

class GenerativeReplay:
    """
    Use a generative model to recreate past data instead of storing it.
    
    Benefits:
    - Privacy: Don't store actual past data
    - Compression: Generate many variations from compact model
    - Scalability: No memory limit on "stored" experiences
    """
    
    def __init__(self, generator_model, generator_config):
        self.generator = generator_model
        self.config = generator_config
    
    def train_generator(self, data_loader, epochs=10):
        """Train generator on current task data."""
        self.generator.train()
        optimizer = torch.optim.Adam(self.generator.parameters(), lr=1e-4)
        
        for epoch in range(epochs):
            total_loss = 0
            
            for batch in data_loader:
                optimizer.zero_grad()
                
                # Train generator to reconstruct input data
                reconstructed = self.generator(batch['texts'])
                loss = self.reconstruction_loss(batch['texts'], reconstructed)
                
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            
            print(f"Generator Epoch {epoch + 1}/{epochs}, Loss: {total_loss/len(data_loader):.4f}")
    
    def generate_replay_samples(self, n_samples=100):
        """Generate synthetic samples resembling past data."""
        self.generator.eval()
        
        generated_samples = []
        
        with torch.no_grad():
            for _ in range(n_samples):
                # Sample from latent space
                latent = torch.randn(1, self.config.latent_dim)
                
                # Generate text/features
                generated = self.generator.decode(latent)
                generated_samples.append(generated)
        
        return generated_samples
    
    def reconstruction_loss(self, original, reconstructed):
        """Calculate reconstruction loss for generator training."""
        # Implementation depends on generator architecture
        # Could be MSE, cross-entropy, or perceptual loss
        return nn.MSELoss()(original, reconstructed)

Dark Experience Replay

class DarkExperienceReplay:
    """
    Store model outputs (logits) along with inputs for replay.
    
    Instead of storing (x, y), store (x, model_output_at_time_t).
    This preserves the model's learned behavior, not just labels.
    """
    
    def __init__(self, model, max_size=5000):
        self.model = model
        self.max_size = max_size
        self.memory = deque(maxlen=max_size)
    
    def collect_experience(self, texts, labels):
        """Collect experiences with model predictions."""
        self.model.eval()
        
        with torch.no_grad():
            logits = self.model.get_logits(texts)
            probabilities = torch.softmax(logits, dim=-1)
        
        # Store input, true label, and model's predicted distribution
        for text, label, prob_dist in zip(texts, labels, probabilities):
            experience = {
                'text': text,
                'true_label': label,
                'predicted_distribution': prob_dist.cpu().numpy(),
                'timestamp': len(self.memory)
            }
            self.memory.append(experience)
    
    def replay_loss(self, current_logits, stored_experiences):
        """
        Calculate distillation loss to match old predictions.
        """
        total_loss = 0
        
        for exp in stored_experiences:
            # Get current prediction for stored input
            current_pred = current_logits[exp['text']]
            
            # Stored prediction (from old model)
            old_pred = torch.tensor(exp['predicted_distribution'])
            
            # KL divergence to match old predictions
            kl_loss = nn.KLDivLoss(reduction='batchmean')(
                torch.log_softmax(current_pred, dim=-1),
                old_pred
            )
            
            total_loss += kl_loss
        
        return total_loss / len(stored_experiences)

Regularization-Based Methods

Elastic Weight Consolidation (EWC)

import torch.nn.functional as F

class EWC:
    """
    Elastic Weight Consolidation: Penalize changes to important weights.
    
    Key idea: Some weights are more important for previous tasks.
    Constrain important weights to stay close to their old values.
    """
    
    def __init__(self, model, fisher_diagonal=None):
        self.model = model
        self.fisher_diagonal = fisher_diagonal  # Importance weights
        self.optimal_weights = None  # Weights after previous task
    
    def estimate_fisher_information(self, data_loader, device='cpu'):
        """
        Estimate Fisher Information Matrix diagonal.
        
        Fisher Information measures how sensitive the loss is to each parameter.
        High Fisher = parameter is important, should not change much.
        """
        self.model.eval()
        
        # Initialize Fisher as zeros
        fisher = {
            name: torch.zeros_like(param)
            for name, param in self.model.named_parameters()
            if param.requires_grad
        }
        
        # Accumulate squared gradients
        for batch in data_loader:
            self.model.zero_grad()
            
            # Get predictions
            outputs = self.model(batch['texts'])
            loss = F.cross_entropy(outputs, batch['labels'])
            
            # Compute gradients
            loss.backward()
            
            # Square and accumulate gradients (Fisher diagonal approximation)
            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    fisher[name] += param.grad.pow(2)
        
        # Average over samples
        n_samples = len(data_loader.dataset)
        for name in fisher:
            fisher[name] /= n_samples
        
        self.fisher_diagonal = fisher
        
        # Store optimal weights (current weights after training)
        self.optimal_weights = {
            name: param.clone().detach()
            for name, param in self.model.named_parameters()
            if param.requires_grad
        }
        
        return fisher
    
    def ewc_loss(self, lambda_ewc=1000):
        """
        Calculate EWC regularization loss.
        
        L_ewc = Σ_i F_i * (θ_i - θ*_i)^2
        
        where:
        - F_i is Fisher Information for parameter i
        - θ_i is current parameter value
        - θ*_i is optimal parameter value from previous task
        """
        if self.fisher_diagonal is None or self.optimal_weights is None:
            return torch.tensor(0.0)
        
        ewc_loss = 0
        
        for name, param in self.model.named_parameters():
            if param.requires_grad and name in self.fisher_diagonal:
                # Squared distance from optimal weights, weighted by Fisher
                diff = param - self.optimal_weights[name]
                ewc_loss += (self.fisher_diagonal[name] * diff.pow(2)).sum()
        
        return lambda_ewc * ewc_loss
    
    def save_checkpoint(self, path):
        """Save EWC state (Fisher and optimal weights)."""
        torch.save({
            'fisher_diagonal': self.fisher_diagonal,
            'optimal_weights': self.optimal_weights
        }, path)
    
    def load_checkpoint(self, path):
        """Load EWC state."""
        checkpoint = torch.load(path)
        self.fisher_diagonal = checkpoint['fisher_diagonal']
        self.optimal_weights = checkpoint['optimal_weights']


class EWC_Trainer:
    """Trainer with EWC regularization for continual learning."""
    
    def __init__(self, model, config, lambda_ewc=1000):
        self.model = model
        self.config = config
        self.lambda_ewc = lambda_ewc
        self.ewc = EWC(model)
        self.optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    
    def train_on_task(self, task_data_loader, task_id, epochs=5):
        """Train on new task with EWC regularization."""
        print(f"\nTraining Task {task_id} with EWC...")
        
        for epoch in range(epochs):
            total_loss = 0
            total_ewc_loss = 0
            
            for batch in task_data_loader:
                self.model.train()
                self.optimizer.zero_grad()
                
                # Standard task loss
                outputs = self.model(batch['texts'])
                task_loss = F.cross_entropy(outputs, batch['labels'])
                
                # EWC regularization loss
                ewc_loss = self.ewc.ewc_loss(self.lambda_ewc)
                
                # Total loss
                total = task_loss + ewc_loss
                
                total.backward()
                self.optimizer.step()
                
                total_loss += task_loss.item()
                total_ewc_loss += ewc_loss.item()
            
            avg_loss = total_loss / len(task_data_loader)
            avg_ewc = total_ewc_loss / len(task_data_loader)
            
            print(f"Epoch {epoch + 1}/{epochs}:")
            print(f"  Task Loss: {avg_loss:.4f}")
            print(f"  EWC Loss: {avg_ewc:.4f}")
        
        # After training, update Fisher and optimal weights
        print("Updating Fisher Information Matrix...")
        self.ewc.estimate_fisher_information(task_data_loader)
        self.ewc.save_checkpoint(f'ewc_checkpoint_task{task_id}.pth')

Synaptic Intelligence (SI)

class SynapticIntelligence:
    """
    Synaptic Intelligence: Track parameter importance during training.
    
    Unlike EWC which computes Fisher after training,
    SI accumulates importance online during training.
    """
    
    def __init__(self, model):
        self.model = model
        self.importance = {
            name: torch.zeros_like(param)
            for name, param in model.named_parameters()
            if param.requires_grad
        }
        self.previous_params = {
            name: param.clone().detach()
            for name, param in model.named_parameters()
            if param.requires_grad
        }
        self.delta_loss = {name: torch.zeros_like(param) for name in self.importance}
    
    def update_importance(self, loss_change):
        """
        Update parameter importance based on contribution to loss decrease.
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None:
                # Change in parameter
                delta_param = param.detach() - self.previous_params[name]
                
                # Contribution to loss decrease (approximated)
                contribution = -param.grad * delta_param * loss_change
                
                # Accumulate importance
                self.importance[name] += contribution.abs()
    
    def si_loss(self, c_si=100):
        """
        Calculate SI regularization loss.
        
        L_si = Σ_j Ω_j * (θ_j - θ*_j)^2
        
        where Ω_j is accumulated importance for parameter j.
        """
        si_loss = 0
        
        for name, param in self.model.named_parameters():
            if param.requires_grad and name in self.importance:
                diff = param - self.previous_params[name]
                si_loss += (self.importance[name] * diff.pow(2)).sum()
        
        return c_si * si_loss
    
    def update_previous_params(self):
        """Store current parameters as previous after training step."""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.previous_params[name] = param.clone().detach()

Learning without Forgetting (LwF)

class LearningWithoutForgetting:
    """
    Learning without Forgetting: Use knowledge distillation from old model.
    
    Keep a copy of the old model and distill its knowledge
    while training on new data.
    """
    
    def __init__(self, model):
        self.model = model
        self.old_model = None  # Copy of model from previous task
    
    def create_old_model_copy(self):
        """Create a frozen copy of current model before training on new task."""
        import copy
        self.old_model = copy.deepcopy(self.model)
        
        # Freeze old model
        for param in self.old_model.parameters():
            param.requires_grad = False
        
        self.old_model.eval()
    
    def distillation_loss(self, texts, temperature=2.0):
        """
        Calculate distillation loss to match old model's outputs.
        
        Uses softened probability distributions (with temperature)
        to capture dark knowledge.
        """
        if self.old_model is None:
            return torch.tensor(0.0)
        
        self.model.train()
        self.old_model.eval()
        
        with torch.no_grad():
            # Old model's soft predictions
            old_logits = self.old_model.get_logits(texts)
            old_soft_probs = F.softmax(old_logits / temperature, dim=-1)
        
        # Current model's predictions
        current_logits = self.model.get_logits(texts)
        current_soft_probs = F.log_softmax(current_logits / temperature, dim=-1)
        
        # KL divergence between distributions
        dist_loss = F.kl_div(
            current_soft_probs,
            old_soft_probs,
            reduction='batchmean'
        ) * (temperature ** 2)
        
        return dist_loss
    
    def combined_loss(self, task_loss, dist_loss, alpha=0.5):
        """
        Combine task loss and distillation loss.
        
        L_total = α * L_task + (1-α) * L_distill
        """
        return alpha * task_loss + (1 - alpha) * dist_loss

Architecture-Based Methods

Progressive Neural Networks

class ProgressiveNeuralNetwork(nn.Module):
    """
    Progressive Neural Networks: Add new columns for new tasks.
    
    Each task gets its own neural network column.
    Columns can read from previous columns but not modify them.
    """
    
    def __init__(self, base_column_class, config):
        super().__init__()
        self.base_column_class = base_column_class
        self.config = config
        self.columns = nn.ModuleList()
        self.task_to_column = {}
    
    def add_task_column(self, task_id):
        """Add a new column for a new task."""
        # Create new column
        new_column = self.base_column_class(self.config)
        
        # If there are previous columns, add lateral connections
        if len(self.columns) > 0:
            new_column.add_lateral_connections(self.columns)
        
        self.columns.append(new_column)
        self.task_to_column[task_id] = len(self.columns) - 1
        
        # Freeze all previous columns
        for col_idx in range(len(self.columns) - 1):
            for param in self.columns[col_idx].parameters():
                param.requires_grad = False
        
        print(f"Added column {len(self.columns) - 1} for task {task_id}")
    
    def forward(self, x, task_id):
        """Forward pass through appropriate column."""
        column_idx = self.task_to_column[task_id]
        return self.columns[column_idx](x)
    
    def get_total_parameters(self):
        """Count total parameters across all columns."""
        return sum(p.numel() for col in self.columns for p in col.parameters())

Adapter Modules

class AdapterModule(nn.Module):
    """
    Adapter modules: Small trainable modules inserted into frozen backbone.
    
    Keep pretrained model frozen, only train lightweight adapters.
    Different adapters for different tasks.
    """
    
    def __init__(self, hidden_dim, adapter_dim=64):
        super().__init__()
        self.down_project = nn.Linear(hidden_dim, adapter_dim)
        self.activation = nn.GELU()
        self.up_project = nn.Linear(adapter_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
    
    def forward(self, x):
        # Residual connection
        residual = x
        
        # Adapter bottleneck
        x = self.down_project(x)
        x = self.activation(x)
        x = self.up_project(x)
        
        # Add residual and normalize
        x = self.layer_norm(x + residual)
        
        return x


class AdapterConfig:
    """Configuration for adapter-based fine-tuning."""
    
    def __init__(self, 
                 adapter_dim=64,
                 adapter_locations='all_layers',
                 freeze_backbone=True,
                 task_adapters=True):
        self.adapter_dim = adapter_dim
        self.adapter_locations = adapter_locations  # 'all_layers', 'last_n', etc.
        self.freeze_backbone = freeze_backbone
        self.task_adapters = task_adapters  # Separate adapter per task


def insert_adapters_into_transformer(transformer_model, config):
    """Insert adapter modules into a transformer model."""
    
    if config.freeze_backbone:
        # Freeze all backbone parameters
        for param in transformer_model.parameters():
            param.requires_grad = False
    
    adapters = {}
    
    # Insert adapters into each layer
    for layer_idx, layer in enumerate(transformer_model.layers):
        if config.adapter_locations == 'all_layers' or \
           (config.adapter_locations.startswith('last_') and 
            layer_idx >= len(transformer_model.layers) - int(config.adapter_locations[5:])):
            
            # Create adapter
            adapter = AdapterModule(
                hidden_dim=layer.hidden_dim,
                adapter_dim=config.adapter_dim
            )
            
            # Insert after attention sublayer
            layer.insert_adapter(adapter)
            
            adapters[f'layer_{layer_idx}'] = adapter
    
    return adapters

Dynamic Architecture Expansion

class DynamicArchitectureExpansion:
    """
    Dynamically expand model capacity when needed.
    
    Monitor performance; if degradation detected, add capacity.
    """
    
    def __init__(self, model, expansion_threshold=0.05):
        self.model = model
        self.expansion_threshold = expansion_threshold
        self.baseline_performance = None
        self.expansion_history = []
    
    def monitor_and_expand(self, validation_data, current_task_id):
        """
        Check if model needs expansion based on performance drop.
        """
        current_performance = self.evaluate(self.model, validation_data)
        
        if self.baseline_performance is not None:
            performance_drop = self.baseline_performance - current_performance
            
            if performance_drop > self.expansion_threshold:
                print(f"Performance drop detected: {performance_drop:.4f}")
                print("Expanding model architecture...")
                
                self.expand_architecture(current_task_id)
                
                # Re-evaluate after expansion
                new_performance = self.evaluate(self.model, validation_data)
                print(f"Performance after expansion: {new_performance:.4f}")
        
        self.baseline_performance = current_performance
    
    def expand_architecture(self, task_id):
        """Add new capacity to the model."""
        # Strategy 1: Add new neurons to hidden layers
        # Strategy 2: Add new layers
        # Strategy 3: Add task-specific heads
        
        # Example: Add task-specific output head
        new_head = nn.Linear(self.model.hidden_dim, self.model.num_new_classes)
        setattr(self.model, f'task_{task_id}_head', new_head)
        
        self.expansion_history.append({
            'task_id': task_id,
            'expansion_type': 'new_head',
            'timestamp': len(self.expansion_history)
        })

Incremental Training Strategies

Scheduled Fine-Tuning

class ScheduledFineTuner:
    """
    Schedule fine-tuning with decreasing learning rates and selective unfreezing.
    """
    
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.training_history = []
    
    def progressive_unfreezing(self, n_stages=3):
        """
        Progressively unfreeze layers from top to bottom.
        
        Stage 1: Only train top layers
        Stage 2: Train top + middle layers
        Stage 3: Train all layers
        """
        total_layers = len(self.model.layers)
        layers_per_stage = total_layers // n_stages
        
        schedule = []
        
        for stage in range(n_stages):
            # Number of layers to unfreeze
            n_unfrozen = (stage + 1) * layers_per_stage
            
            # Freeze/unfreeze accordingly
            for i, layer in enumerate(self.model.layers):
                if i < total_layers - n_unfrozen:
                    for param in layer.parameters():
                        param.requires_grad = False
                else:
                    for param in layer.parameters():
                        param.requires_grad = True
            
            # Learning rate decreases with deeper unfreezing
            lr = self.config.base_lr * (0.5 ** (n_stages - stage - 1))
            
            schedule.append({
                'stage': stage,
                'n_unfrozen': n_unfrozen,
                'learning_rate': lr
            })
        
        return schedule
    
    def train_with_schedule(self, data_loader, schedule):
        """Train using progressive unfreezing schedule."""
        for stage_config in schedule:
            print(f"\nStage {stage_config['stage'] + 1}: "
                  f"Unfreezing {stage_config['n_unfrozen']} layers")
            
            # Set learning rate
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = stage_config['learning_rate']
            
            # Train for this stage
            self.train_epoch(data_loader)
            
            # Record history
            self.training_history.append(stage_config)

Curriculum Learning for Continual Learning

class CurriculumContinualLearner:
    """
    Apply curriculum learning principles to continual learning.
    
    Order tasks/examples from easy to hard to facilitate transfer.
    """
    
    def __init__(self, model):
        self.model = model
        self.task_difficulty = {}
    
    def estimate_task_difficulty(self, task_data):
        """
        Estimate difficulty of a task based on initial performance.
        """
        self.model.eval()
        
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in task_data:
                outputs = self.model(batch['texts'])
                predictions = outputs.argmax(dim=-1)
                
                correct += (predictions == batch['labels']).sum().item()
                total += len(batch['labels'])
        
        initial_accuracy = correct / total
        
        # Difficulty inversely related to accuracy
        difficulty = 1.0 - initial_accuracy
        
        return difficulty
    
    def order_tasks_by_curriculum(self, tasks):
        """
        Order tasks from easiest to hardest.
        """
        # Estimate difficulty for each task
        for task_id, task_data in tasks.items():
            self.task_difficulty[task_id] = self.estimate_task_difficulty(task_data)
        
        # Sort by difficulty (easy to hard)
        ordered_tasks = sorted(
            tasks.items(),
            key=lambda x: self.task_difficulty[x[0]]
        )
        
        print("Task Curriculum (Easy → Hard):")
        for task_id, _ in ordered_tasks:
            print(f"  {task_id}: difficulty = {self.task_difficulty[task_id]:.3f}")
        
        return ordered_tasks
    
    def train_with_curriculum(self, ordered_tasks):
        """Train on tasks in curriculum order."""
        for task_id, task_data in ordered_tasks:
            print(f"\n{'='*50}")
            print(f"Training on task: {task_id}")
            print(f"Difficulty: {self.task_difficulty[task_id]:.3f}")
            print(f"{'='*50}")
            
            # Train on this task
            self.train_on_task(task_data, task_id)

Model Versioning

Semantic Versioning for Models

from datetime import datetime
import json
import hashlib

class ModelVersion:
    """
    Semantic versioning for ML models.
    
    Format: MAJOR.MINOR.PATCH
    
    - MAJOR: Breaking changes (architecture change, incompatible API)
    - MINOR: New features, improved performance (backward compatible)
    - PATCH: Bug fixes, minor improvements (fully backward compatible)
    """
    
    def __init__(self, major=0, minor=0, patch=0, metadata=None):
        self.major = major
        self.minor = minor
        self.patch = patch
        self.metadata = metadata or {}
        self.created_at = datetime.now().isoformat()
    
    def bump_major(self):
        """Increment major version (breaking change)."""
        self.major += 1
        self.minor = 0
        self.patch = 0
    
    def bump_minor(self):
        """Increment minor version (new feature)."""
        self.minor += 1
        self.patch = 0
    
    def bump_patch(self):
        """Increment patch version (bug fix)."""
        self.patch += 1
    
    def __str__(self):
        version_str = f"{self.major}.{self.minor}.{self.patch}"
        
        if self.metadata:
            metadata_str = '+'.join(f"{k}={v}" for k, v in self.metadata.items())
            version_str += f"+{metadata_str}"
        
        return version_str
    
    def to_dict(self):
        return {
            'version': str(self),
            'major': self.major,
            'minor': self.minor,
            'patch': self.patch,
            'metadata': self.metadata,
            'created_at': self.created_at
        }
    
    @classmethod
    def from_string(cls, version_str):
        """Parse version string to ModelVersion object."""
        # Simple parsing (can be extended for metadata)
        parts = version_str.split('+')[0].split('.')
        return cls(
            major=int(parts[0]),
            minor=int(parts[1]) if len(parts) > 1 else 0,
            patch=int(parts[2]) if len(parts) > 2 else 0
        )


class ModelRegistry:
    """
    Centralized registry for model versions and artifacts.
    """
    
    def __init__(self, registry_path='./model_registry'):
        self.registry_path = Path(registry_path)
        self.registry_path.mkdir(parents=True, exist_ok=True)
        self.models = {}  # {model_name: {version: metadata}}
        self.load_registry()
    
    def register_model(self, model_name, version, model_path, metrics, metadata=None):
        """Register a new model version."""
        if model_name not in self.models:
            self.models[model_name] = {}
        
        # Calculate model hash for integrity
        model_hash = self.calculate_file_hash(model_path)
        
        # Create metadata
        model_metadata = {
            'version': str(version),
            'model_path': str(model_path),
            'model_hash': model_hash,
            'metrics': metrics,
            'metadata': metadata or {},
            'registered_at': datetime.now().isoformat(),
            'status': 'active'  # active, deprecated, archived
        }
        
        self.models[model_name][str(version)] = model_metadata
        
        # Save registry
        self.save_registry()
        
        print(f"Registered {model_name} v{version}")
        print(f"  Path: {model_path}")
        print(f"  Hash: {model_hash[:16]}...")
        print(f"  Metrics: {metrics}")
        
        return model_metadata
    
    def get_latest_version(self, model_name):
        """Get the latest active version of a model."""
        if model_name not in self.models:
            return None
        
        active_versions = [
            v for v, m in self.models[model_name].items()
            if m['status'] == 'active'
        ]
        
        if not active_versions:
            return None
        
        # Sort by version number
        latest = max(active_versions, key=lambda v: ModelVersion.from_string(v))
        return latest
    
    def deprecate_version(self, model_name, version_str):
        """Mark a model version as deprecated."""
        if model_name in self.models and version_str in self.models[model_name]:
            self.models[model_name][version_str]['status'] = 'deprecated'
            self.models[model_name][version_str]['deprecated_at'] = datetime.now().isoformat()
            self.save_registry()
            print(f"Deprecated {model_name} v{version_str}")
    
    def calculate_file_hash(self, file_path):
        """Calculate SHA256 hash of model file."""
        sha256_hash = hashlib.sha256()
        with open(file_path, "rb") as f:
            for byte_block in iter(lambda: f.read(4096), b""):
                sha256_hash.update(byte_block)
        return sha256_hash.hexdigest()
    
    def save_registry(self):
        """Save registry to disk."""
        registry_file = self.registry_path / 'registry.json'
        with open(registry_file, 'w') as f:
            json.dump(self.models, f, indent=2)
    
    def load_registry(self):
        """Load registry from disk."""
        registry_file = self.registry_path / 'registry.json'
        if registry_file.exists():
            with open(registry_file, 'r') as f:
                self.models = json.load(f)
    
    def list_models(self):
        """List all registered models and versions."""
        print("Model Registry")
        print("=" * 70)
        
        for model_name, versions in self.models.items():
            print(f"\n{model_name}:")
            for version_str, metadata in sorted(versions.items()):
                status = metadata['status']
                metrics = metadata.get('metrics', {})
                accuracy = metrics.get('accuracy', 'N/A')
                
                print(f"  v{version_str:15} [{status:10}] Acc: {accuracy}")
        
        print("=" * 70)

Deployment Strategies

Canary Deployments

class CanaryDeployment:
    """
    Gradually roll out new model version to subset of traffic.
    """
    
    def __init__(self, old_model, new_model, initial_percentage=5):
        self.old_model = old_model
        self.new_model = new_model
        self.canary_percentage = initial_percentage
        self.deployment_log = []
    
    def route_request(self, request):
        """Route request to old or new model based on canary percentage."""
        if random.random() < self.canary_percentage / 100:
            model = self.new_model
            variant = 'canary'
        else:
            model = self.old_model
            variant = 'stable'
        
        prediction = model.predict(request)
        
        # Log for monitoring
        self.deployment_log.append({
            'timestamp': datetime.now().isoformat(),
            'variant': variant,
            'request_id': request.get('id'),
            'prediction': prediction
        })
        
        return prediction
    
    def increase_canary(self, increment=10):
        """Increase canary traffic percentage."""
        self.canary_percentage = min(100, self.canary_percentage + increment)
        print(f"Canary traffic increased to {self.canary_percentage}%")
    
    def rollback(self):
        """Rollback to 100% old model."""
        self.canary_percentage = 0
        print("Rolled back to stable model")
    
    def analyze_canary_performance(self, ground_truth):
        """Compare performance of canary vs stable."""
        canary_correct = 0
        canary_total = 0
        stable_correct = 0
        stable_total = 0
        
        for log_entry in self.deployment_log:
            # Match with ground truth (simplified)
            is_correct = check_prediction(log_entry, ground_truth)
            
            if log_entry['variant'] == 'canary':
                canary_correct += is_correct
                canary_total += 1
            else:
                stable_correct += is_correct
                stable_total += 1
        
        canary_acc = canary_correct / canary_total if canary_total > 0 else 0
        stable_acc = stable_correct / stable_total if stable_total > 0 else 0
        
        print(f"Canary Accuracy: {canary_acc:.4f} (n={canary_total})")
        print(f"Stable Accuracy: {stable_acc:.4f} (n={stable_total})")
        
        improvement = canary_acc - stable_acc
        
        if improvement > 0.02:  # 2% improvement threshold
            print("✅ Canary performing better - consider increasing traffic")
            return 'promote'
        elif improvement < -0.02:
            print("❌ Canary performing worse - consider rollback")
            return 'rollback'
        else:
            print("⚠️  Similar performance - continue monitoring")
            return 'monitor'

Blue-Green Deployment

class BlueGreenDeployment:
    """
    Maintain two identical production environments.
    
    - Blue: Currently serving all traffic
    - Green: Idle environment with new model
    
    Switch traffic instantly when ready.
    """
    
    def __init__(self):
        self.active_environment = 'blue'
        self.environments = {
            'blue': {'model': None, 'status': 'inactive'},
            'green': {'model': None, 'status': 'inactive'}
        }
    
    def deploy_to_inactive(self, new_model):
        """Deploy new model to inactive environment."""
        inactive_env = 'green' if self.active_environment == 'blue' else 'blue'
        
        self.environments[inactive_env]['model'] = new_model
        self.environments[inactive_env]['status'] = 'ready'
        
        print(f"Deployed new model to {inactive_env} environment")
    
    def switch_traffic(self):
        """Switch all traffic to the other environment."""
        old_active = self.active_environment
        self.active_environment = 'green' if self.active_environment == 'blue' else 'blue'
        
        self.environments[old_active]['status'] = 'inactive'
        self.environments[self.active_environment]['status'] = 'active'
        
        print(f"Traffic switched from {old_active} to {self.active_environment}")
    
    def predict(self, request):
        """Route request to active environment."""
        model = self.environments[self.active_environment]['model']
        return model.predict(request)
    
    def rollback(self):
        """Quick rollback by switching environments."""
        self.switch_traffic()
        print("Rolled back to previous environment")

Shadow Mode Deployment

class ShadowModeDeployment:
    """
    Run new model in shadow mode alongside production.
    
    New model receives all requests but doesn't serve predictions.
    Used for validation without risk.
    """
    
    def __init__(self, production_model, shadow_model):
        self.production_model = production_model
        self.shadow_model = shadow_model
        self.shadow_predictions = []
    
    def predict(self, request):
        """Serve from production, record shadow predictions."""
        # Production prediction (served to user)
        production_pred = self.production_model.predict(request)
        
        # Shadow prediction (recorded only)
        shadow_pred = self.shadow_model.predict(request)
        
        # Store for analysis
        self.shadow_predictions.append({
            'request': request,
            'production': production_pred,
            'shadow': shadow_pred,
            'timestamp': datetime.now().isoformat()
        })
        
        return production_pred
    
    def analyze_discrepancies(self, ground_truth=None):
        """Analyze differences between production and shadow."""
        discrepancies = 0
        total = len(self.shadow_predictions)
        
        for entry in self.shadow_predictions:
            if entry['production'] != entry['shadow']:
                discrepancies += 1
        
        discrepancy_rate = discrepancies / total if total > 0 else 0
        
        print(f"Shadow Mode Analysis:")
        print(f"  Total requests: {total}")
        print(f"  Discrepancies: {discrepancies} ({discrepancy_rate:.2%})")
        
        if ground_truth:
            # Evaluate which model performed better
            prod_correct = sum(
                1 for e, gt in zip(self.shadow_predictions, ground_truth)
                if e['production'] == gt
            )
            shadow_correct = sum(
                1 for e, gt in zip(self.shadow_predictions, ground_truth)
                if e['shadow'] == gt
            )
            
            print(f"  Production accuracy: {prod_correct/total:.4f}")
            print(f"  Shadow accuracy: {shadow_correct/total:.4f}")
        
        return discrepancy_rate

Production Monitoring

Real-Time Performance Monitoring

import time
from collections import defaultdict, deque

class ProductionMonitor:
    """
    Monitor model performance in production.
    """
    
    def __init__(self, window_size=1000):
        self.window_size = window_size
        
        # Metrics windows
        self.latency_window = deque(maxlen=window_size)
        self.throughput_window = deque(maxlen=window_size)
        self.prediction_distribution = defaultdict(int)
        self.confidence_window = deque(maxlen=window_size)
        
        # Alerts
        self.alerts = []
        self.alert_thresholds = {
            'latency_p99': 1000,  # ms
            'throughput_min': 10,  # requests/sec
            'confidence_low': 0.5
        }
    
    def record_prediction(self, prediction, confidence, latency_ms):
        """Record a prediction event."""
        timestamp = time.time()
        
        # Record metrics
        self.latency_window.append((timestamp, latency_ms))
        self.prediction_distribution[prediction] += 1
        self.confidence_window.append((timestamp, confidence))
        self.throughput_window.append(timestamp)
        
        # Check for alerts
        self.check_alerts()
    
    def check_alerts(self):
        """Check if any metrics exceed thresholds."""
        current_time = time.time()
        
        # P99 Latency
        latencies = [l for _, l in self.latency_window]
        if latencies:
            p99_latency = np.percentile(latencies, 99)
            if p99_latency > self.alert_thresholds['latency_p99']:
                self.create_alert('HIGH_LATENCY', f"P99 latency: {p99_latency:.0f}ms")
        
        # Throughput
        recent_throughput = sum(
            1 for t in self.throughput_window
            if current_time - t < 1.0  # Last second
        )
        if recent_throughput < self.alert_thresholds['throughput_min']:
            self.create_alert('LOW_THROUGHPUT', f"Throughput: {recent_throughput} req/s")
        
        # Low confidence
        confidences = [c for _, c in self.confidence_window]
        if confidences:
            low_conf_ratio = np.mean([c < self.alert_thresholds['confidence_low'] for c in confidences])
            if low_conf_ratio > 0.2:  # More than 20% low confidence
                self.create_alert('LOW_CONFIDENCE', f"Low confidence ratio: {low_conf_ratio:.2%}")
    
    def create_alert(self, alert_type, message):
        """Create an alert."""
        alert = {
            'type': alert_type,
            'message': message,
            'timestamp': datetime.now().isoformat()
        }
        self.alerts.append(alert)
        print(f"🚨 ALERT [{alert_type}]: {message}")
    
    def get_dashboard_metrics(self):
        """Get current metrics for dashboard."""
        current_time = time.time()
        
        # Latency stats
        latencies = [l for _, l in self.latency_window]
        latency_stats = {
            'mean': np.mean(latencies) if latencies else 0,
            'p50': np.percentile(latencies, 50) if latencies else 0,
            'p95': np.percentile(latencies, 95) if latencies else 0,
            'p99': np.percentile(latencies, 99) if latencies else 0
        }
        
        # Throughput
        throughput = sum(1 for t in self.throughput_window if current_time - t < 1.0)
        
        # Confidence stats
        confidences = [c for _, c in self.confidence_window]
        confidence_stats = {
            'mean': np.mean(confidences) if confidences else 0,
            'std': np.std(confidences) if confidences else 0
        }
        
        # Prediction distribution
        total_preds = sum(self.prediction_distribution.values())
        pred_distribution = {
            k: v / total_preds if total_preds > 0 else 0
            for k, v in self.prediction_distribution.items()
        }
        
        return {
            'latency': latency_stats,
            'throughput': throughput,
            'confidence': confidence_stats,
            'prediction_distribution': pred_distribution,
            'active_alerts': len([a for a in self.alerts if a['timestamp'] > str(current_time - 3600)])
        }

Drift Detection in Production

class ProductionDriftDetector:
    """
    Detect data drift and concept drift in production.
    """
    
    def __init__(self, reference_data, model, detection_method='ks_test'):
        self.reference_data = reference_data
        self.model = model
        self.detection_method = detection_method
        
        # Reference statistics
        self.reference_stats = self.compute_reference_statistics()
        
        # Production data window
        self.production_window = deque(maxlen=1000)
        
        # Drift alerts
        self.drift_alerts = []
    
    def compute_reference_statistics(self):
        """Compute statistics from reference (training) data."""
        stats = {}
        
        # Feature statistics
        features = self.extract_features(self.reference_data)
        stats['feature_means'] = np.mean(features, axis=0)
        stats['feature_stds'] = np.std(features, axis=0)
        
        # Prediction distribution
        reference_preds = self.model.predict(self.reference_data)
        stats['prediction_distribution'] = np.bincount(reference_preds) / len(reference_preds)
        
        # Confidence distribution
        reference_confs = self.model.get_confidences(self.reference_data)
        stats['confidence_mean'] = np.mean(reference_confs)
        stats['confidence_std'] = np.std(reference_confs)
        
        return stats
    
    def extract_features(self, data):
        """Extract features from data for drift detection."""
        # Implementation depends on data type
        # Could be raw features, embeddings, etc.
        pass
    
    def detect_feature_drift(self, new_data):
        """Detect drift in input features."""
        new_features = self.extract_features(new_data)
        
        if self.detection_method == 'ks_test':
            # Kolmogorov-Smirnov test for each feature
            drift_scores = []
            p_values = []
            
            for i in range(new_features.shape[1]):
                stat, p_val = stats.ks_2samp(
                    self.reference_data[:, i],
                    new_features[:, i]
                )
                drift_scores.append(stat)
                p_values.append(p_val)
            
            # Significant drift if many features have p < 0.05
            significant_drift = np.mean([p < 0.05 for p in p_values])
            
            return {
                'drift_detected': significant_drift > 0.3,  # 30% features drifted
                'drift_scores': drift_scores,
                'p_values': p_values,
                'fraction_drifted': significant_drift
            }
        
        elif self.detection_method == 'population_stability_index':
            # PSI for categorical features
            pass
    
    def detect_prediction_drift(self, new_predictions):
        """Detect drift in prediction distribution."""
        new_distribution = np.bincount(new_predictions) / len(new_predictions)
        
        # Ensure same length
        max_len = max(len(self.reference_stats['prediction_distribution']), 
                     len(new_distribution))
        
        ref_dist = np.zeros(max_len)
        new_dist = np.zeros(max_len)
        
        ref_dist[:len(self.reference_stats['prediction_distribution'])] = \
            self.reference_stats['prediction_distribution']
        new_dist[:len(new_distribution)] = new_distribution
        
        # KL divergence
        epsilon = 1e-10
        kl_div = np.sum(ref_dist * np.log((ref_dist + epsilon) / (new_dist + epsilon)))
        
        # Jensen-Shannon divergence (symmetric)
        js_div = 0.5 * kl_div + 0.5 * np.sum(new_dist * np.log((new_dist + epsilon) / (ref_dist + epsilon)))
        
        drift_detected = js_div > 0.1  # Threshold
        
        return {
            'drift_detected': drift_detected,
            'js_divergence': js_div,
            'reference_distribution': ref_dist,
            'new_distribution': new_dist
        }
    
    def monitor(self, new_data, new_predictions):
        """Run all drift detection methods."""
        results = {}
        
        # Feature drift
        results['feature_drift'] = self.detect_feature_drift(new_data)
        
        # Prediction drift
        results['prediction_drift'] = self.detect_prediction_drift(new_predictions)
        
        # Overall drift decision
        overall_drift = (
            results['feature_drift']['drift_detected'] or
            results['prediction_drift']['drift_detected']
        )
        
        if overall_drift:
            self.create_drift_alert(results)
        
        return {
            'drift_detected': overall_drift,
            'details': results
        }
    
    def create_drift_alert(self, drift_results):
        """Create drift alert."""
        alert = {
            'timestamp': datetime.now().isoformat(),
            'feature_drift': drift_results['feature_drift'],
            'prediction_drift': drift_results['prediction_drift']
        }
        self.drift_alerts.append(alert)
        
        print("🚨 DRIFT DETECTED!")
        print(f"  Feature drift: {drift_results['feature_drift']['fraction_drifted']:.2%}")
        print(f"  Prediction drift (JS): {drift_results['prediction_drift']['js_divergence']:.4f}")

Model Retirement and Archival

Model Deprecation Process

class ModelDeprecationManager:
    """
    Manage the deprecation and retirement lifecycle of models.
    """
    
    def __init__(self, model_registry):
        self.registry = model_registry
        self.deprecation_schedule = {}
    
    def initiate_deprecation(self, model_name, version, reason, timeline_days=90):
        """
        Begin the deprecation process for a model version.
        """
        deprecation_date = datetime.now()
        retirement_date = deprecation_date + timedelta(days=timeline_days)
        
        self.deprecation_schedule[f"{model_name}:{version}"] = {
            'model_name': model_name,
            'version': version,
            'reason': reason,
            'deprecation_date': deprecation_date.isoformat(),
            'retirement_date': retirement_date.isoformat(),
            'status': 'deprecated',
            'replacement': None,
            'migration_guide': None
        }
        
        # Update registry
        self.registry.deprecate_version(model_name, version)
        
        print(f"Initiated deprecation for {model_name} v{version}")
        print(f"  Reason: {reason}")
        print(f"  Deprecation date: {deprecation_date.strftime('%Y-%m-%d')}")
        print(f"  Retirement date: {retirement_date.strftime('%Y-%m-%d')}")
        
        return self.deprecation_schedule[f"{model_name}:{version}"]
    
    def set_replacement(self, model_name, old_version, new_model_name, new_version):
        """Specify replacement model for deprecated version."""
        key = f"{model_name}:{old_version}"
        
        if key in self.deprecation_schedule:
            self.deprecation_schedule[key]['replacement'] = {
                'model_name': new_model_name,
                'version': new_version
            }
            
            # Generate migration guide
            self.generate_migration_guide(model_name, old_version, new_model_name, new_version)
    
    def generate_migration_guide(self, old_model, old_version, new_model, new_version):
        """Generate migration guide for users."""
        guide = f"""
# Migration Guide: {old_model} v{old_version}{new_model} v{new_version}

## Timeline
- Deprecation Date: {self.deprecation_schedule[f'{old_model}:{old_version}']['deprecation_date']}
- Retirement Date: {self.deprecation_schedule[f'{old_model}:{old_version}']['retirement_date']}

## Breaking Changes
[List breaking changes here]

## API Differences
[Document API changes]

## Performance Improvements
[Document improvements]

## Migration Steps
1. Update model reference in configuration
2. Test with new model in staging environment
3. Validate outputs match expectations
4. Deploy to production with canary rollout
5. Monitor for issues

## Support
Contact: ml-platform@company.com
        """
        
        self.deprecation_schedule[f"{old_model}:{old_version}"]["migration_guide"] = guide
        
        # Save guide
        guide_path = f"migration_guides/{old_model}_{old_version}_to_{new_model}_{new_version}.md"
        with open(guide_path, 'w') as f:
            f.write(guide)
    
    def retire_model(self, model_name, version):
        """
        Fully retire a model version (after deprecation period).
        """
        key = f"{model_name}:{version}"
        
        if key not in self.deprecation_schedule:
            print(f"No deprecation record found for {model_name}:{version}")
            return False
        
        deprecation_info = self.deprecation_schedule[key]
        retirement_date = datetime.fromisoformat(deprecation_info['retirement_date'])
        
        if datetime.now() < retirement_date:
            print(f"Cannot retire before {retirement_date.strftime('%Y-%m-%d')}")
            return False
        
        # Archive model
        self.archive_model(model_name, version)
        
        # Update status
        deprecation_info['status'] = 'retired'
        deprecation_info['retired_at'] = datetime.now().isoformat()
        
        print(f"Retired {model_name} v{version}")
        
        return True
    
    def archive_model(self, model_name, version):
        """Move model to cold storage."""
        # Get model path from registry
        model_metadata = self.registry.models[model_name][version]
        model_path = Path(model_metadata['model_path'])
        
        # Create archive directory
        archive_dir = Path('./model_archives') / model_name / version
        archive_dir.mkdir(parents=True, exist_ok=True)
        
        # Move model files
        import shutil
        shutil.move(str(model_path), str(archive_dir / model_path.name))
        
        # Save metadata
        metadata_path = archive_dir / 'metadata.json'
        with open(metadata_path, 'w') as f:
            json.dump(model_metadata, f, indent=2)
        
        # Compress archive
        shutil.make_archive(str(archive_dir), 'gztar', archive_dir)
        
        print(f"Archived {model_name} v{version} to {archive_dir}")
    
    def get_deprecation_status(self, model_name, version=None):
        """Get deprecation status for a model."""
        if version:
            key = f"{model_name}:{version}"
            return self.deprecation_schedule.get(key, None)
        else:
            # Return all versions
            return {
                k: v for k, v in self.deprecation_schedule.items()
                if v['model_name'] == model_name
            }

Model Lineage Tracking

class ModelLineageTracker:
    """
    Track complete lineage of models from training to retirement.
    """
    
    def __init__(self):
        self.lineage_graph = {}  # {model_id: lineage_info}
    
    def record_training_run(self, model_id, training_config, data_version, 
                           code_version, hyperparameters):
        """Record details of a training run."""
        self.lineage_graph[model_id] = {
            'model_id': model_id,
            'created_at': datetime.now().isoformat(),
            'training': {
                'config': training_config,
                'data_version': data_version,
                'code_version': code_version,
                'hyperparameters': hyperparameters,
                'environment': self.capture_environment()
            },
            'parent_models': [],  # For fine-tuned models
            'child_models': [],   # Models derived from this one
            'evaluation_results': {},
            'deployment_history': [],
            'retirement_info': None
        }
    
    def record_fine_tuning(self, child_model_id, parent_model_id, 
                          fine_tuning_data, fine_tuning_config):
        """Record fine-tuning relationship."""
        if parent_model_id in self.lineage_graph:
            # Add child to parent
            self.lineage_graph[parent_model_id]['child_models'].append(child_model_id)
            
            # Create child record
            self.lineage_graph[child_model_id] = {
                'model_id': child_model_id,
                'created_at': datetime.now().isoformat(),
                'parent_models': [parent_model_id],
                'fine_tuning': {
                    'data': fine_tuning_data,
                    'config': fine_tuning_config
                },
                'child_models': [],
                'evaluation_results': {},
                'deployment_history': [],
                'retirement_info': None
            }
    
    def record_evaluation(self, model_id, dataset_name, metrics):
        """Record evaluation results."""
        if model_id in self.lineage_graph:
            self.lineage_graph[model_id]['evaluation_results'][dataset_name] = {
                'metrics': metrics,
                'recorded_at': datetime.now().isoformat()
            }
    
    def record_deployment(self, model_id, environment, deployment_config):
        """Record deployment event."""
        if model_id in self.lineage_graph:
            self.lineage_graph[model_id]['deployment_history'].append({
                'environment': environment,
                'config': deployment_config,
                'deployed_at': datetime.now().isoformat()
            })
    
    def capture_environment(self):
        """Capture training environment details."""
        import sys
        import platform
        
        return {
            'python_version': sys.version,
            'platform': platform.platform(),
            'packages': self.get_installed_packages()
        }
    
    def get_installed_packages(self):
        """Get list of installed packages and versions."""
        import pkg_resources
        return {
            pkg.key: pkg.version
            for pkg in pkg_resources.working_set
        }
    
    def get_model_lineage(self, model_id):
        """Get complete lineage for a model."""
        if model_id not in self.lineage_graph:
            return None
        
        lineage = self.lineage_graph[model_id].copy()
        
        # Recursively get parent lineages
        if lineage['parent_models']:
            lineage['parent_lineages'] = [
                self.get_model_lineage(parent_id)
                for parent_id in lineage['parent_models']
            ]
        
        return lineage
    
    def visualize_lineage(self, model_id):
        """Visualize model lineage as a graph."""
        try:
            import graphviz
        except ImportError:
            print("Install graphviz: pip install graphviz")
            return
        
        dot = graphviz.Digraph(comment='Model Lineage')
        
        def add_node(mid):
            if mid not in self.lineage_graph:
                return
            
            info = self.lineage_graph[mid]
            label = f"{mid}\\n{info['created_at'][:10]}"
            
            if info['retirement_info']:
                dot.node(mid, label, style='filled', fillcolor='lightgray')
            else:
                dot.node(mid, label)
            
            # Add edges to parents
            for parent_id in info['parent_models']:
                dot.edge(parent_id, mid)
                add_node(parent_id)
        
        add_node(model_id)
        
        # Render graph
        dot.render('model_lineage.gv', view=True)

Best Practices Checklist

Continual Learning Best Practices

  • Choose Right Strategy: Select replay, regularization, or architecture based on constraints
  • Monitor Forgetting: Track backward transfer and forgetting metrics
  • Balance Stability-Plasticity: Tune hyperparameters for optimal balance
  • Validate Frequently: Evaluate on all tasks after each new task
  • Document Task Boundaries: Clearly define when new tasks begin
  • Plan Capacity: Ensure model has enough capacity for expected tasks
  • Test Transfer: Measure forward transfer between tasks

Model Lifecycle Best Practices

  • Version Everything: Models, data, code, configurations
  • Automate Deployment: Use CI/CD pipelines for model deployment
  • Monitor Continuously: Track performance, latency, drift in production
  • Plan Deprecation: Have clear retirement criteria and processes
  • Maintain Lineage: Track complete history from training to retirement
  • Document Decisions: Record why models were created, changed, retired
  • Security: Control access to model artifacts and endpoints

Next Steps

In the next tutorial, we'll cover:

  • Complete Production Pipeline: End-to-end example from training to serving
  • Scaling Strategies: Distributed training and inference at scale
  • Cost Optimization: Reducing training and inference costs
  • Team Collaboration: MLOps workflows for teams
  • Case Studies: Real-world examples and lessons learned