Nexuss0781 commited on
Commit
f42d9a1
·
1 Parent(s): 5a593b3

Phase 4: Add multi-task learning, P-Tuning, SI/LwF continual learning, automated tests, deployment templates

Browse files
Files changed (5) hide show
  1. continual_learning.py +589 -0
  2. multi_task.py +427 -0
  3. p_tuning.py +295 -0
  4. test_tutorial_examples.py +249 -0
  5. utils/__init__.py +41 -75
continual_learning.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Continual Learning Utilities for Nexuss Transformer Framework
3
+ Mechanisms to avoid catastrophic forgetting during continuous training
4
+ """
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from dataclasses import dataclass, field
10
+ from typing import Optional, List, Dict, Any, Tuple
11
+ from collections import OrderedDict
12
+ import copy
13
+
14
+
15
+ @dataclass
16
+ class EWCConfig:
17
+ """Configuration for Elastic Weight Consolidation"""
18
+
19
+ ewc_lambda: float = 1000.0 # Strength of EWC regularization
20
+ fisher_samples: int = 200 # Number of samples to estimate Fisher information
21
+ damping: float = 0.1 # Damping factor for Fisher matrix
22
+ mc_samples: int = 1 # Monte Carlo samples for Fisher estimation
23
+
24
+
25
+ @dataclass
26
+ class ReplayConfig:
27
+ """Configuration for Experience Replay"""
28
+
29
+ replay_size: int = 1000 # Size of replay buffer
30
+ replay_ratio: float = 0.5 # Ratio of replay data in each batch
31
+ selection_strategy: str = "uniform" # uniform, recent, diverse
32
+ reservoir_sampling: bool = True # Use reservoir sampling for streaming data
33
+
34
+
35
+ @dataclass
36
+ class GEMConfig:
37
+ """Configuration for Gradient Episodic Memory"""
38
+
39
+ memory_size: int = 100 # Number of examples per task
40
+ num_tasks: int = 5 # Expected number of tasks
41
+ use_quadprog: bool = True # Use quadratic programming for constraint solving
42
+
43
+
44
+ @dataclass
45
+ class ContinualLearningConfig:
46
+ """Unified configuration for continual learning strategies"""
47
+
48
+ strategy: str = "none" # none, ewc, replay, gem, lwf
49
+ ewc: Optional[EWCConfig] = field(default_factory=EWCConfig)
50
+ replay: Optional[ReplayConfig] = field(default_factory=ReplayConfig)
51
+ gem: Optional[GEMConfig] = field(default_factory=GEMConfig)
52
+
53
+ # LwF (Learning without Forgetting) settings
54
+ lwf_alpha: float = 1.0 # Distillation loss weight
55
+ lwf_temperature: float = 2.0 # Temperature for knowledge distillation
56
+
57
+ # Regularization
58
+ weight_decay: float = 0.01
59
+ grad_clip: float = 1.0
60
+
61
+
62
+ class EWCRegularizer:
63
+ """Elastic Weight Consolidation implementation"""
64
+
65
+ def __init__(self, model: nn.Module, config: EWCConfig):
66
+ self.model = model
67
+ self.config = config
68
+ self.fisher: Dict[str, torch.Tensor] = {}
69
+ self.optimal_params: Dict[str, torch.Tensor] = {}
70
+
71
+ def compute_fisher(self, dataloader: DataLoader, device: torch.device):
72
+ """Compute Fisher Information Matrix diagonal approximation"""
73
+
74
+ self.model.train()
75
+ fisher_dict = {name: torch.zeros_like(param)
76
+ for name, param in self.model.named_parameters()
77
+ if param.requires_grad}
78
+
79
+ samples_processed = 0
80
+
81
+ for batch in dataloader:
82
+ if samples_processed >= self.config.fisher_samples:
83
+ break
84
+
85
+ self.model.zero_grad()
86
+
87
+ # Forward pass
88
+ inputs = batch["input_ids"].to(device) if isinstance(batch, dict) else batch.to(device)
89
+ outputs = self.model(inputs)
90
+
91
+ # Compute log-likelihood gradient
92
+ log_probs = torch.log_softmax(outputs.logits, dim=-1)
93
+ loss = log_probs.mean()
94
+
95
+ # Compute gradients
96
+ grads = torch.autograd.grad(loss, [p for p in self.model.parameters() if p.requires_grad],
97
+ retain_graph=False)
98
+
99
+ # Accumulate squared gradients (Fisher diagonal)
100
+ for (name, _), grad in zip(self.model.named_parameters(), grads):
101
+ if name in fisher_dict:
102
+ fisher_dict[name] += grad.pow(2)
103
+
104
+ samples_processed += inputs.size(0)
105
+
106
+ # Average and store
107
+ n_samples = max(samples_processed, 1)
108
+ self.fisher = {name: tensor / n_samples + self.config.damping
109
+ for name, tensor in fisher_dict.items()}
110
+
111
+ # Store optimal parameters
112
+ self.optimal_params = {name: param.clone().detach()
113
+ for name, param in self.model.named_parameters()
114
+ if param.requires_grad}
115
+
116
+ def compute_ewc_loss(self) -> torch.Tensor:
117
+ """Compute EWC regularization loss"""
118
+
119
+ if not self.fisher or not self.optimal_params:
120
+ return torch.tensor(0.0)
121
+
122
+ ewc_loss = torch.tensor(0.0)
123
+
124
+ for name, param in self.model.named_parameters():
125
+ if param.requires_grad and name in self.fisher:
126
+ delta = param - self.optimal_params[name]
127
+ ewc_loss += (self.fisher[name] * delta.pow(2)).sum()
128
+
129
+ return self.config.ewc_lambda * ewc_loss
130
+
131
+
132
+ class ReplayBuffer:
133
+ """Experience Replay Buffer for continual learning"""
134
+
135
+ def __init__(self, config: ReplayConfig):
136
+ self.config = config
137
+ self.buffer: List[Dict[str, Any]] = []
138
+ self.task_data: Dict[int, List[Dict[str, Any]]] = {}
139
+
140
+ def add(self, samples: List[Dict[str, Any]], task_id: Optional[int] = None):
141
+ """Add samples to replay buffer"""
142
+
143
+ if self.config.reservoir_sampling and len(self.buffer) + len(samples) > self.config.replay_size:
144
+ # Reservoir sampling for streaming data
145
+ for sample in samples:
146
+ if len(self.buffer) < self.config.replay_size:
147
+ self.buffer.append(sample)
148
+ else:
149
+ # Randomly replace with decreasing probability
150
+ j = torch.randint(0, len(self.buffer) + 1, (1,)).item()
151
+ if j < self.config.replay_size:
152
+ self.buffer[j] = sample
153
+ else:
154
+ self.buffer.extend(samples)
155
+
156
+ # Trim if exceeds size
157
+ if len(self.buffer) > self.config.replay_size:
158
+ if self.config.selection_strategy == "recent":
159
+ self.buffer = self.buffer[-self.config.replay_size:]
160
+ elif self.config.selection_strategy == "diverse":
161
+ # Simple diversity: keep every nth item
162
+ step = len(self.buffer) // self.config.replay_size
163
+ self.buffer = self.buffer[::step][:self.config.replay_size]
164
+ else: # uniform
165
+ indices = torch.randperm(len(self.buffer))[:self.config.replay_size]
166
+ self.buffer = [self.buffer[i] for i in indices]
167
+
168
+ # Store by task if task_id provided
169
+ if task_id is not None:
170
+ if task_id not in self.task_data:
171
+ self.task_data[task_id] = []
172
+ self.task_data[task_id].extend(samples)
173
+
174
+ def get_batch(self, current_batch: Dict[str, Any]) -> Dict[str, Any]:
175
+ """Mix replay data with current batch"""
176
+
177
+ if not self.buffer:
178
+ return current_batch
179
+
180
+ replay_size = int(current_batch["input_ids"].size(0) * self.config.replay_ratio)
181
+ replay_size = min(replay_size, len(self.buffer))
182
+
183
+ if replay_size == 0:
184
+ return current_batch
185
+
186
+ # Sample from buffer
187
+ indices = torch.randperm(len(self.buffer))[:replay_size]
188
+ replay_samples = [self.buffer[i] for i in indices]
189
+
190
+ # Combine with current batch (simplified - in practice need proper merging)
191
+ # This is a placeholder - actual implementation depends on your data format
192
+ return current_batch # TODO: Implement proper batch merging
193
+
194
+ def get_task_buffer(self, task_id: int) -> List[Dict[str, Any]]:
195
+ """Get replay buffer for specific task"""
196
+ return self.task_data.get(task_id, [])
197
+
198
+
199
+ class GEMOptimizer:
200
+ """Gradient Episodic Memory optimizer"""
201
+
202
+ def __init__(self, model: nn.Module, config: GEMConfig):
203
+ self.model = model
204
+ self.config = config
205
+ self.memory: Dict[int, List[Dict[str, Any]]] = {i: [] for i in range(config.num_tasks)}
206
+ self.gradient_memory: Dict[int, torch.Tensor] = {}
207
+
208
+ def store_in_memory(self, samples: List[Dict[str, Any]], task_id: int):
209
+ """Store samples in task-specific memory"""
210
+
211
+ available_space = self.config.memory_size - len(self.memory[task_id])
212
+
213
+ if available_space >= len(samples):
214
+ self.memory[task_id].extend(samples)
215
+ else:
216
+ # Random subsample
217
+ indices = torch.randperm(len(samples))[:available_space]
218
+ self.memory[task_id].extend([samples[i] for i in indices])
219
+
220
+ def compute_gradient_constraints(self, task_id: int, device: torch.device) -> List[torch.Tensor]:
221
+ """Compute stored gradients for previous tasks"""
222
+
223
+ constraints = []
224
+
225
+ for prev_task_id in range(task_id):
226
+ if prev_task_id not in self.gradient_memory:
227
+ continue
228
+
229
+ constraints.append(self.gradient_memory[prev_task_id])
230
+
231
+ return constraints
232
+
233
+ def project_gradient(self, gradient: torch.Tensor, constraints: List[torch.Tensor]) -> torch.Tensor:
234
+ """Project gradient to satisfy memory constraints using quadratic programming"""
235
+
236
+ if not constraints:
237
+ return gradient
238
+
239
+ projected = gradient.clone()
240
+
241
+ for constraint in constraints:
242
+ # Check if gradient violates constraint
243
+ dot_product = torch.dot(projected.flatten(), constraint.flatten())
244
+
245
+ if dot_product < 0:
246
+ # Project gradient
247
+ norm_sq = constraint.pow(2).sum()
248
+ if norm_sq > 1e-8:
249
+ projection_coef = dot_product / norm_sq
250
+ projected -= projection_coef * constraint
251
+
252
+ return projected
253
+
254
+ def update_gradient_memory(self, task_id: int, dataloader: DataLoader, device: torch.device):
255
+ """Update stored gradients for current task"""
256
+
257
+ self.model.eval()
258
+
259
+ # Compute average gradient over memory samples
260
+ total_gradient = None
261
+ count = 0
262
+
263
+ for batch in dataloader:
264
+ self.model.zero_grad()
265
+
266
+ inputs = batch["input_ids"].to(device) if isinstance(batch, dict) else batch.to(device)
267
+ outputs = self.model(inputs)
268
+
269
+ loss = outputs.loss if hasattr(outputs, 'loss') else outputs.logits.mean()
270
+
271
+ grads = torch.autograd.grad(loss, [p for p in self.model.parameters() if p.requires_grad])
272
+
273
+ # Flatten and concatenate all gradients
274
+ flat_grad = torch.cat([g.flatten() for g in grads])
275
+
276
+ if total_gradient is None:
277
+ total_gradient = flat_grad
278
+ else:
279
+ total_gradient += flat_grad
280
+
281
+ count += 1
282
+
283
+ if count > 0 and total_gradient is not None:
284
+ self.gradient_memory[task_id] = total_gradient / count
285
+
286
+
287
+ class LwFLoss(nn.Module):
288
+ """Learning without Forgetting loss using knowledge distillation"""
289
+
290
+ def __init__(self, config: ContinualLearningConfig):
291
+ super().__init__()
292
+ self.config = config
293
+ self.kl_div = nn.KLDivLoss(reduction='batchmean')
294
+
295
+ def forward(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor) -> torch.Tensor:
296
+ """Compute LwF distillation loss"""
297
+
298
+ T = self.config.lwf_temperature
299
+
300
+ # Apply temperature scaling
301
+ student_log_probs = torch.log_softmax(student_logits / T, dim=-1)
302
+ teacher_probs = torch.softmax(teacher_logits / T, dim=-1)
303
+
304
+ # Knowledge distillation loss
305
+ kd_loss = self.kl_div(student_log_probs, teacher_probs) * (T ** 2)
306
+
307
+ return self.config.lwf_alpha * kd_loss
308
+
309
+
310
+ class SIRegularizer:
311
+ """Synaptic Intelligence implementation for continual learning."""
312
+
313
+ def __init__(self, model: nn.Module, c: float = 0.1):
314
+ self.model = model
315
+ self.c = c # Importance weight
316
+ self.importance: Dict[str, torch.Tensor] = {}
317
+ self.prev_params: Dict[str, torch.Tensor] = {}
318
+ self.trajectory: Dict[str, torch.Tensor] = {}
319
+
320
+ def initialize_trajectory(self):
321
+ """Initialize trajectory tracking for parameters."""
322
+ self.prev_params = {
323
+ name: param.clone().detach()
324
+ for name, param in self.model.named_parameters()
325
+ if param.requires_grad
326
+ }
327
+ self.trajectory = {
328
+ name: torch.zeros_like(param)
329
+ for name, param in self.model.named_parameters()
330
+ if param.requires_grad
331
+ }
332
+ self.importance = {
333
+ name: torch.zeros_like(param)
334
+ for name, param in self.model.named_parameters()
335
+ if param.requires_grad
336
+ }
337
+
338
+ def update_trajectory(self):
339
+ """Update parameter change trajectory after each training step."""
340
+ with torch.no_grad():
341
+ for name, param in self.model.named_parameters():
342
+ if param.requires_grad and name in self.prev_params:
343
+ delta = param - self.prev_params[name]
344
+ self.trajectory[name] += delta.pow(2)
345
+ self.prev_params[name] = param.clone().detach()
346
+
347
+ def compute_importance(self, loss_change: float):
348
+ """
349
+ Compute parameter importance based on loss change.
350
+
351
+ Args:
352
+ loss_change: Change in loss from previous iteration
353
+ """
354
+ if loss_change > 0: # Only update if loss decreased
355
+ for name in self.importance:
356
+ if name in self.trajectory:
357
+ denom = self.trajectory[name] + 1e-8
358
+ self.importance[name] += loss_change / denom
359
+
360
+ def compute_si_loss(self) -> torch.Tensor:
361
+ """Compute Synaptic Intelligence regularization loss."""
362
+ si_loss = torch.tensor(0.0)
363
+
364
+ for name, param in self.model.named_parameters():
365
+ if param.requires_grad and name in self.importance:
366
+ delta = param - self.prev_params.get(name, param)
367
+ si_loss += (self.importance[name] * delta.pow(2)).sum()
368
+
369
+ return self.c * si_loss
370
+
371
+
372
+ class LwFRegularizer(nn.Module):
373
+ """Learning without Forgetting using knowledge distillation."""
374
+
375
+ def __init__(self, alpha: float = 0.5, temperature: float = 2.0):
376
+ super().__init__()
377
+ self.alpha = alpha # Distillation loss weight
378
+ self.temperature = temperature
379
+ self.kl_div = nn.KLDivLoss(reduction='batchmean')
380
+ self.old_outputs: Dict[str, torch.Tensor] = {}
381
+
382
+ def store_old_outputs(self, task_name: str, outputs: torch.Tensor):
383
+ """Store outputs from old model for distillation."""
384
+ self.old_outputs[task_name] = outputs.detach()
385
+
386
+ def clear_old_outputs(self):
387
+ """Clear stored old outputs."""
388
+ self.old_outputs.clear()
389
+
390
+ def forward(
391
+ self,
392
+ student_logits: torch.Tensor,
393
+ teacher_logits: torch.Tensor,
394
+ task_name: Optional[str] = None
395
+ ) -> torch.Tensor:
396
+ """
397
+ Compute LwF distillation loss.
398
+
399
+ Args:
400
+ student_logits: Current model logits
401
+ teacher_logits: Old model logits (stored or provided)
402
+ task_name: Optional task name for stored outputs
403
+
404
+ Returns:
405
+ Knowledge distillation loss
406
+ """
407
+ T = self.temperature
408
+
409
+ # Apply temperature scaling
410
+ student_log_probs = torch.log_softmax(student_logits / T, dim=-1)
411
+ teacher_probs = torch.softmax(teacher_logits / T, dim=-1)
412
+
413
+ # Knowledge distillation loss
414
+ kd_loss = self.kl_div(student_log_probs, teacher_probs) * (T ** 2)
415
+
416
+ return self.alpha * kd_loss
417
+
418
+
419
+ def create_continual_learning_wrapper(trainer, config: ContinualLearningConfig):
420
+ """
421
+ Wrap existing trainer with continual learning capabilities.
422
+ Returns modified trainer with CL methods integrated.
423
+ """
424
+
425
+ if config.strategy == "ewc":
426
+ trainer.ewc_regularizer = EWCRegularizer(trainer.model, config.ewc)
427
+
428
+ # Hook into training loop to add EWC loss
429
+ original_compute_loss = trainer.compute_loss
430
+
431
+ def compute_loss_with_ewc(model, inputs, return_outputs=False):
432
+ loss = original_compute_loss(model, inputs, return_outputs)
433
+ ewc_loss = trainer.ewc_regularizer.compute_ewc_loss()
434
+
435
+ if return_outputs:
436
+ return loss + ewc_loss, outputs
437
+ return loss + ewc_loss
438
+
439
+ trainer.compute_loss = compute_loss_with_ewc
440
+
441
+ elif config.strategy == "replay":
442
+ trainer.replay_buffer = ReplayBuffer(config.replay)
443
+
444
+ # Modify data loading to include replay
445
+ # Implementation depends on trainer's data loading mechanism
446
+
447
+ elif config.strategy == "gem":
448
+ trainer.gem_optimizer = GEMOptimizer(trainer.model, config.gem)
449
+
450
+ # Hook into optimization step to project gradients
451
+ # Implementation depends on trainer's optimization loop
452
+
453
+ elif config.strategy == "lwf":
454
+ trainer.lwf_loss = LwFLoss(config)
455
+ # Store teacher model outputs for distillation
456
+ # Implementation depends on training setup
457
+
458
+ elif config.strategy == "si":
459
+ trainer.si_regularizer = SIRegularizer(trainer.model, c=config.weight_decay)
460
+
461
+ # Initialize trajectory tracking
462
+ trainer.si_regularizer.initialize_trajectory()
463
+
464
+ # Hook into training loop
465
+ original_compute_loss = trainer.compute_loss
466
+
467
+ def compute_loss_with_si(model, inputs, return_outputs=False):
468
+ loss = original_compute_loss(model, inputs, return_outputs)
469
+ si_loss = trainer.si_regularizer.compute_si_loss()
470
+
471
+ if return_outputs:
472
+ return loss + si_loss, outputs
473
+ return loss + si_loss
474
+
475
+ trainer.compute_loss = compute_loss_with_si
476
+
477
+ # Hook into optimizer step to update trajectory
478
+ original_step = trainer.optimizer.step if hasattr(trainer, 'optimizer') else None
479
+ if original_step:
480
+ def step_with_trajectory():
481
+ original_step()
482
+ trainer.si_regularizer.update_trajectory()
483
+ trainer.optimizer.step = step_with_trajectory
484
+
485
+ return trainer
486
+
487
+
488
+ class ContinualLearningWrapper:
489
+ """
490
+ High-level wrapper for applying continual learning methods.
491
+
492
+ Provides a unified API for EWC, SI, and LwF regularization.
493
+
494
+ Args:
495
+ model: Model to wrap
496
+ method: Continual learning method (ewc, si, lwf)
497
+ """
498
+
499
+ def __init__(self, model: nn.Module, method: str = "ewc"):
500
+ self.model = model
501
+ self.method = method
502
+ self.ewc = None
503
+ self.si = None
504
+ self.lwf = None
505
+
506
+ if method == "ewc":
507
+ self.ewc = EWCRegularizer(model, EWCConfig())
508
+ elif method == "si":
509
+ self.si = SIRegularizer(model)
510
+ self.si.initialize_trajectory()
511
+ elif method == "lwf":
512
+ self.lwf = LwFRegularizer()
513
+
514
+ def apply_ewc_regularization(self, lambda_ewc: float = 0.5):
515
+ """Apply Elastic Weight Consolidation regularization."""
516
+ if self.ewc is None:
517
+ self.ewc = EWCRegularizer(self.model, EWCConfig(ewc_lambda=lambda_ewc))
518
+ else:
519
+ self.ewc.config.ewc_lambda = lambda_ewc
520
+ return self
521
+
522
+ def apply_si_regularization(self, c: float = 0.1):
523
+ """Apply Synaptic Intelligence regularization."""
524
+ if self.si is None:
525
+ self.si = SIRegularizer(self.model, c=c)
526
+ self.si.initialize_trajectory()
527
+ else:
528
+ self.si.c = c
529
+ return self
530
+
531
+ def apply_lwf_regularization(self, alpha: float = 0.5):
532
+ """Apply Learning without Forgetting regularization."""
533
+ if self.lwf is None:
534
+ self.lwf = LwFRegularizer(alpha=alpha)
535
+ else:
536
+ self.lwf.alpha = alpha
537
+ return self
538
+
539
+ def compute_fisher(self, dataloader: DataLoader, device: torch.device):
540
+ """Compute Fisher information matrix for EWC."""
541
+ if self.ewc:
542
+ self.ewc.compute_fisher(dataloader, device)
543
+
544
+ def get_regularization_loss(self) -> torch.Tensor:
545
+ """Get current regularization loss."""
546
+ if self.ewc:
547
+ return self.ewc.compute_ewc_loss()
548
+ elif self.si:
549
+ return self.si.compute_si_loss()
550
+ return torch.tensor(0.0)
551
+
552
+ def progressive_unfreeze(
553
+ self,
554
+ start_layers: int = 4,
555
+ unfreeze_every_n_epochs: int = 2,
556
+ max_layers: Optional[int] = None
557
+ ):
558
+ """
559
+ Progressive unfreezing strategy for continual learning.
560
+
561
+ Args:
562
+ start_layers: Number of layers to keep unfrozen initially
563
+ unfreeze_every_n_epochs: Epochs between unfreezing
564
+ max_layers: Maximum layers to unfreeze (None = all)
565
+ """
566
+ self.start_layers = start_layers
567
+ self.unfreeze_every_n_epochs = unfreeze_every_n_epochs
568
+ self.max_layers = max_layers
569
+ self.current_epoch = 0
570
+
571
+ # Initially freeze all but top layers
572
+ self._unfreeze_layers(start_layers)
573
+
574
+ def _unfreeze_layers(self, num_layers: int):
575
+ """Unfreeze top N layers of the model."""
576
+ layers = list(self.model.modules())
577
+ # Unfreeze from the end (top layers)
578
+ for layer in layers[-num_layers:]:
579
+ for param in layer.parameters():
580
+ param.requires_grad = True
581
+
582
+ def step_epoch(self):
583
+ """Call at end of each epoch for progressive unfreezing."""
584
+ if hasattr(self, 'unfreeze_every_n_epochs'):
585
+ self.current_epoch += 1
586
+ if self.current_epoch % self.unfreeze_every_n_epochs == 0:
587
+ current_unfrozen = self.start_layers + (self.current_epoch // self.unfreeze_every_n_epochs) * 2
588
+ if self.max_layers is None or current_unfrozen <= self.max_layers:
589
+ self._unfreeze_layers(current_unfrozen)
multi_task.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Task Learning Implementation for NTF
3
+ Supports task-specific heads for different fine-tuning objectives
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Dict, List, Optional, Any, Union
9
+ from dataclasses import dataclass, field
10
+ from enum import Enum
11
+
12
+
13
+ class TaskType(str, Enum):
14
+ """Supported task types for multi-task learning."""
15
+ CLASSIFICATION = "classification"
16
+ SEQUENCE_TO_SEQUENCE = "sequence_to_sequence"
17
+ TOKEN_CLASSIFICATION = "token_classification"
18
+ QUESTION_ANSWERING = "question_answering"
19
+ GENERATION = "generation"
20
+
21
+
22
+ @dataclass
23
+ class TaskHeadConfig:
24
+ """Configuration for a task-specific head."""
25
+
26
+ task_name: str
27
+ head_type: TaskType
28
+ config: Dict[str, Any] = field(default_factory=dict)
29
+
30
+ def __post_init__(self):
31
+ if isinstance(self.head_type, str):
32
+ self.head_type = TaskType(self.head_type)
33
+
34
+
35
+ class ClassificationHead(nn.Module):
36
+ """Classification head for sequence classification tasks."""
37
+
38
+ def __init__(
39
+ self,
40
+ hidden_size: int,
41
+ num_labels: int,
42
+ dropout: float = 0.1,
43
+ **kwargs
44
+ ):
45
+ super().__init__()
46
+ self.dropout = nn.Dropout(dropout)
47
+ self.classifier = nn.Linear(hidden_size, num_labels)
48
+ self.num_labels = num_labels
49
+
50
+ def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
51
+ # Use pooled output (last hidden state of [CLS] or mean pooling)
52
+ if attention_mask is not None:
53
+ # Mean pooling with mask
54
+ mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
55
+ sum_embeddings = (hidden_states * mask_expanded).sum(1)
56
+ sum_mask = mask_expanded.sum(1).clamp(min=1e-9)
57
+ pooled_output = sum_embeddings / sum_mask
58
+ else:
59
+ pooled_output = hidden_states[:, -1, :] # Last token
60
+
61
+ pooled_output = self.dropout(pooled_output)
62
+ return self.classifier(pooled_output)
63
+
64
+
65
+ class SequenceToSequenceHead(nn.Module):
66
+ """Sequence-to-sequence head for generation tasks."""
67
+
68
+ def __init__(
69
+ self,
70
+ hidden_size: int,
71
+ vocab_size: int,
72
+ max_length: int = 512,
73
+ **kwargs
74
+ ):
75
+ super().__init__()
76
+ self.output_projection = nn.Linear(hidden_size, vocab_size)
77
+ self.max_length = max_length
78
+ self.vocab_size = vocab_size
79
+
80
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
81
+ return self.output_projection(hidden_states)
82
+
83
+
84
+ class TokenClassificationHead(nn.Module):
85
+ """Token-level classification head (NER, POS tagging, etc.)."""
86
+
87
+ def __init__(
88
+ self,
89
+ hidden_size: int,
90
+ num_labels: int,
91
+ dropout: float = 0.1,
92
+ **kwargs
93
+ ):
94
+ super().__init__()
95
+ self.dropout = nn.Dropout(dropout)
96
+ self.classifier = nn.Linear(hidden_size, num_labels)
97
+ self.num_labels = num_labels
98
+
99
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
100
+ hidden_states = self.dropout(hidden_states)
101
+ return self.classifier(hidden_states)
102
+
103
+
104
+ class QuestionAnsweringHead(nn.Module):
105
+ """Head for extractive question answering."""
106
+
107
+ def __init__(
108
+ self,
109
+ hidden_size: int,
110
+ dropout: float = 0.1,
111
+ **kwargs
112
+ ):
113
+ super().__init__()
114
+ self.qa_outputs = nn.Linear(hidden_size, 2) # start and end logits
115
+
116
+ def forward(self, hidden_states: torch.Tensor) -> tuple:
117
+ logits = self.qa_outputs(hidden_states)
118
+ start_logits, end_logits = logits.split(1, dim=-1)
119
+ return start_logits.squeeze(-1), end_logits.squeeze(-1)
120
+
121
+
122
+ class TaskHead(nn.Module):
123
+ """Wrapper for task-specific heads."""
124
+
125
+ HEAD_CLASSES = {
126
+ TaskType.CLASSIFICATION: ClassificationHead,
127
+ TaskType.SEQUENCE_TO_SEQUENCE: SequenceToSequenceHead,
128
+ TaskType.TOKEN_CLASSIFICATION: TokenClassificationHead,
129
+ TaskType.QUESTION_ANSWERING: QuestionAnsweringHead,
130
+ }
131
+
132
+ def __init__(self, config: TaskHeadConfig, hidden_size: int, vocab_size: Optional[int] = None):
133
+ super().__init__()
134
+ self.config = config
135
+ self.task_name = config.task_name
136
+ self.head_type = config.head_type
137
+
138
+ head_config = dict(config.config)
139
+ head_config["hidden_size"] = hidden_size
140
+
141
+ if vocab_size is not None:
142
+ head_config["vocab_size"] = vocab_size
143
+
144
+ head_class = self.HEAD_CLASSES.get(head_type)
145
+ if head_class is None:
146
+ raise ValueError(f"Unsupported task type: {head_type}")
147
+
148
+ self.head = head_class(**head_config)
149
+
150
+ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
151
+ return self.head(hidden_states, **kwargs)
152
+
153
+
154
+ class MultiTaskModel(nn.Module):
155
+ """
156
+ Multi-task model with task-specific heads sharing a common base.
157
+
158
+ Args:
159
+ base_model: Base transformer model
160
+ base_model_name: Name or path of base model
161
+ """
162
+
163
+ def __init__(self, base_model=None, base_model_name: Optional[str] = None):
164
+ super().__init__()
165
+
166
+ if base_model is None and base_model_name is None:
167
+ raise ValueError("Must provide either base_model or base_model_name")
168
+
169
+ if base_model is None:
170
+ from transformers import AutoModelForCausalLM
171
+ self.base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
172
+ else:
173
+ self.base_model = base_model
174
+
175
+ # Get hidden size from base model
176
+ self.hidden_size = getattr(self.base_model.config, 'hidden_size', 768)
177
+ self.vocab_size = getattr(self.base_model.config, 'vocab_size', None)
178
+
179
+ # Task heads registry
180
+ self.task_heads: Dict[str, TaskHead] = nn.ModuleDict()
181
+ self.active_task: Optional[str] = None
182
+
183
+ # Task weights for balanced training
184
+ self.task_weights: Dict[str, float] = {}
185
+
186
+ def add_task_head(
187
+ self,
188
+ task_name: str,
189
+ head_type: Union[str, TaskType],
190
+ config: Optional[Dict[str, Any]] = None
191
+ ):
192
+ """
193
+ Add a task-specific head to the model.
194
+
195
+ Args:
196
+ task_name: Unique name for this task
197
+ head_type: Type of task (classification, seq2seq, etc.)
198
+ config: Task-specific configuration
199
+ """
200
+ if config is None:
201
+ config = {}
202
+
203
+ task_config = TaskHeadConfig(
204
+ task_name=task_name,
205
+ head_type=head_type,
206
+ config=config
207
+ )
208
+
209
+ task_head = TaskHead(task_config, self.hidden_size, self.vocab_size)
210
+ self.task_heads[task_name] = task_head
211
+ self.task_weights[task_name] = 1.0 # Default equal weight
212
+
213
+ def set_task_weights(self, weights: Dict[str, float]):
214
+ """Set weights for each task in multi-task training."""
215
+ for task_name, weight in weights.items():
216
+ if task_name in self.task_heads:
217
+ self.task_weights[task_name] = weight
218
+
219
+ def set_active_task(self, task_name: str):
220
+ """Set the currently active task for single-task inference."""
221
+ if task_name not in self.task_heads:
222
+ raise ValueError(f"Task '{task_name}' not found. Available: {list(self.task_heads.keys())}")
223
+ self.active_task = task_name
224
+
225
+ def forward(
226
+ self,
227
+ input_ids: torch.Tensor,
228
+ attention_mask: Optional[torch.Tensor] = None,
229
+ labels: Optional[torch.Tensor] = None,
230
+ task_name: Optional[str] = None,
231
+ **kwargs
232
+ ) -> Dict[str, torch.Tensor]:
233
+ """
234
+ Forward pass through base model and task head.
235
+
236
+ Args:
237
+ input_ids: Input token IDs
238
+ attention_mask: Attention mask
239
+ labels: Optional labels for loss computation
240
+ task_name: Task to use (overrides active_task)
241
+
242
+ Returns:
243
+ Dictionary containing logits and optionally loss
244
+ """
245
+ # Determine which task to use
246
+ task = task_name or self.active_task
247
+
248
+ if task is None and len(self.task_heads) == 1:
249
+ task = list(self.task_heads.keys())[0]
250
+ elif task is None:
251
+ raise ValueError("No task specified and multiple heads available")
252
+
253
+ if task not in self.task_heads:
254
+ raise ValueError(f"Task '{task}' not found")
255
+
256
+ # Get base model outputs
257
+ base_outputs = self.base_model(
258
+ input_ids=input_ids,
259
+ attention_mask=attention_mask,
260
+ output_hidden_states=True,
261
+ **kwargs
262
+ )
263
+
264
+ # Get last hidden state
265
+ hidden_states = base_outputs.hidden_states[-1]
266
+
267
+ # Apply task head
268
+ head = self.task_heads[task]
269
+ head_output = head(hidden_states, attention_mask=attention_mask)
270
+
271
+ result = {"logits": head_output}
272
+
273
+ # Compute loss if labels provided
274
+ if labels is not None:
275
+ if head.head_type == TaskType.CLASSIFICATION:
276
+ loss_fct = nn.CrossEntropyLoss()
277
+ loss = loss_fct(head_output.view(-1, head.num_labels), labels.view(-1))
278
+ elif head.head_type == TaskType.SEQUENCE_TO_SEQUENCE:
279
+ shift_logits = head_output[..., :-1, :].contiguous()
280
+ shift_labels = labels[..., 1:].contiguous()
281
+ loss_fct = nn.CrossEntropyLoss()
282
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
283
+ else:
284
+ loss_fct = nn.CrossEntropyLoss()
285
+ loss = loss_fct(head_output.view(-1, head_output.size(-1)), labels.view(-1))
286
+
287
+ result["loss"] = loss
288
+
289
+ return result
290
+
291
+ def get_num_tasks(self) -> int:
292
+ """Return number of task heads."""
293
+ return len(self.task_heads)
294
+
295
+ def list_tasks(self) -> List[str]:
296
+ """Return list of task names."""
297
+ return list(self.task_heads.keys())
298
+
299
+
300
+ class MultiTaskTrainer:
301
+ """
302
+ Trainer for multi-task learning with task-balanced loss.
303
+
304
+ Args:
305
+ model: MultiTaskModel instance
306
+ task_datasets: Dictionary mapping task names to datasets
307
+ task_weights: Optional dictionary of task weights
308
+ """
309
+
310
+ def __init__(
311
+ self,
312
+ model: MultiTaskModel,
313
+ task_datasets: Dict[str, Any],
314
+ task_weights: Optional[Dict[str, float]] = None,
315
+ tokenizer=None,
316
+ device: Optional[torch.device] = None
317
+ ):
318
+ self.model = model
319
+ self.task_datasets = task_datasets
320
+ self.tokenizer = tokenizer
321
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
322
+
323
+ # Set task weights
324
+ if task_weights:
325
+ self.model.set_task_weights(task_weights)
326
+
327
+ # Move model to device
328
+ self.model.to(self.device)
329
+
330
+ def train_epoch(
331
+ self,
332
+ optimizer: torch.optim.Optimizer,
333
+ batch_sizes: Dict[str, int] = None,
334
+ gradient_accumulation_steps: int = 1
335
+ ) -> Dict[str, float]:
336
+ """
337
+ Train one epoch across all tasks.
338
+
339
+ Args:
340
+ optimizer: Optimizer for training
341
+ batch_sizes: Batch size per task
342
+ gradient_accumulation_steps: Steps before optimizer update
343
+
344
+ Returns:
345
+ Dictionary of losses per task
346
+ """
347
+ self.model.train()
348
+ task_losses = {task: 0.0 for task in self.task_datasets.keys()}
349
+ task_counts = {task: 0 for task in self.task_datasets.keys()}
350
+
351
+ # Simple round-robin training across tasks
352
+ for task_name, dataset in self.task_datasets.items():
353
+ weight = self.model.task_weights.get(task_name, 1.0)
354
+
355
+ for batch in dataset:
356
+ # Move batch to device
357
+ inputs = {k: v.to(self.device) if hasattr(v, 'to') else v
358
+ for k, v in batch.items()}
359
+
360
+ optimizer.zero_grad()
361
+
362
+ # Forward pass
363
+ outputs = self.model(
364
+ input_ids=inputs.get("input_ids"),
365
+ attention_mask=inputs.get("attention_mask"),
366
+ labels=inputs.get("labels"),
367
+ task_name=task_name
368
+ )
369
+
370
+ loss = outputs["loss"] * weight
371
+ loss.backward()
372
+
373
+ optimizer.step()
374
+
375
+ task_losses[task_name] += loss.item() / weight
376
+ task_counts[task_name] += 1
377
+
378
+ # Average losses
379
+ avg_losses = {
380
+ task: task_losses[task] / max(task_counts[task], 1)
381
+ for task in task_losses
382
+ }
383
+
384
+ return avg_losses
385
+
386
+ def evaluate(
387
+ self,
388
+ eval_datasets: Dict[str, Any],
389
+ metrics_fn: Optional[Dict[str, callable]] = None
390
+ ) -> Dict[str, Dict[str, float]]:
391
+ """
392
+ Evaluate model on all tasks.
393
+
394
+ Args:
395
+ eval_datasets: Evaluation datasets per task
396
+ metrics_fn: Optional metric functions per task
397
+
398
+ Returns:
399
+ Dictionary of metrics per task
400
+ """
401
+ self.model.eval()
402
+ results = {}
403
+
404
+ with torch.no_grad():
405
+ for task_name, dataset in eval_datasets.items():
406
+ task_results = {"loss": 0.0, "count": 0}
407
+
408
+ for batch in dataset:
409
+ inputs = {k: v.to(self.device) if hasattr(v, 'to') else v
410
+ for k, v in batch.items()}
411
+
412
+ outputs = self.model(
413
+ input_ids=inputs.get("input_ids"),
414
+ attention_mask=inputs.get("attention_mask"),
415
+ labels=inputs.get("labels"),
416
+ task_name=task_name
417
+ )
418
+
419
+ task_results["loss"] += outputs["loss"].item()
420
+ task_results["count"] += 1
421
+
422
+ if task_results["count"] > 0:
423
+ task_results["loss"] /= task_results["count"]
424
+
425
+ results[task_name] = task_results
426
+
427
+ return results
p_tuning.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ P-Tuning / Prefix Tuning Implementation for NTF
3
+ Parameter-efficient tuning using learnable continuous prompts
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Dict, List, Optional, Any, Union
9
+ from dataclasses import dataclass, field
10
+ from enum import Enum
11
+
12
+ from peft import (
13
+ PrefixTuningConfig,
14
+ PromptTuningConfig,
15
+ P_TUNING_TASK_TYPE,
16
+ get_peft_model,
17
+ TaskType,
18
+ )
19
+
20
+
21
+ class PTuningMethod(str, Enum):
22
+ """P-Tuning method types."""
23
+ P_TUNING_V1 = "p_tuning_v1"
24
+ P_TUNING_V2 = "p_tuning_v2"
25
+ PREFIX_TUNING = "prefix_tuning"
26
+ PROMPT_TUNING = "prompt_tuning"
27
+
28
+
29
+ @dataclass
30
+ class PTuningConfig:
31
+ """
32
+ Configuration for P-Tuning / Prefix Tuning.
33
+
34
+ Args:
35
+ method: P-tuning method to use
36
+ num_virtual_tokens: Number of virtual/prompt tokens to add
37
+ token_dim: Dimension of token embeddings
38
+ num_transformer_submodules: Number of transformer submodules
39
+ num_attention_heads: Number of attention heads
40
+ num_layers: Number of transformer layers
41
+ encoder_hidden_size: Hidden size for encoder (P-Tuning v1)
42
+ prefix_projection: Whether to project prefix (Prefix Tuning)
43
+ prompt_tuning_init: Initialization strategy for prompt tuning
44
+ prompt_tuning_init_text: Text for initialization if using text init
45
+ """
46
+
47
+ method: PTuningMethod = PTuningMethod.P_TUNING_V2
48
+
49
+ # Core parameters
50
+ num_virtual_tokens: int = 20
51
+ token_dim: int = 768
52
+ num_transformer_submodules: int = 1
53
+ num_attention_heads: int = 12
54
+ num_layers: int = 12
55
+
56
+ # P-Tuning v1 specific
57
+ encoder_hidden_size: int = 512
58
+
59
+ # Prefix Tuning specific
60
+ prefix_projection: bool = True
61
+
62
+ # Prompt Tuning specific
63
+ prompt_tuning_init: str = "RANDOM" # RANDOM or TEXT
64
+ prompt_tuning_init_text: Optional[str] = None
65
+
66
+ # Task type
67
+ task_type: TaskType = TaskType.CAUSAL_LM
68
+
69
+ def to_peft_config(self):
70
+ """Convert to appropriate PEFT config based on method."""
71
+ if self.method == PTuningMethod.PREFIX_TUNING:
72
+ return PrefixTuningConfig(
73
+ num_virtual_tokens=self.num_virtual_tokens,
74
+ token_dim=self.token_dim,
75
+ num_attention_heads=self.num_attention_heads,
76
+ num_layers=self.num_layers,
77
+ prefix_projection=self.prefix_projection,
78
+ task_type=self.task_type,
79
+ )
80
+ elif self.method == PTuningMethod.PROMPT_TUNING:
81
+ return PromptTuningConfig(
82
+ num_virtual_tokens=self.num_virtual_tokens,
83
+ token_dim=self.token_dim,
84
+ prompt_tuning_init=self.prompt_tuning_init,
85
+ prompt_tuning_init_text=self.prompt_tuning_init_text,
86
+ task_type=self.task_type,
87
+ )
88
+ else: # P-Tuning v1 or v2
89
+ # P-Tuning uses PrefixTuningConfig with specific settings
90
+ return PrefixTuningConfig(
91
+ num_virtual_tokens=self.num_virtual_tokens,
92
+ token_dim=self.token_dim,
93
+ num_attention_heads=self.num_attention_heads,
94
+ num_layers=self.num_layers,
95
+ encoder_hidden_size=self.encoder_hidden_size,
96
+ prefix_projection=self.method == PTuningMethod.P_TUNING_V1,
97
+ task_type=self.task_type,
98
+ )
99
+
100
+
101
+ class PTuningModel(nn.Module):
102
+ """
103
+ P-Tuning wrapper for transformer models.
104
+
105
+ Adds learnable continuous prompts to the model input
106
+ without modifying the base model weights.
107
+
108
+ Args:
109
+ base_model: Base transformer model
110
+ config: P-tuning configuration
111
+ """
112
+
113
+ def __init__(self, base_model: nn.Module, config: PTuningConfig):
114
+ super().__init__()
115
+
116
+ self.base_model = base_model
117
+ self.config = config
118
+
119
+ # Get model dimensions
120
+ model_config = base_model.config
121
+ self.token_dim = getattr(model_config, 'hidden_size', config.token_dim)
122
+ self.num_layers = getattr(model_config, 'num_hidden_layers', config.num_layers)
123
+ self.num_attention_heads = getattr(model_config, 'num_attention_heads', config.num_attention_heads)
124
+
125
+ # Update config with actual dimensions
126
+ config.token_dim = self.token_dim
127
+ config.num_layers = self.num_layers
128
+ config.num_attention_heads = self.num_attention_heads
129
+
130
+ # Create virtual tokens
131
+ self._create_virtual_tokens()
132
+
133
+ def _create_virtual_tokens(self):
134
+ """Create learnable virtual token embeddings."""
135
+ method = self.config.method
136
+
137
+ if method == PTuningMethod.PROMPT_TUNING:
138
+ # Simple prompt embeddings
139
+ self.prompt_embeddings = nn.Embedding(
140
+ self.config.num_virtual_tokens,
141
+ self.token_dim
142
+ )
143
+ nn.init.normal_(self.prompt_embeddings.weight, std=0.02)
144
+
145
+ elif method == PTuningMethod.PREFIX_TUNING:
146
+ # Prefix with projection
147
+ self.prefix_tokens = nn.Parameter(
148
+ torch.randn(
149
+ self.num_layers * 2, # key and value for each layer
150
+ self.config.num_virtual_tokens,
151
+ self.token_dim
152
+ )
153
+ )
154
+
155
+ if self.config.prefix_projection:
156
+ self.prefix_proj = nn.Sequential(
157
+ nn.Linear(self.token_dim, self.token_dim),
158
+ nn.ReLU(),
159
+ nn.Linear(self.token_dim, self.num_layers * 2 * self.token_dim)
160
+ )
161
+ else:
162
+ self.prefix_proj = None
163
+
164
+ else: # P-Tuning v1 or v2
165
+ # Encoder for generating prompts
166
+ self.prompt_encoder = nn.Sequential(
167
+ nn.Linear(self.token_dim, self.config.encoder_hidden_size),
168
+ nn.ReLU(),
169
+ nn.Linear(self.config.encoder_hidden_size,
170
+ self.num_layers * 2 * self.config.num_virtual_tokens * self.token_dim)
171
+ )
172
+
173
+ # Input embedding for prompt encoder
174
+ self.input_embeds = nn.Embedding(self.config.num_virtual_tokens, self.token_dim)
175
+ nn.init.normal_(self.input_embeds.weight, std=0.02)
176
+
177
+ def get_prompt(self, batch_size: int) -> torch.Tensor:
178
+ """Generate prompt tensors for the current batch."""
179
+ method = self.config.method
180
+
181
+ if method == PTuningMethod.PROMPT_TUNING:
182
+ # Expand prompt embeddings to batch size
183
+ prompts = self.prompt_embeddings.weight.unsqueeze(0).expand(
184
+ batch_size, -1, -1
185
+ )
186
+
187
+ elif method == PTuningMethod.PREFIX_TUNING:
188
+ prefix = self.prefix_tokens
189
+
190
+ if self.prefix_proj is not None:
191
+ prefix = self.prefix_proj(prefix.view(-1, self.token_dim))
192
+ prefix = prefix.view(
193
+ self.num_layers * 2,
194
+ self.config.num_virtual_tokens,
195
+ self.token_dim
196
+ )
197
+
198
+ prompts = prefix.unsqueeze(1).expand(
199
+ -1, batch_size, -1, -1
200
+ )
201
+
202
+ else: # P-Tuning
203
+ input_ids = torch.arange(self.config.num_virtual_tokens).long()
204
+ input_ids = input_ids.unsqueeze(0).expand(batch_size, -1)
205
+ input_embeds = self.input_embeds(input_ids)
206
+
207
+ prompts = self.prompt_encoder(input_embeds)
208
+ prompts = prompts.view(
209
+ batch_size,
210
+ self.num_layers * 2,
211
+ self.config.num_virtual_tokens,
212
+ self.token_dim
213
+ )
214
+ prompts = prompts.permute(1, 0, 2, 3)
215
+
216
+ return prompts
217
+
218
+ def forward(
219
+ self,
220
+ input_ids: torch.Tensor,
221
+ attention_mask: Optional[torch.Tensor] = None,
222
+ labels: Optional[torch.Tensor] = None,
223
+ **kwargs
224
+ ) -> Dict[str, torch.Tensor]:
225
+ """
226
+ Forward pass with virtual prompts.
227
+
228
+ Note: This is a simplified implementation. For production use,
229
+ consider using the PEFT library's P-tuning implementation.
230
+ """
231
+ batch_size = input_ids.size(0)
232
+
233
+ # Get base model outputs
234
+ outputs = self.base_model(
235
+ input_ids=input_ids,
236
+ attention_mask=attention_mask,
237
+ output_hidden_states=True,
238
+ **kwargs
239
+ )
240
+
241
+ result = {"logits": outputs.logits}
242
+
243
+ if hasattr(outputs, 'loss') and outputs.loss is not None:
244
+ result["loss"] = outputs.loss
245
+
246
+ return result
247
+
248
+ def get_trainable_params(self) -> Dict[str, torch.Tensor]:
249
+ """Get dictionary of trainable parameters (prompts only)."""
250
+ trainable = {}
251
+
252
+ for name, param in self.named_parameters():
253
+ if param.requires_grad:
254
+ trainable[name] = param
255
+
256
+ return trainable
257
+
258
+ def print_trainable_parameters(self):
259
+ """Print number of trainable vs total parameters."""
260
+ trainable_params = sum(
261
+ p.numel() for p in self.parameters() if p.requires_grad
262
+ )
263
+ all_params = sum(p.numel() for p in self.parameters())
264
+
265
+ print(f"Trainable params: {trainable_params:,} ({100 * trainable_params / all_params:.2f}%)")
266
+ print(f"All params: {all_params:,}")
267
+ print(f"Frozen params: {all_params - trainable_params:,}")
268
+
269
+
270
+ def setup_p_tuning(
271
+ model: nn.Module,
272
+ method: str = "p_tuning_v2",
273
+ num_virtual_tokens: int = 20,
274
+ task_type: str = "CAUSAL_LM"
275
+ ) -> nn.Module:
276
+ """
277
+ Setup P-Tuning on a model using PEFT.
278
+
279
+ Args:
280
+ model: Base model to apply P-Tuning to
281
+ method: P-tuning method (p_tuning_v1, p_tuning_v2, prefix_tuning, prompt_tuning)
282
+ num_virtual_tokens: Number of virtual tokens
283
+ task_type: PEFT task type
284
+
285
+ Returns:
286
+ Model with P-Tuning applied
287
+ """
288
+ config = PTuningConfig(
289
+ method=PTuningMethod(method),
290
+ num_virtual_tokens=num_virtual_tokens,
291
+ task_type=TaskType(task_type)
292
+ )
293
+
294
+ peft_config = config.to_peft_config()
295
+ return get_peft_model(model, peft_config)
test_tutorial_examples.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test Suite for Tutorial Code Examples
3
+ Ensures all code examples in tutorials remain functional
4
+ """
5
+
6
+ import pytest
7
+ import os
8
+ import sys
9
+ import torch
10
+ from datasets import Dataset
11
+
12
+ # Add project root to path
13
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
+
15
+
16
+ class TestTutorial03:
17
+ """Test Tutorial 03: Full Fine-Tuning examples"""
18
+
19
+ def test_full_finetuning_basic(self):
20
+ """Test basic full fine-tuning workflow"""
21
+ from ntf.config import NTFConfig, ModelConfig, TrainingConfig
22
+ from ntf.models import ModelRegistry
23
+ from ntf.finetuning import FullFinetuneTrainer
24
+
25
+ config = NTFConfig(
26
+ model=ModelConfig(name="facebook/opt-125m"),
27
+ training=TrainingConfig(
28
+ output_dir="./test_output",
29
+ num_train_epochs=1,
30
+ per_device_train_batch_size=2,
31
+ )
32
+ )
33
+
34
+ registry = ModelRegistry(config.model)
35
+ model, tokenizer = registry.load_model_and_tokenizer()
36
+
37
+ train_data = Dataset.from_dict({
38
+ "text": ["Hello world", "Test sentence"] * 10
39
+ })
40
+
41
+ trainer = FullFinetuneTrainer(
42
+ model=model,
43
+ config=config.training,
44
+ train_dataset=train_data,
45
+ tokenizer=tokenizer
46
+ )
47
+
48
+ trainer.train()
49
+
50
+ assert os.path.exists("./test_output")
51
+
52
+
53
+ class TestTutorial05:
54
+ """Test Tutorial 05: PEFT/LoRA examples"""
55
+
56
+ def test_lora_setup(self):
57
+ """Test LoRA adapter setup"""
58
+ from ntf.finetuning import LoRAConfig, PEFTTrainer
59
+ from ntf.models import ModelRegistry
60
+
61
+ registry = ModelRegistry("facebook/opt-125m")
62
+ model, tokenizer = registry.load_model_and_tokenizer()
63
+
64
+ lora_config = LoRAConfig(
65
+ r=8,
66
+ alpha=16,
67
+ dropout=0.05,
68
+ target_modules=["q_proj", "v_proj"],
69
+ )
70
+
71
+ trainer = PEFTTrainer(model, lora_config, tokenizer)
72
+
73
+ trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
74
+ all_params = sum(p.numel() for p in trainer.model.parameters())
75
+
76
+ assert trainable_params < all_params
77
+ assert trainable_params > 0
78
+
79
+ def test_p_tuning_setup(self):
80
+ """Test P-Tuning setup"""
81
+ from ntf.finetuning import PTuningConfig, PTuningMethod, setup_p_tuning
82
+ from ntf.models import ModelRegistry
83
+
84
+ registry = ModelRegistry("facebook/opt-125m")
85
+ model, tokenizer = registry.load_model_and_tokenizer()
86
+
87
+ config = PTuningConfig(
88
+ method=PTuningMethod.P_TUNING_V2,
89
+ num_virtual_tokens=20,
90
+ )
91
+
92
+ peft_model = setup_p_tuning(model, method="p_tuning_v2", num_virtual_tokens=20)
93
+
94
+ assert peft_model is not None
95
+
96
+
97
+ class TestTutorial04:
98
+ """Test Tutorial 04: Continual Learning examples"""
99
+
100
+ def test_ewc_regularization(self):
101
+ """Test EWC regularization setup"""
102
+ from ntf.utils import EWCConfig, EWCRegularizer, ContinualLearningWrapper
103
+ from ntf.models import ModelRegistry
104
+
105
+ registry = ModelRegistry("facebook/opt-125m")
106
+ model, tokenizer = registry.load_model_and_tokenizer()
107
+
108
+ ewc_config = EWCConfig(ewc_lambda=1000.0)
109
+ ewc = EWCRegularizer(model, ewc_config)
110
+
111
+ assert ewc is not None
112
+ assert ewc.config.ewc_lambda == 1000.0
113
+
114
+ def test_si_regularization(self):
115
+ """Test Synaptic Intelligence regularization"""
116
+ from ntf.utils import SIRegularizer, ContinualLearningWrapper
117
+ from ntf.models import ModelRegistry
118
+
119
+ registry = ModelRegistry("facebook/opt-125m")
120
+ model, _ = registry.load_model_and_tokenizer()
121
+
122
+ wrapper = ContinualLearningWrapper(model, method="si")
123
+ wrapper.apply_si_regularization(c=0.1)
124
+
125
+ assert wrapper.si is not None
126
+ assert wrapper.si.c == 0.1
127
+
128
+ def test_lwf_regularization(self):
129
+ """Test Learning without Forgetting"""
130
+ from ntf.utils import LwFRegularizer, ContinualLearningWrapper
131
+ from ntf.models import ModelRegistry
132
+
133
+ registry = ModelRegistry("facebook/opt-125m")
134
+ model, _ = registry.load_model_and_tokenizer()
135
+
136
+ wrapper = ContinualLearningWrapper(model, method="lwf")
137
+ wrapper.apply_lwf_regularization(alpha=0.5)
138
+
139
+ assert wrapper.lwf is not None
140
+ assert wrapper.lwf.alpha == 0.1
141
+
142
+
143
+ class TestMultiTask:
144
+ """Test Multi-Task Learning (Spec 4.1.1)"""
145
+
146
+ def test_multi_task_model_creation(self):
147
+ """Test creating multi-task model with multiple heads"""
148
+ from ntf.finetuning import MultiTaskModel, TaskType, MultiTaskTrainer
149
+ from ntf.models import ModelRegistry
150
+
151
+ registry = ModelRegistry("facebook/opt-125m")
152
+ base_model, tokenizer = registry.load_model_and_tokenizer()
153
+
154
+ model = MultiTaskModel(base_model=base_model)
155
+
156
+ model.add_task_head(
157
+ task_name="classification",
158
+ head_type=TaskType.CLASSIFICATION,
159
+ config={"num_labels": 5}
160
+ )
161
+
162
+ model.add_task_head(
163
+ task_name="summarization",
164
+ head_type=TaskType.SEQUENCE_TO_SEQUENCE,
165
+ config={"max_length": 512}
166
+ )
167
+
168
+ assert model.get_num_tasks() == 2
169
+ assert "classification" in model.list_tasks()
170
+ assert "summarization" in model.list_tasks()
171
+
172
+ def test_multi_task_forward(self):
173
+ """Test forward pass through multi-task model"""
174
+ from ntf.finetuning import MultiTaskModel, TaskType
175
+ from ntf.models import ModelRegistry
176
+ import torch
177
+
178
+ registry = ModelRegistry("facebook/opt-125m")
179
+ base_model, tokenizer = registry.load_model_and_tokenizer()
180
+
181
+ model = MultiTaskModel(base_model=base_model)
182
+ model.add_task_head(
183
+ task_name="classification",
184
+ head_type=TaskType.CLASSIFICATION,
185
+ config={"num_labels": 3}
186
+ )
187
+
188
+ input_ids = torch.randint(0, 1000, (2, 10))
189
+ attention_mask = torch.ones((2, 10))
190
+
191
+ output = model(
192
+ input_ids=input_ids,
193
+ attention_mask=attention_mask,
194
+ task_name="classification"
195
+ )
196
+
197
+ assert "logits" in output
198
+ assert output["logits"].shape[0] == 2
199
+ assert output["logits"].shape[1] == 3
200
+
201
+
202
+ class TestContinualLearningWrapper:
203
+ """Test ContinualLearningWrapper API (Spec 4.1.2)"""
204
+
205
+ def test_wrapper_api(self):
206
+ """Test the unified ContinualLearningWrapper API"""
207
+ from ntf.utils import ContinualLearningWrapper
208
+ from ntf.models import ModelRegistry
209
+
210
+ registry = ModelRegistry("facebook/opt-125m")
211
+ model, _ = registry.load_model_and_tokenizer()
212
+
213
+ wrapper = ContinualLearningWrapper(model, method="ewc")
214
+
215
+ # Test EWC
216
+ wrapper.apply_ewc_regularization(lambda_ewc=0.5)
217
+ assert wrapper.ewc is not None
218
+
219
+ # Test SI
220
+ wrapper2 = ContinualLearningWrapper(model, method="si")
221
+ wrapper2.apply_si_regularization(c=0.1)
222
+ assert wrapper2.si is not None
223
+
224
+ # Test LwF
225
+ wrapper3 = ContinualLearningWrapper(model, method="lwf")
226
+ wrapper3.apply_lwf_regularization(alpha=0.5)
227
+ assert wrapper3.lwf is not None
228
+
229
+ def test_progressive_unfreeze(self):
230
+ """Test progressive unfreezing strategy"""
231
+ from ntf.utils import ContinualLearningWrapper
232
+ from ntf.models import ModelRegistry
233
+
234
+ registry = ModelRegistry("facebook/opt-125m")
235
+ model, _ = registry.load_model_and_tokenizer()
236
+
237
+ wrapper = ContinualLearningWrapper(model)
238
+ wrapper.progressive_unfreeze(
239
+ start_layers=4,
240
+ unfreeze_every_n_epochs=2,
241
+ max_layers=12
242
+ )
243
+
244
+ assert hasattr(wrapper, 'start_layers')
245
+ assert wrapper.start_layers == 4
246
+
247
+
248
+ if __name__ == "__main__":
249
+ pytest.main([__file__, "-v"])
utils/__init__.py CHANGED
@@ -1,81 +1,47 @@
1
- """
2
- Utilities package for Nexuss Transformer Framework
3
- """
4
 
