| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from torch.utils.data import DataLoader |
| | from collections import deque |
| | from typing import List, Dict, Any, Optional, Union |
| | from tqdm import tqdm |
| | from dataclasses import dataclass |
| |
|
| |
|
| | @dataclass |
| | class ModalityConfig: |
| | name: str |
| | modality_id: int |
| |
|
| | class UnifiedMultiModalPreprocessor(nn.Module): |
| | def __init__(self, model_dim: int = 2048): |
| | super().__init__() |
| | self.modality_configs = { |
| | 'text': ModalityConfig('text', 0), |
| | 'image': ModalityConfig('image', 1), |
| | 'audio': ModalityConfig('audio', 2), |
| | 'video': ModalityConfig('video', 3) |
| | } |
| |
|
| | def process_batch(self, batch_data: Union[torch.Tensor, List[Any]], modality_type: str) -> List[Dict]: |
| | processed_segments = [] |
| | if modality_type not in self.modality_configs: |
| | return processed_segments |
| |
|
| | config = self.modality_configs[modality_type] |
| | |
| | if isinstance(batch_data, list): |
| | |
| | valid_data = [x for x in batch_data if x is not None] |
| | if not valid_data: |
| | return [] |
| | |
| | |
| | try: |
| | data_tensor = torch.stack(valid_data) |
| | except Exception as e: |
| | print(f"Error stacking modality data: {e}") |
| | return [] |
| | elif isinstance(batch_data, torch.Tensor): |
| | data_tensor = batch_data |
| | else: |
| | return [] |
| |
|
| | processed_segments.append({ |
| | 'type': modality_type, |
| | 'data': data_tensor, |
| | 'modality_id': config.modality_id |
| | }) |
| | return processed_segments |
| |
|
| |
|
| | class ExperienceReplayBuffer: |
| | def __init__(self, max_size: int = 10000): |
| | self.buffer = deque(maxlen=max_size) |
| |
|
| | def add(self, sample: Dict[str, Any]): |
| | safe_sample = {} |
| | for k, v in sample.items(): |
| | if isinstance(v, torch.Tensor): |
| | safe_sample[k] = v.detach().cpu() |
| | elif isinstance(v, list): |
| | |
| | safe_sample[k] = [x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in v] |
| | else: |
| | safe_sample[k] = v |
| | self.buffer.append(safe_sample) |
| |
|
| | def sample(self, batch_size: int) -> List[Any]: |
| | """从buffer中采样""" |
| | if not self.buffer: |
| | return [] |
| | |
| | indices = np.random.choice( |
| | len(self.buffer), |
| | min(len(self.buffer), batch_size), |
| | replace=False |
| | ) |
| | return [self.buffer[i] for i in indices] |
| |
|
| | def __len__(self): |
| | return len(self.buffer) |
| |
|
| | def clear(self): |
| | """清空buffer""" |
| | self.buffer.clear() |
| |
|
| |
|
| | class EWC: |
| | """弹性权重固化 (Elastic Weight Consolidation)""" |
| | def __init__( |
| | self, |
| | model: nn.Module, |
| | dataloader: DataLoader, |
| | preprocessor: UnifiedMultiModalPreprocessor, |
| | importance: float = 1000.0 |
| | ): |
| | self.model = model |
| | self.preprocessor = preprocessor |
| | self.importance = importance |
| | self.device = next(model.parameters()).device |
| | |
| | |
| | self.params = { |
| | n: p.clone().detach() |
| | for n, p in model.named_parameters() |
| | if p.requires_grad |
| | } |
| | |
| | self.fisher = self._compute_fisher(dataloader) |
| |
|
| | def _compute_fisher(self, dataloader: DataLoader) -> Dict[str, torch.Tensor]: |
| | """计算Fisher信息矩阵 (使用 Empirical Fisher)""" |
| | fisher = { |
| | n: torch.zeros_like(p) |
| | for n, p in self.model.named_parameters() |
| | if p.requires_grad |
| | } |
| | |
| | self.model.eval() |
| | num_samples = 0 |
| | |
| | |
| | pbar = tqdm(dataloader, desc="Computing Fisher Matrix", leave=False) |
| | for batch in pbar: |
| | if batch is None: continue |
| |
|
| | self.model.zero_grad() |
| | |
| | |
| | instruction_ids = batch['instruction'].to(self.device) |
| | response_ids = batch['response'].to(self.device) |
| | |
| | input_ids = torch.cat([instruction_ids, response_ids], dim=1) |
| | |
| | |
| | input_data = {'segments': []} |
| | |
| | |
| | raw_modality_data = batch.get('modality_data') |
| | if raw_modality_data is not None: |
| | modality_type = batch.get('modality_type', 'image') |
| | if isinstance(modality_type, list): modality_type = modality_type[0] |
| | |
| | mod_segments = self.preprocessor.process_batch(raw_modality_data, modality_type) |
| | for seg in mod_segments: |
| | seg['data'] = seg['data'].to(self.device) |
| | input_data['segments'].append(seg) |
| | |
| | input_data['segments'].append({ |
| | 'type': 'text', |
| | 'data': input_ids, |
| | 'modality_id': 0 |
| | }) |
| | |
| | output = self.model(input_data) |
| | logits = output['logits'] |
| | |
| | |
| | |
| | |
| | |
| | shift_logits = logits[:, :-1, :].contiguous() |
| | shift_labels = input_ids[:, 1:].contiguous() |
| | |
| | |
| | |
| | inst_len = instruction_ids.shape[1] |
| | loss_mask = torch.ones_like(shift_labels, dtype=torch.float) |
| | if inst_len > 1: |
| | loss_mask[:, :inst_len-1] = 0.0 |
| |
|
| | |
| | loss_fct = nn.CrossEntropyLoss(reduction='none') |
| | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
| | |
| | |
| | loss = (loss * loss_mask.view(-1)).sum() / (loss_mask.sum() + 1e-6) |
| | |
| | |
| | loss.backward() |
| | |
| | for n, p in self.model.named_parameters(): |
| | if p.grad is not None and n in fisher: |
| | fisher[n] += p.grad.detach() ** 2 |
| | |
| | num_samples += input_ids.size(0) |
| | |
| | |
| | if num_samples > 0: |
| | for n in fisher: |
| | fisher[n] /= num_samples |
| | |
| | self.model.train() |
| | return fisher |
| |
|
| | def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor: |
| | target_model = model if model is not None else self.model |
| | |
| | loss = torch.tensor(0.0, device=self.device) |
| | |
| | for n, p in target_model.named_parameters(): |
| | if n in self.params and p.requires_grad: |
| | if n in self.fisher: |
| | loss += (self.fisher[n] * (p - self.params[n]) ** 2).sum() |
| | |
| | return self.importance * loss |
| |
|
| |
|
| | class OnlineEWC(EWC): |
| | def __init__( |
| | self, |
| | model: nn.Module, |
| | preprocessor: UnifiedMultiModalPreprocessor, |
| | importance: float = 1000.0, |
| | gamma: float = 0.9 |
| | ): |
| | self.model = model |
| | self.preprocessor = preprocessor |
| | self.importance = importance |
| | self.gamma = gamma |
| | self.device = next(model.parameters()).device |
| | |
| | self.params = {} |
| | self.fisher = {} |
| | self.task_count = 0 |
| |
|
| | def update_fisher(self, dataloader: DataLoader): |
| | """更新Fisher信息矩阵""" |
| | print(f"Updating Online EWC Fisher Matrix (Task {self.task_count + 1})...") |
| | new_fisher = self._compute_fisher(dataloader) |
| | |
| | if self.task_count == 0: |
| | self.fisher = new_fisher |
| | else: |
| | for n in self.fisher: |
| | if n in new_fisher: |
| | |
| | self.fisher[n] = self.gamma * self.fisher[n] + new_fisher[n] |
| | |
| | |
| | self.params = { |
| | n: p.clone().detach() |
| | for n, p in self.model.named_parameters() |
| | if p.requires_grad |
| | } |
| | |
| | self.task_count += 1 |
| | print(f"Online EWC regularizer updated.") |
| |
|
| | def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor: |
| | """计算EWC惩罚项""" |
| | if self.task_count == 0: |
| | return torch.tensor(0.0, device=self.device) |
| | return super().penalty(model) |