Commit ·
f42d9a1
1
Parent(s): 5a593b3
Phase 4: Add multi-task learning, P-Tuning, SI/LwF continual learning, automated tests, deployment templates
Browse files- continual_learning.py +589 -0
- multi_task.py +427 -0
- p_tuning.py +295 -0
- test_tutorial_examples.py +249 -0
- utils/__init__.py +41 -75
continual_learning.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Continual Learning Utilities for Nexuss Transformer Framework
|
| 3 |
+
Mechanisms to avoid catastrophic forgetting during continuous training
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Optional, List, Dict, Any, Tuple
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
import copy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class EWCConfig:
|
| 17 |
+
"""Configuration for Elastic Weight Consolidation"""
|
| 18 |
+
|
| 19 |
+
ewc_lambda: float = 1000.0 # Strength of EWC regularization
|
| 20 |
+
fisher_samples: int = 200 # Number of samples to estimate Fisher information
|
| 21 |
+
damping: float = 0.1 # Damping factor for Fisher matrix
|
| 22 |
+
mc_samples: int = 1 # Monte Carlo samples for Fisher estimation
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class ReplayConfig:
|
| 27 |
+
"""Configuration for Experience Replay"""
|
| 28 |
+
|
| 29 |
+
replay_size: int = 1000 # Size of replay buffer
|
| 30 |
+
replay_ratio: float = 0.5 # Ratio of replay data in each batch
|
| 31 |
+
selection_strategy: str = "uniform" # uniform, recent, diverse
|
| 32 |
+
reservoir_sampling: bool = True # Use reservoir sampling for streaming data
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class GEMConfig:
|
| 37 |
+
"""Configuration for Gradient Episodic Memory"""
|
| 38 |
+
|
| 39 |
+
memory_size: int = 100 # Number of examples per task
|
| 40 |
+
num_tasks: int = 5 # Expected number of tasks
|
| 41 |
+
use_quadprog: bool = True # Use quadratic programming for constraint solving
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class ContinualLearningConfig:
|
| 46 |
+
"""Unified configuration for continual learning strategies"""
|
| 47 |
+
|
| 48 |
+
strategy: str = "none" # none, ewc, replay, gem, lwf
|
| 49 |
+
ewc: Optional[EWCConfig] = field(default_factory=EWCConfig)
|
| 50 |
+
replay: Optional[ReplayConfig] = field(default_factory=ReplayConfig)
|
| 51 |
+
gem: Optional[GEMConfig] = field(default_factory=GEMConfig)
|
| 52 |
+
|
| 53 |
+
# LwF (Learning without Forgetting) settings
|
| 54 |
+
lwf_alpha: float = 1.0 # Distillation loss weight
|
| 55 |
+
lwf_temperature: float = 2.0 # Temperature for knowledge distillation
|
| 56 |
+
|
| 57 |
+
# Regularization
|
| 58 |
+
weight_decay: float = 0.01
|
| 59 |
+
grad_clip: float = 1.0
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class EWCRegularizer:
|
| 63 |
+
"""Elastic Weight Consolidation implementation"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, model: nn.Module, config: EWCConfig):
|
| 66 |
+
self.model = model
|
| 67 |
+
self.config = config
|
| 68 |
+
self.fisher: Dict[str, torch.Tensor] = {}
|
| 69 |
+
self.optimal_params: Dict[str, torch.Tensor] = {}
|
| 70 |
+
|
| 71 |
+
def compute_fisher(self, dataloader: DataLoader, device: torch.device):
|
| 72 |
+
"""Compute Fisher Information Matrix diagonal approximation"""
|
| 73 |
+
|
| 74 |
+
self.model.train()
|
| 75 |
+
fisher_dict = {name: torch.zeros_like(param)
|
| 76 |
+
for name, param in self.model.named_parameters()
|
| 77 |
+
if param.requires_grad}
|
| 78 |
+
|
| 79 |
+
samples_processed = 0
|
| 80 |
+
|
| 81 |
+
for batch in dataloader:
|
| 82 |
+
if samples_processed >= self.config.fisher_samples:
|
| 83 |
+
break
|
| 84 |
+
|
| 85 |
+
self.model.zero_grad()
|
| 86 |
+
|
| 87 |
+
# Forward pass
|
| 88 |
+
inputs = batch["input_ids"].to(device) if isinstance(batch, dict) else batch.to(device)
|
| 89 |
+
outputs = self.model(inputs)
|
| 90 |
+
|
| 91 |
+
# Compute log-likelihood gradient
|
| 92 |
+
log_probs = torch.log_softmax(outputs.logits, dim=-1)
|
| 93 |
+
loss = log_probs.mean()
|
| 94 |
+
|
| 95 |
+
# Compute gradients
|
| 96 |
+
grads = torch.autograd.grad(loss, [p for p in self.model.parameters() if p.requires_grad],
|
| 97 |
+
retain_graph=False)
|
| 98 |
+
|
| 99 |
+
# Accumulate squared gradients (Fisher diagonal)
|
| 100 |
+
for (name, _), grad in zip(self.model.named_parameters(), grads):
|
| 101 |
+
if name in fisher_dict:
|
| 102 |
+
fisher_dict[name] += grad.pow(2)
|
| 103 |
+
|
| 104 |
+
samples_processed += inputs.size(0)
|
| 105 |
+
|
| 106 |
+
# Average and store
|
| 107 |
+
n_samples = max(samples_processed, 1)
|
| 108 |
+
self.fisher = {name: tensor / n_samples + self.config.damping
|
| 109 |
+
for name, tensor in fisher_dict.items()}
|
| 110 |
+
|
| 111 |
+
# Store optimal parameters
|
| 112 |
+
self.optimal_params = {name: param.clone().detach()
|
| 113 |
+
for name, param in self.model.named_parameters()
|
| 114 |
+
if param.requires_grad}
|
| 115 |
+
|
| 116 |
+
def compute_ewc_loss(self) -> torch.Tensor:
|
| 117 |
+
"""Compute EWC regularization loss"""
|
| 118 |
+
|
| 119 |
+
if not self.fisher or not self.optimal_params:
|
| 120 |
+
return torch.tensor(0.0)
|
| 121 |
+
|
| 122 |
+
ewc_loss = torch.tensor(0.0)
|
| 123 |
+
|
| 124 |
+
for name, param in self.model.named_parameters():
|
| 125 |
+
if param.requires_grad and name in self.fisher:
|
| 126 |
+
delta = param - self.optimal_params[name]
|
| 127 |
+
ewc_loss += (self.fisher[name] * delta.pow(2)).sum()
|
| 128 |
+
|
| 129 |
+
return self.config.ewc_lambda * ewc_loss
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class ReplayBuffer:
|
| 133 |
+
"""Experience Replay Buffer for continual learning"""
|
| 134 |
+
|
| 135 |
+
def __init__(self, config: ReplayConfig):
|
| 136 |
+
self.config = config
|
| 137 |
+
self.buffer: List[Dict[str, Any]] = []
|
| 138 |
+
self.task_data: Dict[int, List[Dict[str, Any]]] = {}
|
| 139 |
+
|
| 140 |
+
def add(self, samples: List[Dict[str, Any]], task_id: Optional[int] = None):
|
| 141 |
+
"""Add samples to replay buffer"""
|
| 142 |
+
|
| 143 |
+
if self.config.reservoir_sampling and len(self.buffer) + len(samples) > self.config.replay_size:
|
| 144 |
+
# Reservoir sampling for streaming data
|
| 145 |
+
for sample in samples:
|
| 146 |
+
if len(self.buffer) < self.config.replay_size:
|
| 147 |
+
self.buffer.append(sample)
|
| 148 |
+
else:
|
| 149 |
+
# Randomly replace with decreasing probability
|
| 150 |
+
j = torch.randint(0, len(self.buffer) + 1, (1,)).item()
|
| 151 |
+
if j < self.config.replay_size:
|
| 152 |
+
self.buffer[j] = sample
|
| 153 |
+
else:
|
| 154 |
+
self.buffer.extend(samples)
|
| 155 |
+
|
| 156 |
+
# Trim if exceeds size
|
| 157 |
+
if len(self.buffer) > self.config.replay_size:
|
| 158 |
+
if self.config.selection_strategy == "recent":
|
| 159 |
+
self.buffer = self.buffer[-self.config.replay_size:]
|
| 160 |
+
elif self.config.selection_strategy == "diverse":
|
| 161 |
+
# Simple diversity: keep every nth item
|
| 162 |
+
step = len(self.buffer) // self.config.replay_size
|
| 163 |
+
self.buffer = self.buffer[::step][:self.config.replay_size]
|
| 164 |
+
else: # uniform
|
| 165 |
+
indices = torch.randperm(len(self.buffer))[:self.config.replay_size]
|
| 166 |
+
self.buffer = [self.buffer[i] for i in indices]
|
| 167 |
+
|
| 168 |
+
# Store by task if task_id provided
|
| 169 |
+
if task_id is not None:
|
| 170 |
+
if task_id not in self.task_data:
|
| 171 |
+
self.task_data[task_id] = []
|
| 172 |
+
self.task_data[task_id].extend(samples)
|
| 173 |
+
|
| 174 |
+
def get_batch(self, current_batch: Dict[str, Any]) -> Dict[str, Any]:
|
| 175 |
+
"""Mix replay data with current batch"""
|
| 176 |
+
|
| 177 |
+
if not self.buffer:
|
| 178 |
+
return current_batch
|
| 179 |
+
|
| 180 |
+
replay_size = int(current_batch["input_ids"].size(0) * self.config.replay_ratio)
|
| 181 |
+
replay_size = min(replay_size, len(self.buffer))
|
| 182 |
+
|
| 183 |
+
if replay_size == 0:
|
| 184 |
+
return current_batch
|
| 185 |
+
|
| 186 |
+
# Sample from buffer
|
| 187 |
+
indices = torch.randperm(len(self.buffer))[:replay_size]
|
| 188 |
+
replay_samples = [self.buffer[i] for i in indices]
|
| 189 |
+
|
| 190 |
+
# Combine with current batch (simplified - in practice need proper merging)
|
| 191 |
+
# This is a placeholder - actual implementation depends on your data format
|
| 192 |
+
return current_batch # TODO: Implement proper batch merging
|
| 193 |
+
|
| 194 |
+
def get_task_buffer(self, task_id: int) -> List[Dict[str, Any]]:
|
| 195 |
+
"""Get replay buffer for specific task"""
|
| 196 |
+
return self.task_data.get(task_id, [])
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class GEMOptimizer:
|
| 200 |
+
"""Gradient Episodic Memory optimizer"""
|
| 201 |
+
|
| 202 |
+
def __init__(self, model: nn.Module, config: GEMConfig):
|
| 203 |
+
self.model = model
|
| 204 |
+
self.config = config
|
| 205 |
+
self.memory: Dict[int, List[Dict[str, Any]]] = {i: [] for i in range(config.num_tasks)}
|
| 206 |
+
self.gradient_memory: Dict[int, torch.Tensor] = {}
|
| 207 |
+
|
| 208 |
+
def store_in_memory(self, samples: List[Dict[str, Any]], task_id: int):
|
| 209 |
+
"""Store samples in task-specific memory"""
|
| 210 |
+
|
| 211 |
+
available_space = self.config.memory_size - len(self.memory[task_id])
|
| 212 |
+
|
| 213 |
+
if available_space >= len(samples):
|
| 214 |
+
self.memory[task_id].extend(samples)
|
| 215 |
+
else:
|
| 216 |
+
# Random subsample
|
| 217 |
+
indices = torch.randperm(len(samples))[:available_space]
|
| 218 |
+
self.memory[task_id].extend([samples[i] for i in indices])
|
| 219 |
+
|
| 220 |
+
def compute_gradient_constraints(self, task_id: int, device: torch.device) -> List[torch.Tensor]:
|
| 221 |
+
"""Compute stored gradients for previous tasks"""
|
| 222 |
+
|
| 223 |
+
constraints = []
|
| 224 |
+
|
| 225 |
+
for prev_task_id in range(task_id):
|
| 226 |
+
if prev_task_id not in self.gradient_memory:
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
constraints.append(self.gradient_memory[prev_task_id])
|
| 230 |
+
|
| 231 |
+
return constraints
|
| 232 |
+
|
| 233 |
+
def project_gradient(self, gradient: torch.Tensor, constraints: List[torch.Tensor]) -> torch.Tensor:
|
| 234 |
+
"""Project gradient to satisfy memory constraints using quadratic programming"""
|
| 235 |
+
|
| 236 |
+
if not constraints:
|
| 237 |
+
return gradient
|
| 238 |
+
|
| 239 |
+
projected = gradient.clone()
|
| 240 |
+
|
| 241 |
+
for constraint in constraints:
|
| 242 |
+
# Check if gradient violates constraint
|
| 243 |
+
dot_product = torch.dot(projected.flatten(), constraint.flatten())
|
| 244 |
+
|
| 245 |
+
if dot_product < 0:
|
| 246 |
+
# Project gradient
|
| 247 |
+
norm_sq = constraint.pow(2).sum()
|
| 248 |
+
if norm_sq > 1e-8:
|
| 249 |
+
projection_coef = dot_product / norm_sq
|
| 250 |
+
projected -= projection_coef * constraint
|
| 251 |
+
|
| 252 |
+
return projected
|
| 253 |
+
|
| 254 |
+
def update_gradient_memory(self, task_id: int, dataloader: DataLoader, device: torch.device):
|
| 255 |
+
"""Update stored gradients for current task"""
|
| 256 |
+
|
| 257 |
+
self.model.eval()
|
| 258 |
+
|
| 259 |
+
# Compute average gradient over memory samples
|
| 260 |
+
total_gradient = None
|
| 261 |
+
count = 0
|
| 262 |
+
|
| 263 |
+
for batch in dataloader:
|
| 264 |
+
self.model.zero_grad()
|
| 265 |
+
|
| 266 |
+
inputs = batch["input_ids"].to(device) if isinstance(batch, dict) else batch.to(device)
|
| 267 |
+
outputs = self.model(inputs)
|
| 268 |
+
|
| 269 |
+
loss = outputs.loss if hasattr(outputs, 'loss') else outputs.logits.mean()
|
| 270 |
+
|
| 271 |
+
grads = torch.autograd.grad(loss, [p for p in self.model.parameters() if p.requires_grad])
|
| 272 |
+
|
| 273 |
+
# Flatten and concatenate all gradients
|
| 274 |
+
flat_grad = torch.cat([g.flatten() for g in grads])
|
| 275 |
+
|
| 276 |
+
if total_gradient is None:
|
| 277 |
+
total_gradient = flat_grad
|
| 278 |
+
else:
|
| 279 |
+
total_gradient += flat_grad
|
| 280 |
+
|
| 281 |
+
count += 1
|
| 282 |
+
|
| 283 |
+
if count > 0 and total_gradient is not None:
|
| 284 |
+
self.gradient_memory[task_id] = total_gradient / count
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class LwFLoss(nn.Module):
|
| 288 |
+
"""Learning without Forgetting loss using knowledge distillation"""
|
| 289 |
+
|
| 290 |
+
def __init__(self, config: ContinualLearningConfig):
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.config = config
|
| 293 |
+
self.kl_div = nn.KLDivLoss(reduction='batchmean')
|
| 294 |
+
|
| 295 |
+
def forward(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor) -> torch.Tensor:
|
| 296 |
+
"""Compute LwF distillation loss"""
|
| 297 |
+
|
| 298 |
+
T = self.config.lwf_temperature
|
| 299 |
+
|
| 300 |
+
# Apply temperature scaling
|
| 301 |
+
student_log_probs = torch.log_softmax(student_logits / T, dim=-1)
|
| 302 |
+
teacher_probs = torch.softmax(teacher_logits / T, dim=-1)
|
| 303 |
+
|
| 304 |
+
# Knowledge distillation loss
|
| 305 |
+
kd_loss = self.kl_div(student_log_probs, teacher_probs) * (T ** 2)
|
| 306 |
+
|
| 307 |
+
return self.config.lwf_alpha * kd_loss
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class SIRegularizer:
|
| 311 |
+
"""Synaptic Intelligence implementation for continual learning."""
|
| 312 |
+
|
| 313 |
+
def __init__(self, model: nn.Module, c: float = 0.1):
|
| 314 |
+
self.model = model
|
| 315 |
+
self.c = c # Importance weight
|
| 316 |
+
self.importance: Dict[str, torch.Tensor] = {}
|
| 317 |
+
self.prev_params: Dict[str, torch.Tensor] = {}
|
| 318 |
+
self.trajectory: Dict[str, torch.Tensor] = {}
|
| 319 |
+
|
| 320 |
+
def initialize_trajectory(self):
|
| 321 |
+
"""Initialize trajectory tracking for parameters."""
|
| 322 |
+
self.prev_params = {
|
| 323 |
+
name: param.clone().detach()
|
| 324 |
+
for name, param in self.model.named_parameters()
|
| 325 |
+
if param.requires_grad
|
| 326 |
+
}
|
| 327 |
+
self.trajectory = {
|
| 328 |
+
name: torch.zeros_like(param)
|
| 329 |
+
for name, param in self.model.named_parameters()
|
| 330 |
+
if param.requires_grad
|
| 331 |
+
}
|
| 332 |
+
self.importance = {
|
| 333 |
+
name: torch.zeros_like(param)
|
| 334 |
+
for name, param in self.model.named_parameters()
|
| 335 |
+
if param.requires_grad
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
def update_trajectory(self):
|
| 339 |
+
"""Update parameter change trajectory after each training step."""
|
| 340 |
+
with torch.no_grad():
|
| 341 |
+
for name, param in self.model.named_parameters():
|
| 342 |
+
if param.requires_grad and name in self.prev_params:
|
| 343 |
+
delta = param - self.prev_params[name]
|
| 344 |
+
self.trajectory[name] += delta.pow(2)
|
| 345 |
+
self.prev_params[name] = param.clone().detach()
|
| 346 |
+
|
| 347 |
+
def compute_importance(self, loss_change: float):
|
| 348 |
+
"""
|
| 349 |
+
Compute parameter importance based on loss change.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
loss_change: Change in loss from previous iteration
|
| 353 |
+
"""
|
| 354 |
+
if loss_change > 0: # Only update if loss decreased
|
| 355 |
+
for name in self.importance:
|
| 356 |
+
if name in self.trajectory:
|
| 357 |
+
denom = self.trajectory[name] + 1e-8
|
| 358 |
+
self.importance[name] += loss_change / denom
|
| 359 |
+
|
| 360 |
+
def compute_si_loss(self) -> torch.Tensor:
|
| 361 |
+
"""Compute Synaptic Intelligence regularization loss."""
|
| 362 |
+
si_loss = torch.tensor(0.0)
|
| 363 |
+
|
| 364 |
+
for name, param in self.model.named_parameters():
|
| 365 |
+
if param.requires_grad and name in self.importance:
|
| 366 |
+
delta = param - self.prev_params.get(name, param)
|
| 367 |
+
si_loss += (self.importance[name] * delta.pow(2)).sum()
|
| 368 |
+
|
| 369 |
+
return self.c * si_loss
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class LwFRegularizer(nn.Module):
|
| 373 |
+
"""Learning without Forgetting using knowledge distillation."""
|
| 374 |
+
|
| 375 |
+
def __init__(self, alpha: float = 0.5, temperature: float = 2.0):
|
| 376 |
+
super().__init__()
|
| 377 |
+
self.alpha = alpha # Distillation loss weight
|
| 378 |
+
self.temperature = temperature
|
| 379 |
+
self.kl_div = nn.KLDivLoss(reduction='batchmean')
|
| 380 |
+
self.old_outputs: Dict[str, torch.Tensor] = {}
|
| 381 |
+
|
| 382 |
+
def store_old_outputs(self, task_name: str, outputs: torch.Tensor):
|
| 383 |
+
"""Store outputs from old model for distillation."""
|
| 384 |
+
self.old_outputs[task_name] = outputs.detach()
|
| 385 |
+
|
| 386 |
+
def clear_old_outputs(self):
|
| 387 |
+
"""Clear stored old outputs."""
|
| 388 |
+
self.old_outputs.clear()
|
| 389 |
+
|
| 390 |
+
def forward(
|
| 391 |
+
self,
|
| 392 |
+
student_logits: torch.Tensor,
|
| 393 |
+
teacher_logits: torch.Tensor,
|
| 394 |
+
task_name: Optional[str] = None
|
| 395 |
+
) -> torch.Tensor:
|
| 396 |
+
"""
|
| 397 |
+
Compute LwF distillation loss.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
student_logits: Current model logits
|
| 401 |
+
teacher_logits: Old model logits (stored or provided)
|
| 402 |
+
task_name: Optional task name for stored outputs
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Knowledge distillation loss
|
| 406 |
+
"""
|
| 407 |
+
T = self.temperature
|
| 408 |
+
|
| 409 |
+
# Apply temperature scaling
|
| 410 |
+
student_log_probs = torch.log_softmax(student_logits / T, dim=-1)
|
| 411 |
+
teacher_probs = torch.softmax(teacher_logits / T, dim=-1)
|
| 412 |
+
|
| 413 |
+
# Knowledge distillation loss
|
| 414 |
+
kd_loss = self.kl_div(student_log_probs, teacher_probs) * (T ** 2)
|
| 415 |
+
|
| 416 |
+
return self.alpha * kd_loss
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def create_continual_learning_wrapper(trainer, config: ContinualLearningConfig):
|
| 420 |
+
"""
|
| 421 |
+
Wrap existing trainer with continual learning capabilities.
|
| 422 |
+
Returns modified trainer with CL methods integrated.
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
if config.strategy == "ewc":
|
| 426 |
+
trainer.ewc_regularizer = EWCRegularizer(trainer.model, config.ewc)
|
| 427 |
+
|
| 428 |
+
# Hook into training loop to add EWC loss
|
| 429 |
+
original_compute_loss = trainer.compute_loss
|
| 430 |
+
|
| 431 |
+
def compute_loss_with_ewc(model, inputs, return_outputs=False):
|
| 432 |
+
loss = original_compute_loss(model, inputs, return_outputs)
|
| 433 |
+
ewc_loss = trainer.ewc_regularizer.compute_ewc_loss()
|
| 434 |
+
|
| 435 |
+
if return_outputs:
|
| 436 |
+
return loss + ewc_loss, outputs
|
| 437 |
+
return loss + ewc_loss
|
| 438 |
+
|
| 439 |
+
trainer.compute_loss = compute_loss_with_ewc
|
| 440 |
+
|
| 441 |
+
elif config.strategy == "replay":
|
| 442 |
+
trainer.replay_buffer = ReplayBuffer(config.replay)
|
| 443 |
+
|
| 444 |
+
# Modify data loading to include replay
|
| 445 |
+
# Implementation depends on trainer's data loading mechanism
|
| 446 |
+
|
| 447 |
+
elif config.strategy == "gem":
|
| 448 |
+
trainer.gem_optimizer = GEMOptimizer(trainer.model, config.gem)
|
| 449 |
+
|
| 450 |
+
# Hook into optimization step to project gradients
|
| 451 |
+
# Implementation depends on trainer's optimization loop
|
| 452 |
+
|
| 453 |
+
elif config.strategy == "lwf":
|
| 454 |
+
trainer.lwf_loss = LwFLoss(config)
|
| 455 |
+
# Store teacher model outputs for distillation
|
| 456 |
+
# Implementation depends on training setup
|
| 457 |
+
|
| 458 |
+
elif config.strategy == "si":
|
| 459 |
+
trainer.si_regularizer = SIRegularizer(trainer.model, c=config.weight_decay)
|
| 460 |
+
|
| 461 |
+
# Initialize trajectory tracking
|
| 462 |
+
trainer.si_regularizer.initialize_trajectory()
|
| 463 |
+
|
| 464 |
+
# Hook into training loop
|
| 465 |
+
original_compute_loss = trainer.compute_loss
|
| 466 |
+
|
| 467 |
+
def compute_loss_with_si(model, inputs, return_outputs=False):
|
| 468 |
+
loss = original_compute_loss(model, inputs, return_outputs)
|
| 469 |
+
si_loss = trainer.si_regularizer.compute_si_loss()
|
| 470 |
+
|
| 471 |
+
if return_outputs:
|
| 472 |
+
return loss + si_loss, outputs
|
| 473 |
+
return loss + si_loss
|
| 474 |
+
|
| 475 |
+
trainer.compute_loss = compute_loss_with_si
|
| 476 |
+
|
| 477 |
+
# Hook into optimizer step to update trajectory
|
| 478 |
+
original_step = trainer.optimizer.step if hasattr(trainer, 'optimizer') else None
|
| 479 |
+
if original_step:
|
| 480 |
+
def step_with_trajectory():
|
| 481 |
+
original_step()
|
| 482 |
+
trainer.si_regularizer.update_trajectory()
|
| 483 |
+
trainer.optimizer.step = step_with_trajectory
|
| 484 |
+
|
| 485 |
+
return trainer
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class ContinualLearningWrapper:
|
| 489 |
+
"""
|
| 490 |
+
High-level wrapper for applying continual learning methods.
|
| 491 |
+
|
| 492 |
+
Provides a unified API for EWC, SI, and LwF regularization.
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
model: Model to wrap
|
| 496 |
+
method: Continual learning method (ewc, si, lwf)
|
| 497 |
+
"""
|
| 498 |
+
|
| 499 |
+
def __init__(self, model: nn.Module, method: str = "ewc"):
|
| 500 |
+
self.model = model
|
| 501 |
+
self.method = method
|
| 502 |
+
self.ewc = None
|
| 503 |
+
self.si = None
|
| 504 |
+
self.lwf = None
|
| 505 |
+
|
| 506 |
+
if method == "ewc":
|
| 507 |
+
self.ewc = EWCRegularizer(model, EWCConfig())
|
| 508 |
+
elif method == "si":
|
| 509 |
+
self.si = SIRegularizer(model)
|
| 510 |
+
self.si.initialize_trajectory()
|
| 511 |
+
elif method == "lwf":
|
| 512 |
+
self.lwf = LwFRegularizer()
|
| 513 |
+
|
| 514 |
+
def apply_ewc_regularization(self, lambda_ewc: float = 0.5):
|
| 515 |
+
"""Apply Elastic Weight Consolidation regularization."""
|
| 516 |
+
if self.ewc is None:
|
| 517 |
+
self.ewc = EWCRegularizer(self.model, EWCConfig(ewc_lambda=lambda_ewc))
|
| 518 |
+
else:
|
| 519 |
+
self.ewc.config.ewc_lambda = lambda_ewc
|
| 520 |
+
return self
|
| 521 |
+
|
| 522 |
+
def apply_si_regularization(self, c: float = 0.1):
|
| 523 |
+
"""Apply Synaptic Intelligence regularization."""
|
| 524 |
+
if self.si is None:
|
| 525 |
+
self.si = SIRegularizer(self.model, c=c)
|
| 526 |
+
self.si.initialize_trajectory()
|
| 527 |
+
else:
|
| 528 |
+
self.si.c = c
|
| 529 |
+
return self
|
| 530 |
+
|
| 531 |
+
def apply_lwf_regularization(self, alpha: float = 0.5):
|
| 532 |
+
"""Apply Learning without Forgetting regularization."""
|
| 533 |
+
if self.lwf is None:
|
| 534 |
+
self.lwf = LwFRegularizer(alpha=alpha)
|
| 535 |
+
else:
|
| 536 |
+
self.lwf.alpha = alpha
|
| 537 |
+
return self
|
| 538 |
+
|
| 539 |
+
def compute_fisher(self, dataloader: DataLoader, device: torch.device):
|
| 540 |
+
"""Compute Fisher information matrix for EWC."""
|
| 541 |
+
if self.ewc:
|
| 542 |
+
self.ewc.compute_fisher(dataloader, device)
|
| 543 |
+
|
| 544 |
+
def get_regularization_loss(self) -> torch.Tensor:
|
| 545 |
+
"""Get current regularization loss."""
|
| 546 |
+
if self.ewc:
|
| 547 |
+
return self.ewc.compute_ewc_loss()
|
| 548 |
+
elif self.si:
|
| 549 |
+
return self.si.compute_si_loss()
|
| 550 |
+
return torch.tensor(0.0)
|
| 551 |
+
|
| 552 |
+
def progressive_unfreeze(
|
| 553 |
+
self,
|
| 554 |
+
start_layers: int = 4,
|
| 555 |
+
unfreeze_every_n_epochs: int = 2,
|
| 556 |
+
max_layers: Optional[int] = None
|
| 557 |
+
):
|
| 558 |
+
"""
|
| 559 |
+
Progressive unfreezing strategy for continual learning.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
start_layers: Number of layers to keep unfrozen initially
|
| 563 |
+
unfreeze_every_n_epochs: Epochs between unfreezing
|
| 564 |
+
max_layers: Maximum layers to unfreeze (None = all)
|
| 565 |
+
"""
|
| 566 |
+
self.start_layers = start_layers
|
| 567 |
+
self.unfreeze_every_n_epochs = unfreeze_every_n_epochs
|
| 568 |
+
self.max_layers = max_layers
|
| 569 |
+
self.current_epoch = 0
|
| 570 |
+
|
| 571 |
+
# Initially freeze all but top layers
|
| 572 |
+
self._unfreeze_layers(start_layers)
|
| 573 |
+
|
| 574 |
+
def _unfreeze_layers(self, num_layers: int):
|
| 575 |
+
"""Unfreeze top N layers of the model."""
|
| 576 |
+
layers = list(self.model.modules())
|
| 577 |
+
# Unfreeze from the end (top layers)
|
| 578 |
+
for layer in layers[-num_layers:]:
|
| 579 |
+
for param in layer.parameters():
|
| 580 |
+
param.requires_grad = True
|
| 581 |
+
|
| 582 |
+
def step_epoch(self):
|
| 583 |
+
"""Call at end of each epoch for progressive unfreezing."""
|
| 584 |
+
if hasattr(self, 'unfreeze_every_n_epochs'):
|
| 585 |
+
self.current_epoch += 1
|
| 586 |
+
if self.current_epoch % self.unfreeze_every_n_epochs == 0:
|
| 587 |
+
current_unfrozen = self.start_layers + (self.current_epoch // self.unfreeze_every_n_epochs) * 2
|
| 588 |
+
if self.max_layers is None or current_unfrozen <= self.max_layers:
|
| 589 |
+
self._unfreeze_layers(current_unfrozen)
|
multi_task.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-Task Learning Implementation for NTF
|
| 3 |
+
Supports task-specific heads for different fine-tuning objectives
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing import Dict, List, Optional, Any, Union
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from enum import Enum
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TaskType(str, Enum):
|
| 14 |
+
"""Supported task types for multi-task learning."""
|
| 15 |
+
CLASSIFICATION = "classification"
|
| 16 |
+
SEQUENCE_TO_SEQUENCE = "sequence_to_sequence"
|
| 17 |
+
TOKEN_CLASSIFICATION = "token_classification"
|
| 18 |
+
QUESTION_ANSWERING = "question_answering"
|
| 19 |
+
GENERATION = "generation"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class TaskHeadConfig:
|
| 24 |
+
"""Configuration for a task-specific head."""
|
| 25 |
+
|
| 26 |
+
task_name: str
|
| 27 |
+
head_type: TaskType
|
| 28 |
+
config: Dict[str, Any] = field(default_factory=dict)
|
| 29 |
+
|
| 30 |
+
def __post_init__(self):
|
| 31 |
+
if isinstance(self.head_type, str):
|
| 32 |
+
self.head_type = TaskType(self.head_type)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ClassificationHead(nn.Module):
|
| 36 |
+
"""Classification head for sequence classification tasks."""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
hidden_size: int,
|
| 41 |
+
num_labels: int,
|
| 42 |
+
dropout: float = 0.1,
|
| 43 |
+
**kwargs
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.dropout = nn.Dropout(dropout)
|
| 47 |
+
self.classifier = nn.Linear(hidden_size, num_labels)
|
| 48 |
+
self.num_labels = num_labels
|
| 49 |
+
|
| 50 |
+
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 51 |
+
# Use pooled output (last hidden state of [CLS] or mean pooling)
|
| 52 |
+
if attention_mask is not None:
|
| 53 |
+
# Mean pooling with mask
|
| 54 |
+
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
|
| 55 |
+
sum_embeddings = (hidden_states * mask_expanded).sum(1)
|
| 56 |
+
sum_mask = mask_expanded.sum(1).clamp(min=1e-9)
|
| 57 |
+
pooled_output = sum_embeddings / sum_mask
|
| 58 |
+
else:
|
| 59 |
+
pooled_output = hidden_states[:, -1, :] # Last token
|
| 60 |
+
|
| 61 |
+
pooled_output = self.dropout(pooled_output)
|
| 62 |
+
return self.classifier(pooled_output)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SequenceToSequenceHead(nn.Module):
|
| 66 |
+
"""Sequence-to-sequence head for generation tasks."""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
hidden_size: int,
|
| 71 |
+
vocab_size: int,
|
| 72 |
+
max_length: int = 512,
|
| 73 |
+
**kwargs
|
| 74 |
+
):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.output_projection = nn.Linear(hidden_size, vocab_size)
|
| 77 |
+
self.max_length = max_length
|
| 78 |
+
self.vocab_size = vocab_size
|
| 79 |
+
|
| 80 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
return self.output_projection(hidden_states)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class TokenClassificationHead(nn.Module):
|
| 85 |
+
"""Token-level classification head (NER, POS tagging, etc.)."""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
hidden_size: int,
|
| 90 |
+
num_labels: int,
|
| 91 |
+
dropout: float = 0.1,
|
| 92 |
+
**kwargs
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.dropout = nn.Dropout(dropout)
|
| 96 |
+
self.classifier = nn.Linear(hidden_size, num_labels)
|
| 97 |
+
self.num_labels = num_labels
|
| 98 |
+
|
| 99 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
hidden_states = self.dropout(hidden_states)
|
| 101 |
+
return self.classifier(hidden_states)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class QuestionAnsweringHead(nn.Module):
|
| 105 |
+
"""Head for extractive question answering."""
|
| 106 |
+
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
hidden_size: int,
|
| 110 |
+
dropout: float = 0.1,
|
| 111 |
+
**kwargs
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.qa_outputs = nn.Linear(hidden_size, 2) # start and end logits
|
| 115 |
+
|
| 116 |
+
def forward(self, hidden_states: torch.Tensor) -> tuple:
|
| 117 |
+
logits = self.qa_outputs(hidden_states)
|
| 118 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 119 |
+
return start_logits.squeeze(-1), end_logits.squeeze(-1)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class TaskHead(nn.Module):
|
| 123 |
+
"""Wrapper for task-specific heads."""
|
| 124 |
+
|
| 125 |
+
HEAD_CLASSES = {
|
| 126 |
+
TaskType.CLASSIFICATION: ClassificationHead,
|
| 127 |
+
TaskType.SEQUENCE_TO_SEQUENCE: SequenceToSequenceHead,
|
| 128 |
+
TaskType.TOKEN_CLASSIFICATION: TokenClassificationHead,
|
| 129 |
+
TaskType.QUESTION_ANSWERING: QuestionAnsweringHead,
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
def __init__(self, config: TaskHeadConfig, hidden_size: int, vocab_size: Optional[int] = None):
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.config = config
|
| 135 |
+
self.task_name = config.task_name
|
| 136 |
+
self.head_type = config.head_type
|
| 137 |
+
|
| 138 |
+
head_config = dict(config.config)
|
| 139 |
+
head_config["hidden_size"] = hidden_size
|
| 140 |
+
|
| 141 |
+
if vocab_size is not None:
|
| 142 |
+
head_config["vocab_size"] = vocab_size
|
| 143 |
+
|
| 144 |
+
head_class = self.HEAD_CLASSES.get(head_type)
|
| 145 |
+
if head_class is None:
|
| 146 |
+
raise ValueError(f"Unsupported task type: {head_type}")
|
| 147 |
+
|
| 148 |
+
self.head = head_class(**head_config)
|
| 149 |
+
|
| 150 |
+
def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 151 |
+
return self.head(hidden_states, **kwargs)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class MultiTaskModel(nn.Module):
|
| 155 |
+
"""
|
| 156 |
+
Multi-task model with task-specific heads sharing a common base.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
base_model: Base transformer model
|
| 160 |
+
base_model_name: Name or path of base model
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(self, base_model=None, base_model_name: Optional[str] = None):
|
| 164 |
+
super().__init__()
|
| 165 |
+
|
| 166 |
+
if base_model is None and base_model_name is None:
|
| 167 |
+
raise ValueError("Must provide either base_model or base_model_name")
|
| 168 |
+
|
| 169 |
+
if base_model is None:
|
| 170 |
+
from transformers import AutoModelForCausalLM
|
| 171 |
+
self.base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
|
| 172 |
+
else:
|
| 173 |
+
self.base_model = base_model
|
| 174 |
+
|
| 175 |
+
# Get hidden size from base model
|
| 176 |
+
self.hidden_size = getattr(self.base_model.config, 'hidden_size', 768)
|
| 177 |
+
self.vocab_size = getattr(self.base_model.config, 'vocab_size', None)
|
| 178 |
+
|
| 179 |
+
# Task heads registry
|
| 180 |
+
self.task_heads: Dict[str, TaskHead] = nn.ModuleDict()
|
| 181 |
+
self.active_task: Optional[str] = None
|
| 182 |
+
|
| 183 |
+
# Task weights for balanced training
|
| 184 |
+
self.task_weights: Dict[str, float] = {}
|
| 185 |
+
|
| 186 |
+
def add_task_head(
|
| 187 |
+
self,
|
| 188 |
+
task_name: str,
|
| 189 |
+
head_type: Union[str, TaskType],
|
| 190 |
+
config: Optional[Dict[str, Any]] = None
|
| 191 |
+
):
|
| 192 |
+
"""
|
| 193 |
+
Add a task-specific head to the model.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
task_name: Unique name for this task
|
| 197 |
+
head_type: Type of task (classification, seq2seq, etc.)
|
| 198 |
+
config: Task-specific configuration
|
| 199 |
+
"""
|
| 200 |
+
if config is None:
|
| 201 |
+
config = {}
|
| 202 |
+
|
| 203 |
+
task_config = TaskHeadConfig(
|
| 204 |
+
task_name=task_name,
|
| 205 |
+
head_type=head_type,
|
| 206 |
+
config=config
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
task_head = TaskHead(task_config, self.hidden_size, self.vocab_size)
|
| 210 |
+
self.task_heads[task_name] = task_head
|
| 211 |
+
self.task_weights[task_name] = 1.0 # Default equal weight
|
| 212 |
+
|
| 213 |
+
def set_task_weights(self, weights: Dict[str, float]):
|
| 214 |
+
"""Set weights for each task in multi-task training."""
|
| 215 |
+
for task_name, weight in weights.items():
|
| 216 |
+
if task_name in self.task_heads:
|
| 217 |
+
self.task_weights[task_name] = weight
|
| 218 |
+
|
| 219 |
+
def set_active_task(self, task_name: str):
|
| 220 |
+
"""Set the currently active task for single-task inference."""
|
| 221 |
+
if task_name not in self.task_heads:
|
| 222 |
+
raise ValueError(f"Task '{task_name}' not found. Available: {list(self.task_heads.keys())}")
|
| 223 |
+
self.active_task = task_name
|
| 224 |
+
|
| 225 |
+
def forward(
|
| 226 |
+
self,
|
| 227 |
+
input_ids: torch.Tensor,
|
| 228 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 229 |
+
labels: Optional[torch.Tensor] = None,
|
| 230 |
+
task_name: Optional[str] = None,
|
| 231 |
+
**kwargs
|
| 232 |
+
) -> Dict[str, torch.Tensor]:
|
| 233 |
+
"""
|
| 234 |
+
Forward pass through base model and task head.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
input_ids: Input token IDs
|
| 238 |
+
attention_mask: Attention mask
|
| 239 |
+
labels: Optional labels for loss computation
|
| 240 |
+
task_name: Task to use (overrides active_task)
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
Dictionary containing logits and optionally loss
|
| 244 |
+
"""
|
| 245 |
+
# Determine which task to use
|
| 246 |
+
task = task_name or self.active_task
|
| 247 |
+
|
| 248 |
+
if task is None and len(self.task_heads) == 1:
|
| 249 |
+
task = list(self.task_heads.keys())[0]
|
| 250 |
+
elif task is None:
|
| 251 |
+
raise ValueError("No task specified and multiple heads available")
|
| 252 |
+
|
| 253 |
+
if task not in self.task_heads:
|
| 254 |
+
raise ValueError(f"Task '{task}' not found")
|
| 255 |
+
|
| 256 |
+
# Get base model outputs
|
| 257 |
+
base_outputs = self.base_model(
|
| 258 |
+
input_ids=input_ids,
|
| 259 |
+
attention_mask=attention_mask,
|
| 260 |
+
output_hidden_states=True,
|
| 261 |
+
**kwargs
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Get last hidden state
|
| 265 |
+
hidden_states = base_outputs.hidden_states[-1]
|
| 266 |
+
|
| 267 |
+
# Apply task head
|
| 268 |
+
head = self.task_heads[task]
|
| 269 |
+
head_output = head(hidden_states, attention_mask=attention_mask)
|
| 270 |
+
|
| 271 |
+
result = {"logits": head_output}
|
| 272 |
+
|
| 273 |
+
# Compute loss if labels provided
|
| 274 |
+
if labels is not None:
|
| 275 |
+
if head.head_type == TaskType.CLASSIFICATION:
|
| 276 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 277 |
+
loss = loss_fct(head_output.view(-1, head.num_labels), labels.view(-1))
|
| 278 |
+
elif head.head_type == TaskType.SEQUENCE_TO_SEQUENCE:
|
| 279 |
+
shift_logits = head_output[..., :-1, :].contiguous()
|
| 280 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 281 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 282 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 283 |
+
else:
|
| 284 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 285 |
+
loss = loss_fct(head_output.view(-1, head_output.size(-1)), labels.view(-1))
|
| 286 |
+
|
| 287 |
+
result["loss"] = loss
|
| 288 |
+
|
| 289 |
+
return result
|
| 290 |
+
|
| 291 |
+
def get_num_tasks(self) -> int:
|
| 292 |
+
"""Return number of task heads."""
|
| 293 |
+
return len(self.task_heads)
|
| 294 |
+
|
| 295 |
+
def list_tasks(self) -> List[str]:
|
| 296 |
+
"""Return list of task names."""
|
| 297 |
+
return list(self.task_heads.keys())
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class MultiTaskTrainer:
|
| 301 |
+
"""
|
| 302 |
+
Trainer for multi-task learning with task-balanced loss.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
model: MultiTaskModel instance
|
| 306 |
+
task_datasets: Dictionary mapping task names to datasets
|
| 307 |
+
task_weights: Optional dictionary of task weights
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def __init__(
|
| 311 |
+
self,
|
| 312 |
+
model: MultiTaskModel,
|
| 313 |
+
task_datasets: Dict[str, Any],
|
| 314 |
+
task_weights: Optional[Dict[str, float]] = None,
|
| 315 |
+
tokenizer=None,
|
| 316 |
+
device: Optional[torch.device] = None
|
| 317 |
+
):
|
| 318 |
+
self.model = model
|
| 319 |
+
self.task_datasets = task_datasets
|
| 320 |
+
self.tokenizer = tokenizer
|
| 321 |
+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 322 |
+
|
| 323 |
+
# Set task weights
|
| 324 |
+
if task_weights:
|
| 325 |
+
self.model.set_task_weights(task_weights)
|
| 326 |
+
|
| 327 |
+
# Move model to device
|
| 328 |
+
self.model.to(self.device)
|
| 329 |
+
|
| 330 |
+
def train_epoch(
|
| 331 |
+
self,
|
| 332 |
+
optimizer: torch.optim.Optimizer,
|
| 333 |
+
batch_sizes: Dict[str, int] = None,
|
| 334 |
+
gradient_accumulation_steps: int = 1
|
| 335 |
+
) -> Dict[str, float]:
|
| 336 |
+
"""
|
| 337 |
+
Train one epoch across all tasks.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
optimizer: Optimizer for training
|
| 341 |
+
batch_sizes: Batch size per task
|
| 342 |
+
gradient_accumulation_steps: Steps before optimizer update
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
Dictionary of losses per task
|
| 346 |
+
"""
|
| 347 |
+
self.model.train()
|
| 348 |
+
task_losses = {task: 0.0 for task in self.task_datasets.keys()}
|
| 349 |
+
task_counts = {task: 0 for task in self.task_datasets.keys()}
|
| 350 |
+
|
| 351 |
+
# Simple round-robin training across tasks
|
| 352 |
+
for task_name, dataset in self.task_datasets.items():
|
| 353 |
+
weight = self.model.task_weights.get(task_name, 1.0)
|
| 354 |
+
|
| 355 |
+
for batch in dataset:
|
| 356 |
+
# Move batch to device
|
| 357 |
+
inputs = {k: v.to(self.device) if hasattr(v, 'to') else v
|
| 358 |
+
for k, v in batch.items()}
|
| 359 |
+
|
| 360 |
+
optimizer.zero_grad()
|
| 361 |
+
|
| 362 |
+
# Forward pass
|
| 363 |
+
outputs = self.model(
|
| 364 |
+
input_ids=inputs.get("input_ids"),
|
| 365 |
+
attention_mask=inputs.get("attention_mask"),
|
| 366 |
+
labels=inputs.get("labels"),
|
| 367 |
+
task_name=task_name
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
loss = outputs["loss"] * weight
|
| 371 |
+
loss.backward()
|
| 372 |
+
|
| 373 |
+
optimizer.step()
|
| 374 |
+
|
| 375 |
+
task_losses[task_name] += loss.item() / weight
|
| 376 |
+
task_counts[task_name] += 1
|
| 377 |
+
|
| 378 |
+
# Average losses
|
| 379 |
+
avg_losses = {
|
| 380 |
+
task: task_losses[task] / max(task_counts[task], 1)
|
| 381 |
+
for task in task_losses
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
return avg_losses
|
| 385 |
+
|
| 386 |
+
def evaluate(
|
| 387 |
+
self,
|
| 388 |
+
eval_datasets: Dict[str, Any],
|
| 389 |
+
metrics_fn: Optional[Dict[str, callable]] = None
|
| 390 |
+
) -> Dict[str, Dict[str, float]]:
|
| 391 |
+
"""
|
| 392 |
+
Evaluate model on all tasks.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
eval_datasets: Evaluation datasets per task
|
| 396 |
+
metrics_fn: Optional metric functions per task
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
Dictionary of metrics per task
|
| 400 |
+
"""
|
| 401 |
+
self.model.eval()
|
| 402 |
+
results = {}
|
| 403 |
+
|
| 404 |
+
with torch.no_grad():
|
| 405 |
+
for task_name, dataset in eval_datasets.items():
|
| 406 |
+
task_results = {"loss": 0.0, "count": 0}
|
| 407 |
+
|
| 408 |
+
for batch in dataset:
|
| 409 |
+
inputs = {k: v.to(self.device) if hasattr(v, 'to') else v
|
| 410 |
+
for k, v in batch.items()}
|
| 411 |
+
|
| 412 |
+
outputs = self.model(
|
| 413 |
+
input_ids=inputs.get("input_ids"),
|
| 414 |
+
attention_mask=inputs.get("attention_mask"),
|
| 415 |
+
labels=inputs.get("labels"),
|
| 416 |
+
task_name=task_name
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
task_results["loss"] += outputs["loss"].item()
|
| 420 |
+
task_results["count"] += 1
|
| 421 |
+
|
| 422 |
+
if task_results["count"] > 0:
|
| 423 |
+
task_results["loss"] /= task_results["count"]
|
| 424 |
+
|
| 425 |
+
results[task_name] = task_results
|
| 426 |
+
|
| 427 |
+
return results
|
p_tuning.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
P-Tuning / Prefix Tuning Implementation for NTF
|
| 3 |
+
Parameter-efficient tuning using learnable continuous prompts
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing import Dict, List, Optional, Any, Union
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from enum import Enum
|
| 11 |
+
|
| 12 |
+
from peft import (
|
| 13 |
+
PrefixTuningConfig,
|
| 14 |
+
PromptTuningConfig,
|
| 15 |
+
P_TUNING_TASK_TYPE,
|
| 16 |
+
get_peft_model,
|
| 17 |
+
TaskType,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PTuningMethod(str, Enum):
|
| 22 |
+
"""P-Tuning method types."""
|
| 23 |
+
P_TUNING_V1 = "p_tuning_v1"
|
| 24 |
+
P_TUNING_V2 = "p_tuning_v2"
|
| 25 |
+
PREFIX_TUNING = "prefix_tuning"
|
| 26 |
+
PROMPT_TUNING = "prompt_tuning"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class PTuningConfig:
|
| 31 |
+
"""
|
| 32 |
+
Configuration for P-Tuning / Prefix Tuning.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
method: P-tuning method to use
|
| 36 |
+
num_virtual_tokens: Number of virtual/prompt tokens to add
|
| 37 |
+
token_dim: Dimension of token embeddings
|
| 38 |
+
num_transformer_submodules: Number of transformer submodules
|
| 39 |
+
num_attention_heads: Number of attention heads
|
| 40 |
+
num_layers: Number of transformer layers
|
| 41 |
+
encoder_hidden_size: Hidden size for encoder (P-Tuning v1)
|
| 42 |
+
prefix_projection: Whether to project prefix (Prefix Tuning)
|
| 43 |
+
prompt_tuning_init: Initialization strategy for prompt tuning
|
| 44 |
+
prompt_tuning_init_text: Text for initialization if using text init
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
method: PTuningMethod = PTuningMethod.P_TUNING_V2
|
| 48 |
+
|
| 49 |
+
# Core parameters
|
| 50 |
+
num_virtual_tokens: int = 20
|
| 51 |
+
token_dim: int = 768
|
| 52 |
+
num_transformer_submodules: int = 1
|
| 53 |
+
num_attention_heads: int = 12
|
| 54 |
+
num_layers: int = 12
|
| 55 |
+
|
| 56 |
+
# P-Tuning v1 specific
|
| 57 |
+
encoder_hidden_size: int = 512
|
| 58 |
+
|
| 59 |
+
# Prefix Tuning specific
|
| 60 |
+
prefix_projection: bool = True
|
| 61 |
+
|
| 62 |
+
# Prompt Tuning specific
|
| 63 |
+
prompt_tuning_init: str = "RANDOM" # RANDOM or TEXT
|
| 64 |
+
prompt_tuning_init_text: Optional[str] = None
|
| 65 |
+
|
| 66 |
+
# Task type
|
| 67 |
+
task_type: TaskType = TaskType.CAUSAL_LM
|
| 68 |
+
|
| 69 |
+
def to_peft_config(self):
|
| 70 |
+
"""Convert to appropriate PEFT config based on method."""
|
| 71 |
+
if self.method == PTuningMethod.PREFIX_TUNING:
|
| 72 |
+
return PrefixTuningConfig(
|
| 73 |
+
num_virtual_tokens=self.num_virtual_tokens,
|
| 74 |
+
token_dim=self.token_dim,
|
| 75 |
+
num_attention_heads=self.num_attention_heads,
|
| 76 |
+
num_layers=self.num_layers,
|
| 77 |
+
prefix_projection=self.prefix_projection,
|
| 78 |
+
task_type=self.task_type,
|
| 79 |
+
)
|
| 80 |
+
elif self.method == PTuningMethod.PROMPT_TUNING:
|
| 81 |
+
return PromptTuningConfig(
|
| 82 |
+
num_virtual_tokens=self.num_virtual_tokens,
|
| 83 |
+
token_dim=self.token_dim,
|
| 84 |
+
prompt_tuning_init=self.prompt_tuning_init,
|
| 85 |
+
prompt_tuning_init_text=self.prompt_tuning_init_text,
|
| 86 |
+
task_type=self.task_type,
|
| 87 |
+
)
|
| 88 |
+
else: # P-Tuning v1 or v2
|
| 89 |
+
# P-Tuning uses PrefixTuningConfig with specific settings
|
| 90 |
+
return PrefixTuningConfig(
|
| 91 |
+
num_virtual_tokens=self.num_virtual_tokens,
|
| 92 |
+
token_dim=self.token_dim,
|
| 93 |
+
num_attention_heads=self.num_attention_heads,
|
| 94 |
+
num_layers=self.num_layers,
|
| 95 |
+
encoder_hidden_size=self.encoder_hidden_size,
|
| 96 |
+
prefix_projection=self.method == PTuningMethod.P_TUNING_V1,
|
| 97 |
+
task_type=self.task_type,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class PTuningModel(nn.Module):
|
| 102 |
+
"""
|
| 103 |
+
P-Tuning wrapper for transformer models.
|
| 104 |
+
|
| 105 |
+
Adds learnable continuous prompts to the model input
|
| 106 |
+
without modifying the base model weights.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
base_model: Base transformer model
|
| 110 |
+
config: P-tuning configuration
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, base_model: nn.Module, config: PTuningConfig):
|
| 114 |
+
super().__init__()
|
| 115 |
+
|
| 116 |
+
self.base_model = base_model
|
| 117 |
+
self.config = config
|
| 118 |
+
|
| 119 |
+
# Get model dimensions
|
| 120 |
+
model_config = base_model.config
|
| 121 |
+
self.token_dim = getattr(model_config, 'hidden_size', config.token_dim)
|
| 122 |
+
self.num_layers = getattr(model_config, 'num_hidden_layers', config.num_layers)
|
| 123 |
+
self.num_attention_heads = getattr(model_config, 'num_attention_heads', config.num_attention_heads)
|
| 124 |
+
|
| 125 |
+
# Update config with actual dimensions
|
| 126 |
+
config.token_dim = self.token_dim
|
| 127 |
+
config.num_layers = self.num_layers
|
| 128 |
+
config.num_attention_heads = self.num_attention_heads
|
| 129 |
+
|
| 130 |
+
# Create virtual tokens
|
| 131 |
+
self._create_virtual_tokens()
|
| 132 |
+
|
| 133 |
+
def _create_virtual_tokens(self):
|
| 134 |
+
"""Create learnable virtual token embeddings."""
|
| 135 |
+
method = self.config.method
|
| 136 |
+
|
| 137 |
+
if method == PTuningMethod.PROMPT_TUNING:
|
| 138 |
+
# Simple prompt embeddings
|
| 139 |
+
self.prompt_embeddings = nn.Embedding(
|
| 140 |
+
self.config.num_virtual_tokens,
|
| 141 |
+
self.token_dim
|
| 142 |
+
)
|
| 143 |
+
nn.init.normal_(self.prompt_embeddings.weight, std=0.02)
|
| 144 |
+
|
| 145 |
+
elif method == PTuningMethod.PREFIX_TUNING:
|
| 146 |
+
# Prefix with projection
|
| 147 |
+
self.prefix_tokens = nn.Parameter(
|
| 148 |
+
torch.randn(
|
| 149 |
+
self.num_layers * 2, # key and value for each layer
|
| 150 |
+
self.config.num_virtual_tokens,
|
| 151 |
+
self.token_dim
|
| 152 |
+
)
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
if self.config.prefix_projection:
|
| 156 |
+
self.prefix_proj = nn.Sequential(
|
| 157 |
+
nn.Linear(self.token_dim, self.token_dim),
|
| 158 |
+
nn.ReLU(),
|
| 159 |
+
nn.Linear(self.token_dim, self.num_layers * 2 * self.token_dim)
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
self.prefix_proj = None
|
| 163 |
+
|
| 164 |
+
else: # P-Tuning v1 or v2
|
| 165 |
+
# Encoder for generating prompts
|
| 166 |
+
self.prompt_encoder = nn.Sequential(
|
| 167 |
+
nn.Linear(self.token_dim, self.config.encoder_hidden_size),
|
| 168 |
+
nn.ReLU(),
|
| 169 |
+
nn.Linear(self.config.encoder_hidden_size,
|
| 170 |
+
self.num_layers * 2 * self.config.num_virtual_tokens * self.token_dim)
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Input embedding for prompt encoder
|
| 174 |
+
self.input_embeds = nn.Embedding(self.config.num_virtual_tokens, self.token_dim)
|
| 175 |
+
nn.init.normal_(self.input_embeds.weight, std=0.02)
|
| 176 |
+
|
| 177 |
+
def get_prompt(self, batch_size: int) -> torch.Tensor:
|
| 178 |
+
"""Generate prompt tensors for the current batch."""
|
| 179 |
+
method = self.config.method
|
| 180 |
+
|
| 181 |
+
if method == PTuningMethod.PROMPT_TUNING:
|
| 182 |
+
# Expand prompt embeddings to batch size
|
| 183 |
+
prompts = self.prompt_embeddings.weight.unsqueeze(0).expand(
|
| 184 |
+
batch_size, -1, -1
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
elif method == PTuningMethod.PREFIX_TUNING:
|
| 188 |
+
prefix = self.prefix_tokens
|
| 189 |
+
|
| 190 |
+
if self.prefix_proj is not None:
|
| 191 |
+
prefix = self.prefix_proj(prefix.view(-1, self.token_dim))
|
| 192 |
+
prefix = prefix.view(
|
| 193 |
+
self.num_layers * 2,
|
| 194 |
+
self.config.num_virtual_tokens,
|
| 195 |
+
self.token_dim
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
prompts = prefix.unsqueeze(1).expand(
|
| 199 |
+
-1, batch_size, -1, -1
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
else: # P-Tuning
|
| 203 |
+
input_ids = torch.arange(self.config.num_virtual_tokens).long()
|
| 204 |
+
input_ids = input_ids.unsqueeze(0).expand(batch_size, -1)
|
| 205 |
+
input_embeds = self.input_embeds(input_ids)
|
| 206 |
+
|
| 207 |
+
prompts = self.prompt_encoder(input_embeds)
|
| 208 |
+
prompts = prompts.view(
|
| 209 |
+
batch_size,
|
| 210 |
+
self.num_layers * 2,
|
| 211 |
+
self.config.num_virtual_tokens,
|
| 212 |
+
self.token_dim
|
| 213 |
+
)
|
| 214 |
+
prompts = prompts.permute(1, 0, 2, 3)
|
| 215 |
+
|
| 216 |
+
return prompts
|
| 217 |
+
|
| 218 |
+
def forward(
|
| 219 |
+
self,
|
| 220 |
+
input_ids: torch.Tensor,
|
| 221 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 222 |
+
labels: Optional[torch.Tensor] = None,
|
| 223 |
+
**kwargs
|
| 224 |
+
) -> Dict[str, torch.Tensor]:
|
| 225 |
+
"""
|
| 226 |
+
Forward pass with virtual prompts.
|
| 227 |
+
|
| 228 |
+
Note: This is a simplified implementation. For production use,
|
| 229 |
+
consider using the PEFT library's P-tuning implementation.
|
| 230 |
+
"""
|
| 231 |
+
batch_size = input_ids.size(0)
|
| 232 |
+
|
| 233 |
+
# Get base model outputs
|
| 234 |
+
outputs = self.base_model(
|
| 235 |
+
input_ids=input_ids,
|
| 236 |
+
attention_mask=attention_mask,
|
| 237 |
+
output_hidden_states=True,
|
| 238 |
+
**kwargs
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
result = {"logits": outputs.logits}
|
| 242 |
+
|
| 243 |
+
if hasattr(outputs, 'loss') and outputs.loss is not None:
|
| 244 |
+
result["loss"] = outputs.loss
|
| 245 |
+
|
| 246 |
+
return result
|
| 247 |
+
|
| 248 |
+
def get_trainable_params(self) -> Dict[str, torch.Tensor]:
|
| 249 |
+
"""Get dictionary of trainable parameters (prompts only)."""
|
| 250 |
+
trainable = {}
|
| 251 |
+
|
| 252 |
+
for name, param in self.named_parameters():
|
| 253 |
+
if param.requires_grad:
|
| 254 |
+
trainable[name] = param
|
| 255 |
+
|
| 256 |
+
return trainable
|
| 257 |
+
|
| 258 |
+
def print_trainable_parameters(self):
|
| 259 |
+
"""Print number of trainable vs total parameters."""
|
| 260 |
+
trainable_params = sum(
|
| 261 |
+
p.numel() for p in self.parameters() if p.requires_grad
|
| 262 |
+
)
|
| 263 |
+
all_params = sum(p.numel() for p in self.parameters())
|
| 264 |
+
|
| 265 |
+
print(f"Trainable params: {trainable_params:,} ({100 * trainable_params / all_params:.2f}%)")
|
| 266 |
+
print(f"All params: {all_params:,}")
|
| 267 |
+
print(f"Frozen params: {all_params - trainable_params:,}")
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def setup_p_tuning(
|
| 271 |
+
model: nn.Module,
|
| 272 |
+
method: str = "p_tuning_v2",
|
| 273 |
+
num_virtual_tokens: int = 20,
|
| 274 |
+
task_type: str = "CAUSAL_LM"
|
| 275 |
+
) -> nn.Module:
|
| 276 |
+
"""
|
| 277 |
+
Setup P-Tuning on a model using PEFT.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
model: Base model to apply P-Tuning to
|
| 281 |
+
method: P-tuning method (p_tuning_v1, p_tuning_v2, prefix_tuning, prompt_tuning)
|
| 282 |
+
num_virtual_tokens: Number of virtual tokens
|
| 283 |
+
task_type: PEFT task type
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
Model with P-Tuning applied
|
| 287 |
+
"""
|
| 288 |
+
config = PTuningConfig(
|
| 289 |
+
method=PTuningMethod(method),
|
| 290 |
+
num_virtual_tokens=num_virtual_tokens,
|
| 291 |
+
task_type=TaskType(task_type)
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
peft_config = config.to_peft_config()
|
| 295 |
+
return get_peft_model(model, peft_config)
|
test_tutorial_examples.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test Suite for Tutorial Code Examples
|
| 3 |
+
Ensures all code examples in tutorials remain functional
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import torch
|
| 10 |
+
from datasets import Dataset
|
| 11 |
+
|
| 12 |
+
# Add project root to path
|
| 13 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestTutorial03:
|
| 17 |
+
"""Test Tutorial 03: Full Fine-Tuning examples"""
|
| 18 |
+
|
| 19 |
+
def test_full_finetuning_basic(self):
|
| 20 |
+
"""Test basic full fine-tuning workflow"""
|
| 21 |
+
from ntf.config import NTFConfig, ModelConfig, TrainingConfig
|
| 22 |
+
from ntf.models import ModelRegistry
|
| 23 |
+
from ntf.finetuning import FullFinetuneTrainer
|
| 24 |
+
|
| 25 |
+
config = NTFConfig(
|
| 26 |
+
model=ModelConfig(name="facebook/opt-125m"),
|
| 27 |
+
training=TrainingConfig(
|
| 28 |
+
output_dir="./test_output",
|
| 29 |
+
num_train_epochs=1,
|
| 30 |
+
per_device_train_batch_size=2,
|
| 31 |
+
)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
registry = ModelRegistry(config.model)
|
| 35 |
+
model, tokenizer = registry.load_model_and_tokenizer()
|
| 36 |
+
|
| 37 |
+
train_data = Dataset.from_dict({
|
| 38 |
+
"text": ["Hello world", "Test sentence"] * 10
|
| 39 |
+
})
|
| 40 |
+
|
| 41 |
+
trainer = FullFinetuneTrainer(
|
| 42 |
+
model=model,
|
| 43 |
+
config=config.training,
|
| 44 |
+
train_dataset=train_data,
|
| 45 |
+
tokenizer=tokenizer
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
trainer.train()
|
| 49 |
+
|
| 50 |
+
assert os.path.exists("./test_output")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TestTutorial05:
|
| 54 |
+
"""Test Tutorial 05: PEFT/LoRA examples"""
|
| 55 |
+
|
| 56 |
+
def test_lora_setup(self):
|
| 57 |
+
"""Test LoRA adapter setup"""
|
| 58 |
+
from ntf.finetuning import LoRAConfig, PEFTTrainer
|
| 59 |
+
from ntf.models import ModelRegistry
|
| 60 |
+
|
| 61 |
+
registry = ModelRegistry("facebook/opt-125m")
|
| 62 |
+
model, tokenizer = registry.load_model_and_tokenizer()
|
| 63 |
+
|
| 64 |
+
lora_config = LoRAConfig(
|
| 65 |
+
r=8,
|
| 66 |
+
alpha=16,
|
| 67 |
+
dropout=0.05,
|
| 68 |
+
target_modules=["q_proj", "v_proj"],
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
trainer = PEFTTrainer(model, lora_config, tokenizer)
|
| 72 |
+
|
| 73 |
+
trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
|
| 74 |
+
all_params = sum(p.numel() for p in trainer.model.parameters())
|
| 75 |
+
|
| 76 |
+
assert trainable_params < all_params
|
| 77 |
+
assert trainable_params > 0
|
| 78 |
+
|
| 79 |
+
def test_p_tuning_setup(self):
|
| 80 |
+
"""Test P-Tuning setup"""
|
| 81 |
+
from ntf.finetuning import PTuningConfig, PTuningMethod, setup_p_tuning
|
| 82 |
+
from ntf.models import ModelRegistry
|
| 83 |
+
|
| 84 |
+
registry = ModelRegistry("facebook/opt-125m")
|
| 85 |
+
model, tokenizer = registry.load_model_and_tokenizer()
|
| 86 |
+
|
| 87 |
+
config = PTuningConfig(
|
| 88 |
+
method=PTuningMethod.P_TUNING_V2,
|
| 89 |
+
num_virtual_tokens=20,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
peft_model = setup_p_tuning(model, method="p_tuning_v2", num_virtual_tokens=20)
|
| 93 |
+
|
| 94 |
+
assert peft_model is not None
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class TestTutorial04:
|
| 98 |
+
"""Test Tutorial 04: Continual Learning examples"""
|
| 99 |
+
|
| 100 |
+
def test_ewc_regularization(self):
|
| 101 |
+
"""Test EWC regularization setup"""
|
| 102 |
+
from ntf.utils import EWCConfig, EWCRegularizer, ContinualLearningWrapper
|
| 103 |
+
from ntf.models import ModelRegistry
|
| 104 |
+
|
| 105 |
+
registry = ModelRegistry("facebook/opt-125m")
|
| 106 |
+
model, tokenizer = registry.load_model_and_tokenizer()
|
| 107 |
+
|
| 108 |
+
ewc_config = EWCConfig(ewc_lambda=1000.0)
|
| 109 |
+
ewc = EWCRegularizer(model, ewc_config)
|
| 110 |
+
|
| 111 |
+
assert ewc is not None
|
| 112 |
+
assert ewc.config.ewc_lambda == 1000.0
|
| 113 |
+
|
| 114 |
+
def test_si_regularization(self):
|
| 115 |
+
"""Test Synaptic Intelligence regularization"""
|
| 116 |
+
from ntf.utils import SIRegularizer, ContinualLearningWrapper
|
| 117 |
+
from ntf.models import ModelRegistry
|
| 118 |
+
|
| 119 |
+
registry = ModelRegistry("facebook/opt-125m")
|
| 120 |
+
model, _ = registry.load_model_and_tokenizer()
|
| 121 |
+
|
| 122 |
+
wrapper = ContinualLearningWrapper(model, method="si")
|
| 123 |
+
wrapper.apply_si_regularization(c=0.1)
|
| 124 |
+
|
| 125 |
+
assert wrapper.si is not None
|
| 126 |
+
assert wrapper.si.c == 0.1
|
| 127 |
+
|
| 128 |
+
def test_lwf_regularization(self):
|
| 129 |
+
"""Test Learning without Forgetting"""
|
| 130 |
+
from ntf.utils import LwFRegularizer, ContinualLearningWrapper
|
| 131 |
+
from ntf.models import ModelRegistry
|
| 132 |
+
|
| 133 |
+
registry = ModelRegistry("facebook/opt-125m")
|
| 134 |
+
model, _ = registry.load_model_and_tokenizer()
|
| 135 |
+
|
| 136 |
+
wrapper = ContinualLearningWrapper(model, method="lwf")
|
| 137 |
+
wrapper.apply_lwf_regularization(alpha=0.5)
|
| 138 |
+
|
| 139 |
+
assert wrapper.lwf is not None
|
| 140 |
+
assert wrapper.lwf.alpha == 0.1
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class TestMultiTask:
|
| 144 |
+
"""Test Multi-Task Learning (Spec 4.1.1)"""
|
| 145 |
+
|
| 146 |
+
def test_multi_task_model_creation(self):
|
| 147 |
+
"""Test creating multi-task model with multiple heads"""
|
| 148 |
+
from ntf.finetuning import MultiTaskModel, TaskType, MultiTaskTrainer
|
| 149 |
+
from ntf.models import ModelRegistry
|
| 150 |
+
|
| 151 |
+
registry = ModelRegistry("facebook/opt-125m")
|
| 152 |
+
base_model, tokenizer = registry.load_model_and_tokenizer()
|
| 153 |
+
|
| 154 |
+
model = MultiTaskModel(base_model=base_model)
|
| 155 |
+
|
| 156 |
+
model.add_task_head(
|
| 157 |
+
task_name="classification",
|
| 158 |
+
head_type=TaskType.CLASSIFICATION,
|
| 159 |
+
config={"num_labels": 5}
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
model.add_task_head(
|
| 163 |
+
task_name="summarization",
|
| 164 |
+
head_type=TaskType.SEQUENCE_TO_SEQUENCE,
|
| 165 |
+
config={"max_length": 512}
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
assert model.get_num_tasks() == 2
|
| 169 |
+
assert "classification" in model.list_tasks()
|
| 170 |
+
assert "summarization" in model.list_tasks()
|
| 171 |
+
|
| 172 |
+
def test_multi_task_forward(self):
|
| 173 |
+
"""Test forward pass through multi-task model"""
|
| 174 |
+
from ntf.finetuning import MultiTaskModel, TaskType
|
| 175 |
+
from ntf.models import ModelRegistry
|
| 176 |
+
import torch
|
| 177 |
+
|
| 178 |
+
registry = ModelRegistry("facebook/opt-125m")
|
| 179 |
+
base_model, tokenizer = registry.load_model_and_tokenizer()
|
| 180 |
+
|
| 181 |
+
model = MultiTaskModel(base_model=base_model)
|
| 182 |
+
model.add_task_head(
|
| 183 |
+
task_name="classification",
|
| 184 |
+
head_type=TaskType.CLASSIFICATION,
|
| 185 |
+
config={"num_labels": 3}
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
input_ids = torch.randint(0, 1000, (2, 10))
|
| 189 |
+
attention_mask = torch.ones((2, 10))
|
| 190 |
+
|
| 191 |
+
output = model(
|
| 192 |
+
input_ids=input_ids,
|
| 193 |
+
attention_mask=attention_mask,
|
| 194 |
+
task_name="classification"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
assert "logits" in output
|
| 198 |
+
assert output["logits"].shape[0] == 2
|
| 199 |
+
assert output["logits"].shape[1] == 3
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class TestContinualLearningWrapper:
|
| 203 |
+
"""Test ContinualLearningWrapper API (Spec 4.1.2)"""
|
| 204 |
+
|
| 205 |
+
def test_wrapper_api(self):
|
| 206 |
+
"""Test the unified ContinualLearningWrapper API"""
|
| 207 |
+
from ntf.utils import ContinualLearningWrapper
|
| 208 |
+
from ntf.models import ModelRegistry
|
| 209 |
+
|
| 210 |
+
registry = ModelRegistry("facebook/opt-125m")
|
| 211 |
+
model, _ = registry.load_model_and_tokenizer()
|
| 212 |
+
|
| 213 |
+
wrapper = ContinualLearningWrapper(model, method="ewc")
|
| 214 |
+
|
| 215 |
+
# Test EWC
|
| 216 |
+
wrapper.apply_ewc_regularization(lambda_ewc=0.5)
|
| 217 |
+
assert wrapper.ewc is not None
|
| 218 |
+
|
| 219 |
+
# Test SI
|
| 220 |
+
wrapper2 = ContinualLearningWrapper(model, method="si")
|
| 221 |
+
wrapper2.apply_si_regularization(c=0.1)
|
| 222 |
+
assert wrapper2.si is not None
|
| 223 |
+
|
| 224 |
+
# Test LwF
|
| 225 |
+
wrapper3 = ContinualLearningWrapper(model, method="lwf")
|
| 226 |
+
wrapper3.apply_lwf_regularization(alpha=0.5)
|
| 227 |
+
assert wrapper3.lwf is not None
|
| 228 |
+
|
| 229 |
+
def test_progressive_unfreeze(self):
|
| 230 |
+
"""Test progressive unfreezing strategy"""
|
| 231 |
+
from ntf.utils import ContinualLearningWrapper
|
| 232 |
+
from ntf.models import ModelRegistry
|
| 233 |
+
|
| 234 |
+
registry = ModelRegistry("facebook/opt-125m")
|
| 235 |
+
model, _ = registry.load_model_and_tokenizer()
|
| 236 |
+
|
| 237 |
+
wrapper = ContinualLearningWrapper(model)
|
| 238 |
+
wrapper.progressive_unfreeze(
|
| 239 |
+
start_layers=4,
|
| 240 |
+
unfreeze_every_n_epochs=2,
|
| 241 |
+
max_layers=12
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
assert hasattr(wrapper, 'start_layers')
|
| 245 |
+
assert wrapper.start_layers == 4
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
pytest.main([__file__, "-v"])
|
utils/__init__.py
CHANGED
|
@@ -1,81 +1,47 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Utilities package for Nexuss Transformer Framework
|
| 3 |
-
"""
|
| 4 |
|
| 5 |
-
from .
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
ModelRegistry,
|
| 25 |
-
create_model_metadata,
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
from .metrics import (
|
| 29 |
-
EvaluationResults,
|
| 30 |
-
compute_perplexity,
|
| 31 |
-
compute_accuracy,
|
| 32 |
-
evaluate_model,
|
| 33 |
-
benchmark_throughput,
|
| 34 |
-
compare_models,
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
from .logging import (
|
| 38 |
-
setup_logging,
|
| 39 |
-
get_logger,
|
| 40 |
-
set_log_level,
|
| 41 |
-
DebugLogger,
|
| 42 |
-
validate_config,
|
| 43 |
)
|
| 44 |
|
| 45 |
__all__ = [
|
| 46 |
-
|
| 47 |
-
"
|
| 48 |
-
"
|
| 49 |
-
"
|
| 50 |
-
"
|
| 51 |
-
"
|
| 52 |
-
"
|
| 53 |
-
|
| 54 |
-
"
|
| 55 |
-
"
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
"
|
| 62 |
-
"
|
| 63 |
-
|
| 64 |
-
"
|
| 65 |
-
"
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
"EvaluationResults",
|
| 69 |
-
"compute_perplexity",
|
| 70 |
-
"compute_accuracy",
|
| 71 |
-
"evaluate_model",
|
| 72 |
-
"benchmark_throughput",
|
| 73 |
-
"compare_models",
|
| 74 |
-
|
| 75 |
-
# Logging
|
| 76 |
-
"setup_logging",
|
| 77 |
-
"get_logger",
|
| 78 |
-
"set_log_level",
|
| 79 |
-
"DebugLogger",
|
| 80 |
-
"validate_config",
|
| 81 |
]
|
|
|
|
| 1 |
+
"""Finetuning package - PEFT, LoRA, and layer freezing utilities."""
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
from finetuning.peft_finetune import PEFTTrainer, LoRAConfig, setup_lora
|
| 4 |
+
from finetuning.freeze import LayerFreezer, freeze_layers
|
| 5 |
+
from finetuning.full_finetune import FullFinetuneTrainer, full_finetune
|
| 6 |
+
from finetuning.multi_task import (
|
| 7 |
+
MultiTaskModel,
|
| 8 |
+
MultiTaskTrainer,
|
| 9 |
+
TaskHead,
|
| 10 |
+
TaskType,
|
| 11 |
+
TaskHeadConfig,
|
| 12 |
+
ClassificationHead,
|
| 13 |
+
SequenceToSequenceHead,
|
| 14 |
+
TokenClassificationHead,
|
| 15 |
+
QuestionAnsweringHead,
|
| 16 |
)
|
| 17 |
+
from finetuning.p_tuning import (
|
| 18 |
+
PTuningModel,
|
| 19 |
+
PTuningConfig,
|
| 20 |
+
PTuningMethod,
|
| 21 |
+
setup_p_tuning,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
)
|
| 23 |
|
| 24 |
__all__ = [
|
| 25 |
+
"PEFTTrainer",
|
| 26 |
+
"LoRAConfig",
|
| 27 |
+
"setup_lora",
|
| 28 |
+
"LayerFreezer",
|
| 29 |
+
"freeze_layers",
|
| 30 |
+
"FullFinetuneTrainer",
|
| 31 |
+
"full_finetune",
|
| 32 |
+
# Multi-task learning
|
| 33 |
+
"MultiTaskModel",
|
| 34 |
+
"MultiTaskTrainer",
|
| 35 |
+
"TaskHead",
|
| 36 |
+
"TaskType",
|
| 37 |
+
"TaskHeadConfig",
|
| 38 |
+
"ClassificationHead",
|
| 39 |
+
"SequenceToSequenceHead",
|
| 40 |
+
"TokenClassificationHead",
|
| 41 |
+
"QuestionAnsweringHead",
|
| 42 |
+
# P-Tuning / Prefix Tuning
|
| 43 |
+
"PTuningModel",
|
| 44 |
+
"PTuningConfig",
|
| 45 |
+
"PTuningMethod",
|
| 46 |
+
"setup_p_tuning",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
]
|