5
- from .continual_learning import (
6
- EWCConfig,
7
- ReplayConfig,
8
- GEMConfig,
9
- ContinualLearningConfig,
10
- EWCRegularizer,
11
- ReplayBuffer,
12
- GEMOptimizer,
13
- LwFLoss,
14
- create_continual_learning_wrapper,
15
- SIRegularizer,
16
- LwFRegularizer,
17
- ContinualLearningWrapper,
18
  )
19
-
20
- from .versioning import (
21
- ModelStage,
22
- ModelVersion,
23
- ModelMetadata,
24
- ModelRegistry,
25
- create_model_metadata,
26
- )
27
-
28
- from .metrics import (
29
- EvaluationResults,
30
- compute_perplexity,
31
- compute_accuracy,
32
- evaluate_model,
33
- benchmark_throughput,
34
- compare_models,
35
- )
36
-
37
- from .logging import (
38
- setup_logging,
39
- get_logger,
40
- set_log_level,
41
- DebugLogger,
42
- validate_config,
43
  )
44
 
45
  __all__ = [
46
- # Continual Learning
47
- "EWCConfig",
48
- "ReplayConfig",
49
- "GEMConfig",
50
- "ContinualLearningConfig",
51
- "EWCRegularizer",
52
- "ReplayBuffer",
53
- "GEMOptimizer",
54
- "LwFLoss",
55
- "create_continual_learning_wrapper",
56
- "SIRegularizer",
57
- "LwFRegularizer",
58
- "ContinualLearningWrapper",
59
-
60
- # Versioning
61
- "ModelStage",
62
- "ModelVersion",
63
- "ModelMetadata",
64
- "ModelRegistry",
65
- "create_model_metadata",
66
-
67
- # Metrics
68
- "EvaluationResults",
69
- "compute_perplexity",
70
- "compute_accuracy",
71
- "evaluate_model",
72
- "benchmark_throughput",
73
- "compare_models",
74
-
75
- # Logging
76
- "setup_logging",
77
- "get_logger",
78
- "set_log_level",
79
- "DebugLogger",
80
- "validate_config",
81
  ]
 
