import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from torch.utils.data import DataLoader from collections import deque from typing import List, Dict, Any, Optional, Union from tqdm import tqdm from dataclasses import dataclass @dataclass class ModalityConfig: name: str modality_id: int class UnifiedMultiModalPreprocessor(nn.Module): def __init__(self, model_dim: int = 2048): super().__init__() self.modality_configs = { 'text': ModalityConfig('text', 0), 'image': ModalityConfig('image', 1), 'audio': ModalityConfig('audio', 2), 'video': ModalityConfig('video', 3) } def process_batch(self, batch_data: Union[torch.Tensor, List[Any]], modality_type: str) -> List[Dict]: processed_segments = [] if modality_type not in self.modality_configs: return processed_segments config = self.modality_configs[modality_type] if isinstance(batch_data, list): # 过滤 None valid_data = [x for x in batch_data if x is not None] if not valid_data: return [] # 假设 list 中全是 Tensor,且维度一致,进行堆叠 # 如果是 list of tensor (B, C, H, W) -> stack -> (B, C, H, W) try: data_tensor = torch.stack(valid_data) except Exception as e: print(f"Error stacking modality data: {e}") return [] elif isinstance(batch_data, torch.Tensor): data_tensor = batch_data else: return [] processed_segments.append({ 'type': modality_type, 'data': data_tensor, 'modality_id': config.modality_id }) return processed_segments class ExperienceReplayBuffer: def __init__(self, max_size: int = 10000): self.buffer = deque(maxlen=max_size) def add(self, sample: Dict[str, Any]): safe_sample = {} for k, v in sample.items(): if isinstance(v, torch.Tensor): safe_sample[k] = v.detach().cpu() elif isinstance(v, list): # 递归处理 list 中的 tensor safe_sample[k] = [x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in v] else: safe_sample[k] = v self.buffer.append(safe_sample) def sample(self, batch_size: int) -> List[Any]: """从buffer中采样""" if not self.buffer: return [] indices = np.random.choice( len(self.buffer), min(len(self.buffer), batch_size), replace=False ) return [self.buffer[i] for i in indices] def __len__(self): return len(self.buffer) def clear(self): """清空buffer""" self.buffer.clear() class EWC: """弹性权重固化 (Elastic Weight Consolidation)""" def __init__( self, model: nn.Module, dataloader: DataLoader, preprocessor: UnifiedMultiModalPreprocessor, importance: float = 1000.0 ): self.model = model self.preprocessor = preprocessor self.importance = importance self.device = next(model.parameters()).device # 冻结当前参数作为参考 self.params = { n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad } self.fisher = self._compute_fisher(dataloader) def _compute_fisher(self, dataloader: DataLoader) -> Dict[str, torch.Tensor]: """计算Fisher信息矩阵 (使用 Empirical Fisher)""" fisher = { n: torch.zeros_like(p) for n, p in self.model.named_parameters() if p.requires_grad } self.model.eval() num_samples = 0 # 使用 tqdm 稍微简化输出 pbar = tqdm(dataloader, desc="Computing Fisher Matrix", leave=False) for batch in pbar: if batch is None: continue self.model.zero_grad() # 1. 准备文本输入 instruction_ids = batch['instruction'].to(self.device) response_ids = batch['response'].to(self.device) # 拼接: [Instruction, Response] input_ids = torch.cat([instruction_ids, response_ids], dim=1) # 2. 准备多模态输入结构 input_data = {'segments': []} # 处理额外的模态数据 raw_modality_data = batch.get('modality_data') if raw_modality_data is not None: modality_type = batch.get('modality_type', 'image') if isinstance(modality_type, list): modality_type = modality_type[0] mod_segments = self.preprocessor.process_batch(raw_modality_data, modality_type) for seg in mod_segments: seg['data'] = seg['data'].to(self.device) input_data['segments'].append(seg) input_data['segments'].append({ 'type': 'text', 'data': input_ids, 'modality_id': 0 }) output = self.model(input_data) logits = output['logits'] # 4. 计算 Loss (Standard Causal LM Loss) # Shift logits and labels # input_ids: [I1, I2, R1, R2] # labels: [I2, R1, R2, EOS] shift_logits = logits[:, :-1, :].contiguous() shift_labels = input_ids[:, 1:].contiguous() # 创建 Mask: 只在 Response 部分计算梯度 # Instruction 长度 inst_len = instruction_ids.shape[1] loss_mask = torch.ones_like(shift_labels, dtype=torch.float) if inst_len > 1: loss_mask[:, :inst_len-1] = 0.0 # 计算逐个 Token 的 Loss loss_fct = nn.CrossEntropyLoss(reduction='none') loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # 应用 Mask 并求平均 loss = (loss * loss_mask.view(-1)).sum() / (loss_mask.sum() + 1e-6) # 5. 反向传播累积梯度平方 loss.backward() for n, p in self.model.named_parameters(): if p.grad is not None and n in fisher: fisher[n] += p.grad.detach() ** 2 num_samples += input_ids.size(0) # 平均化 if num_samples > 0: for n in fisher: fisher[n] /= num_samples self.model.train() return fisher def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor: target_model = model if model is not None else self.model loss = torch.tensor(0.0, device=self.device) for n, p in target_model.named_parameters(): if n in self.params and p.requires_grad: if n in self.fisher: loss += (self.fisher[n] * (p - self.params[n]) ** 2).sum() return self.importance * loss class OnlineEWC(EWC): def __init__( self, model: nn.Module, preprocessor: UnifiedMultiModalPreprocessor, importance: float = 1000.0, gamma: float = 0.9 ): self.model = model self.preprocessor = preprocessor self.importance = importance self.gamma = gamma self.device = next(model.parameters()).device self.params = {} self.fisher = {} self.task_count = 0 def update_fisher(self, dataloader: DataLoader): """更新Fisher信息矩阵""" print(f"Updating Online EWC Fisher Matrix (Task {self.task_count + 1})...") new_fisher = self._compute_fisher(dataloader) if self.task_count == 0: self.fisher = new_fisher else: for n in self.fisher: if n in new_fisher: # 移动平均更新 Fisher 信息 self.fisher[n] = self.gamma * self.fisher[n] + new_fisher[n] # 更新参考参数为当前任务训练后的参数 self.params = { n: p.clone().detach() for n, p in self.model.named_parameters() if p.requires_grad } self.task_count += 1 print(f"Online EWC regularizer updated.") def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor: """计算EWC惩罚项""" if self.task_count == 0: return torch.tensor(0.0, device=self.device) return super().penalty(model)