File size: 8,929 Bytes
4bfa065 cd66851 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from collections import deque
from typing import List, Dict, Any, Optional, Union
from tqdm import tqdm
from dataclasses import dataclass
@dataclass
class ModalityConfig:
name: str
modality_id: int
class UnifiedMultiModalPreprocessor(nn.Module):
def __init__(self, model_dim: int = 2048):
super().__init__()
self.modality_configs = {
'text': ModalityConfig('text', 0),
'image': ModalityConfig('image', 1),
'audio': ModalityConfig('audio', 2),
'video': ModalityConfig('video', 3)
}
def process_batch(self, batch_data: Union[torch.Tensor, List[Any]], modality_type: str) -> List[Dict]:
processed_segments = []
if modality_type not in self.modality_configs:
return processed_segments
config = self.modality_configs[modality_type]
if isinstance(batch_data, list):
# 过滤 None
valid_data = [x for x in batch_data if x is not None]
if not valid_data:
return []
# 假设 list 中全是 Tensor,且维度一致,进行堆叠
# 如果是 list of tensor (B, C, H, W) -> stack -> (B, C, H, W)
try:
data_tensor = torch.stack(valid_data)
except Exception as e:
print(f"Error stacking modality data: {e}")
return []
elif isinstance(batch_data, torch.Tensor):
data_tensor = batch_data
else:
return []
processed_segments.append({
'type': modality_type,
'data': data_tensor,
'modality_id': config.modality_id
})
return processed_segments
class ExperienceReplayBuffer:
def __init__(self, max_size: int = 10000):
self.buffer = deque(maxlen=max_size)
def add(self, sample: Dict[str, Any]):
safe_sample = {}
for k, v in sample.items():
if isinstance(v, torch.Tensor):
safe_sample[k] = v.detach().cpu()
elif isinstance(v, list):
# 递归处理 list 中的 tensor
safe_sample[k] = [x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in v]
else:
safe_sample[k] = v
self.buffer.append(safe_sample)
def sample(self, batch_size: int) -> List[Any]:
"""从buffer中采样"""
if not self.buffer:
return []
indices = np.random.choice(
len(self.buffer),
min(len(self.buffer), batch_size),
replace=False
)
return [self.buffer[i] for i in indices]
def __len__(self):
return len(self.buffer)
def clear(self):
"""清空buffer"""
self.buffer.clear()
class EWC:
"""弹性权重固化 (Elastic Weight Consolidation)"""
def __init__(
self,
model: nn.Module,
dataloader: DataLoader,
preprocessor: UnifiedMultiModalPreprocessor,
importance: float = 1000.0
):
self.model = model
self.preprocessor = preprocessor
self.importance = importance
self.device = next(model.parameters()).device
# 冻结当前参数作为参考
self.params = {
n: p.clone().detach()
for n, p in model.named_parameters()
if p.requires_grad
}
self.fisher = self._compute_fisher(dataloader)
def _compute_fisher(self, dataloader: DataLoader) -> Dict[str, torch.Tensor]:
"""计算Fisher信息矩阵 (使用 Empirical Fisher)"""
fisher = {
n: torch.zeros_like(p)
for n, p in self.model.named_parameters()
if p.requires_grad
}
self.model.eval()
num_samples = 0
# 使用 tqdm 稍微简化输出
pbar = tqdm(dataloader, desc="Computing Fisher Matrix", leave=False)
for batch in pbar:
if batch is None: continue
self.model.zero_grad()
# 1. 准备文本输入
instruction_ids = batch['instruction'].to(self.device)
response_ids = batch['response'].to(self.device)
# 拼接: [Instruction, Response]
input_ids = torch.cat([instruction_ids, response_ids], dim=1)
# 2. 准备多模态输入结构
input_data = {'segments': []}
# 处理额外的模态数据
raw_modality_data = batch.get('modality_data')
if raw_modality_data is not None:
modality_type = batch.get('modality_type', 'image')
if isinstance(modality_type, list): modality_type = modality_type[0]
mod_segments = self.preprocessor.process_batch(raw_modality_data, modality_type)
for seg in mod_segments:
seg['data'] = seg['data'].to(self.device)
input_data['segments'].append(seg)
input_data['segments'].append({
'type': 'text',
'data': input_ids,
'modality_id': 0
})
output = self.model(input_data)
logits = output['logits']
# 4. 计算 Loss (Standard Causal LM Loss)
# Shift logits and labels
# input_ids: [I1, I2, R1, R2]
# labels: [I2, R1, R2, EOS]
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
# 创建 Mask: 只在 Response 部分计算梯度
# Instruction 长度
inst_len = instruction_ids.shape[1]
loss_mask = torch.ones_like(shift_labels, dtype=torch.float)
if inst_len > 1:
loss_mask[:, :inst_len-1] = 0.0
# 计算逐个 Token 的 Loss
loss_fct = nn.CrossEntropyLoss(reduction='none')
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# 应用 Mask 并求平均
loss = (loss * loss_mask.view(-1)).sum() / (loss_mask.sum() + 1e-6)
# 5. 反向传播累积梯度平方
loss.backward()
for n, p in self.model.named_parameters():
if p.grad is not None and n in fisher:
fisher[n] += p.grad.detach() ** 2
num_samples += input_ids.size(0)
# 平均化
if num_samples > 0:
for n in fisher:
fisher[n] /= num_samples
self.model.train()
return fisher
def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor:
target_model = model if model is not None else self.model
loss = torch.tensor(0.0, device=self.device)
for n, p in target_model.named_parameters():
if n in self.params and p.requires_grad:
if n in self.fisher:
loss += (self.fisher[n] * (p - self.params[n]) ** 2).sum()
return self.importance * loss
class OnlineEWC(EWC):
def __init__(
self,
model: nn.Module,
preprocessor: UnifiedMultiModalPreprocessor,
importance: float = 1000.0,
gamma: float = 0.9
):
self.model = model
self.preprocessor = preprocessor
self.importance = importance
self.gamma = gamma
self.device = next(model.parameters()).device
self.params = {}
self.fisher = {}
self.task_count = 0
def update_fisher(self, dataloader: DataLoader):
"""更新Fisher信息矩阵"""
print(f"Updating Online EWC Fisher Matrix (Task {self.task_count + 1})...")
new_fisher = self._compute_fisher(dataloader)
if self.task_count == 0:
self.fisher = new_fisher
else:
for n in self.fisher:
if n in new_fisher:
# 移动平均更新 Fisher 信息
self.fisher[n] = self.gamma * self.fisher[n] + new_fisher[n]
# 更新参考参数为当前任务训练后的参数
self.params = {
n: p.clone().detach()
for n, p in self.model.named_parameters()
if p.requires_grad
}
self.task_count += 1
print(f"Online EWC regularizer updated.")
def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor:
"""计算EWC惩罚项"""
if self.task_count == 0:
return torch.tensor(0.0, device=self.device)
return super().penalty(model) |