|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from torch.utils.data import DataLoader |
|
|
from collections import deque |
|
|
from typing import List, Dict, Any, Optional, Union |
|
|
from tqdm import tqdm |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModalityConfig: |
|
|
name: str |
|
|
modality_id: int |
|
|
|
|
|
class UnifiedMultiModalPreprocessor(nn.Module): |
|
|
def __init__(self, model_dim: int = 2048): |
|
|
super().__init__() |
|
|
self.modality_configs = { |
|
|
'text': ModalityConfig('text', 0), |
|
|
'image': ModalityConfig('image', 1), |
|
|
'audio': ModalityConfig('audio', 2), |
|
|
'video': ModalityConfig('video', 3) |
|
|
} |
|
|
|
|
|
def process_batch(self, batch_data: Union[torch.Tensor, List[Any]], modality_type: str) -> List[Dict]: |
|
|
processed_segments = [] |
|
|
if modality_type not in self.modality_configs: |
|
|
return processed_segments |
|
|
|
|
|
config = self.modality_configs[modality_type] |
|
|
|
|
|
if isinstance(batch_data, list): |
|
|
|
|
|
valid_data = [x for x in batch_data if x is not None] |
|
|
if not valid_data: |
|
|
return [] |
|
|
|
|
|
|
|
|
try: |
|
|
data_tensor = torch.stack(valid_data) |
|
|
except Exception as e: |
|
|
print(f"Error stacking modality data: {e}") |
|
|
return [] |
|
|
elif isinstance(batch_data, torch.Tensor): |
|
|
data_tensor = batch_data |
|
|
else: |
|
|
return [] |
|
|
|
|
|
processed_segments.append({ |
|
|
'type': modality_type, |
|
|
'data': data_tensor, |
|
|
'modality_id': config.modality_id |
|
|
}) |
|
|
return processed_segments |
|
|
|
|
|
|
|
|
class ExperienceReplayBuffer: |
|
|
def __init__(self, max_size: int = 10000): |
|
|
self.buffer = deque(maxlen=max_size) |
|
|
|
|
|
def add(self, sample: Dict[str, Any]): |
|
|
safe_sample = {} |
|
|
for k, v in sample.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
safe_sample[k] = v.detach().cpu() |
|
|
elif isinstance(v, list): |
|
|
|
|
|
safe_sample[k] = [x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in v] |
|
|
else: |
|
|
safe_sample[k] = v |
|
|
self.buffer.append(safe_sample) |
|
|
|
|
|
def sample(self, batch_size: int) -> List[Any]: |
|
|
"""从buffer中采样""" |
|
|
if not self.buffer: |
|
|
return [] |
|
|
|
|
|
indices = np.random.choice( |
|
|
len(self.buffer), |
|
|
min(len(self.buffer), batch_size), |
|
|
replace=False |
|
|
) |
|
|
return [self.buffer[i] for i in indices] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.buffer) |
|
|
|
|
|
def clear(self): |
|
|
"""清空buffer""" |
|
|
self.buffer.clear() |
|
|
|
|
|
|
|
|
class EWC: |
|
|
"""弹性权重固化 (Elastic Weight Consolidation)""" |
|
|
def __init__( |
|
|
self, |
|
|
model: nn.Module, |
|
|
dataloader: DataLoader, |
|
|
preprocessor: UnifiedMultiModalPreprocessor, |
|
|
importance: float = 1000.0 |
|
|
): |
|
|
self.model = model |
|
|
self.preprocessor = preprocessor |
|
|
self.importance = importance |
|
|
self.device = next(model.parameters()).device |
|
|
|
|
|
|
|
|
self.params = { |
|
|
n: p.clone().detach() |
|
|
for n, p in model.named_parameters() |
|
|
if p.requires_grad |
|
|
} |
|
|
|
|
|
self.fisher = self._compute_fisher(dataloader) |
|
|
|
|
|
def _compute_fisher(self, dataloader: DataLoader) -> Dict[str, torch.Tensor]: |
|
|
"""计算Fisher信息矩阵 (使用 Empirical Fisher)""" |
|
|
fisher = { |
|
|
n: torch.zeros_like(p) |
|
|
for n, p in self.model.named_parameters() |
|
|
if p.requires_grad |
|
|
} |
|
|
|
|
|
self.model.eval() |
|
|
num_samples = 0 |
|
|
|
|
|
|
|
|
pbar = tqdm(dataloader, desc="Computing Fisher Matrix", leave=False) |
|
|
for batch in pbar: |
|
|
if batch is None: continue |
|
|
|
|
|
self.model.zero_grad() |
|
|
|
|
|
|
|
|
instruction_ids = batch['instruction'].to(self.device) |
|
|
response_ids = batch['response'].to(self.device) |
|
|
|
|
|
input_ids = torch.cat([instruction_ids, response_ids], dim=1) |
|
|
|
|
|
|
|
|
input_data = {'segments': []} |
|
|
|
|
|
|
|
|
raw_modality_data = batch.get('modality_data') |
|
|
if raw_modality_data is not None: |
|
|
modality_type = batch.get('modality_type', 'image') |
|
|
if isinstance(modality_type, list): modality_type = modality_type[0] |
|
|
|
|
|
mod_segments = self.preprocessor.process_batch(raw_modality_data, modality_type) |
|
|
for seg in mod_segments: |
|
|
seg['data'] = seg['data'].to(self.device) |
|
|
input_data['segments'].append(seg) |
|
|
|
|
|
input_data['segments'].append({ |
|
|
'type': 'text', |
|
|
'data': input_ids, |
|
|
'modality_id': 0 |
|
|
}) |
|
|
|
|
|
output = self.model(input_data) |
|
|
logits = output['logits'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shift_logits = logits[:, :-1, :].contiguous() |
|
|
shift_labels = input_ids[:, 1:].contiguous() |
|
|
|
|
|
|
|
|
|
|
|
inst_len = instruction_ids.shape[1] |
|
|
loss_mask = torch.ones_like(shift_labels, dtype=torch.float) |
|
|
if inst_len > 1: |
|
|
loss_mask[:, :inst_len-1] = 0.0 |
|
|
|
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(reduction='none') |
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
|
|
|
|
|
loss = (loss * loss_mask.view(-1)).sum() / (loss_mask.sum() + 1e-6) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
for n, p in self.model.named_parameters(): |
|
|
if p.grad is not None and n in fisher: |
|
|
fisher[n] += p.grad.detach() ** 2 |
|
|
|
|
|
num_samples += input_ids.size(0) |
|
|
|
|
|
|
|
|
if num_samples > 0: |
|
|
for n in fisher: |
|
|
fisher[n] /= num_samples |
|
|
|
|
|
self.model.train() |
|
|
return fisher |
|
|
|
|
|
def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor: |
|
|
target_model = model if model is not None else self.model |
|
|
|
|
|
loss = torch.tensor(0.0, device=self.device) |
|
|
|
|
|
for n, p in target_model.named_parameters(): |
|
|
if n in self.params and p.requires_grad: |
|
|
if n in self.fisher: |
|
|
loss += (self.fisher[n] * (p - self.params[n]) ** 2).sum() |
|
|
|
|
|
return self.importance * loss |
|
|
|
|
|
|
|
|
class OnlineEWC(EWC): |
|
|
def __init__( |
|
|
self, |
|
|
model: nn.Module, |
|
|
preprocessor: UnifiedMultiModalPreprocessor, |
|
|
importance: float = 1000.0, |
|
|
gamma: float = 0.9 |
|
|
): |
|
|
self.model = model |
|
|
self.preprocessor = preprocessor |
|
|
self.importance = importance |
|
|
self.gamma = gamma |
|
|
self.device = next(model.parameters()).device |
|
|
|
|
|
self.params = {} |
|
|
self.fisher = {} |
|
|
self.task_count = 0 |
|
|
|
|
|
def update_fisher(self, dataloader: DataLoader): |
|
|
"""更新Fisher信息矩阵""" |
|
|
print(f"Updating Online EWC Fisher Matrix (Task {self.task_count + 1})...") |
|
|
new_fisher = self._compute_fisher(dataloader) |
|
|
|
|
|
if self.task_count == 0: |
|
|
self.fisher = new_fisher |
|
|
else: |
|
|
for n in self.fisher: |
|
|
if n in new_fisher: |
|
|
|
|
|
self.fisher[n] = self.gamma * self.fisher[n] + new_fisher[n] |
|
|
|
|
|
|
|
|
self.params = { |
|
|
n: p.clone().detach() |
|
|
for n, p in self.model.named_parameters() |
|
|
if p.requires_grad |
|
|
} |
|
|
|
|
|
self.task_count += 1 |
|
|
print(f"Online EWC regularizer updated.") |
|
|
|
|
|
def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor: |
|
|
"""计算EWC惩罚项""" |
|
|
if self.task_count == 0: |
|
|
return torch.tensor(0.0, device=self.device) |
|
|
return super().penalty(model) |