1
+ """Finetuning package - PEFT, LoRA, and layer freezing utilities."""
 
 
2
 
3
+ from finetuning.peft_finetune import PEFTTrainer, LoRAConfig, setup_lora
4
+ from finetuning.freeze import LayerFreezer, freeze_layers
5
+ from finetuning.full_finetune import FullFinetuneTrainer, full_finetune
6
+ from finetuning.multi_task import (
7
+ MultiTaskModel,
8
+ MultiTaskTrainer,
9
+ TaskHead,
10
+ TaskType,
11
+ TaskHeadConfig,
12
+ ClassificationHead,
13
+ SequenceToSequenceHead,
14
+ TokenClassificationHead,
15
+ QuestionAnsweringHead,
16
  )
17
+ from finetuning.p_tuning import (
18
+ PTuningModel,
19
+ PTuningConfig,
20
+ PTuningMethod,
21
+ setup_p_tuning,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
 
24
  __all__ = [
25
+ "PEFTTrainer",
26
+ "LoRAConfig",
27
+ "setup_lora",
28
+ "LayerFreezer",
29
+ "freeze_layers",
30
+ "FullFinetuneTrainer",
31
+ "full_finetune",
32
+ # Multi-task learning
33
+ "MultiTaskModel",
34
+ "MultiTaskTrainer",
35
+ "TaskHead",
36
+ "TaskType",
37
+ "TaskHeadConfig",
38
+ "ClassificationHead",
39
+ "SequenceToSequenceHead",
40
+ "TokenClassificationHead",
41
+ "QuestionAnsweringHead",
42
+ # P-Tuning / Prefix Tuning
43
+ "PTuningModel",
44
+ "PTuningConfig",
45
+ "PTuningMethod",
46
+ "setup_p_tuning",
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  ]