MultiModal / continual_learning.py
szxllm's picture
Update continual_learning.py
4bfa065 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from collections import deque
from typing import List, Dict, Any, Optional, Union
from tqdm import tqdm
from dataclasses import dataclass
@dataclass
class ModalityConfig:
name: str
modality_id: int
class UnifiedMultiModalPreprocessor(nn.Module):
def __init__(self, model_dim: int = 2048):
super().__init__()
self.modality_configs = {
'text': ModalityConfig('text', 0),
'image': ModalityConfig('image', 1),
'audio': ModalityConfig('audio', 2),
'video': ModalityConfig('video', 3)
}
def process_batch(self, batch_data: Union[torch.Tensor, List[Any]], modality_type: str) -> List[Dict]:
processed_segments = []
if modality_type not in self.modality_configs:
return processed_segments
config = self.modality_configs[modality_type]
if isinstance(batch_data, list):
# 过滤 None
valid_data = [x for x in batch_data if x is not None]
if not valid_data:
return []
# 假设 list 中全是 Tensor,且维度一致,进行堆叠
# 如果是 list of tensor (B, C, H, W) -> stack -> (B, C, H, W)
try:
data_tensor = torch.stack(valid_data)
except Exception as e:
print(f"Error stacking modality data: {e}")
return []
elif isinstance(batch_data, torch.Tensor):
data_tensor = batch_data
else:
return []
processed_segments.append({
'type': modality_type,
'data': data_tensor,
'modality_id': config.modality_id
})
return processed_segments
class ExperienceReplayBuffer:
def __init__(self, max_size: int = 10000):
self.buffer = deque(maxlen=max_size)
def add(self, sample: Dict[str, Any]):
safe_sample = {}
for k, v in sample.items():
if isinstance(v, torch.Tensor):
safe_sample[k] = v.detach().cpu()
elif isinstance(v, list):
# 递归处理 list 中的 tensor
safe_sample[k] = [x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in v]
else:
safe_sample[k] = v
self.buffer.append(safe_sample)
def sample(self, batch_size: int) -> List[Any]:
"""从buffer中采样"""
if not self.buffer:
return []
indices = np.random.choice(
len(self.buffer),
min(len(self.buffer), batch_size),
replace=False
)
return [self.buffer[i] for i in indices]
def __len__(self):
return len(self.buffer)
def clear(self):
"""清空buffer"""
self.buffer.clear()
class EWC:
"""弹性权重固化 (Elastic Weight Consolidation)"""
def __init__(
self,
model: nn.Module,
dataloader: DataLoader,
preprocessor: UnifiedMultiModalPreprocessor,
importance: float = 1000.0
):
self.model = model
self.preprocessor = preprocessor
self.importance = importance
self.device = next(model.parameters()).device
# 冻结当前参数作为参考
self.params = {
n: p.clone().detach()
for n, p in model.named_parameters()
if p.requires_grad
}
self.fisher = self._compute_fisher(dataloader)
def _compute_fisher(self, dataloader: DataLoader) -> Dict[str, torch.Tensor]:
"""计算Fisher信息矩阵 (使用 Empirical Fisher)"""
fisher = {
n: torch.zeros_like(p)
for n, p in self.model.named_parameters()
if p.requires_grad
}
self.model.eval()
num_samples = 0
# 使用 tqdm 稍微简化输出
pbar = tqdm(dataloader, desc="Computing Fisher Matrix", leave=False)
for batch in pbar:
if batch is None: continue
self.model.zero_grad()
# 1. 准备文本输入
instruction_ids = batch['instruction'].to(self.device)
response_ids = batch['response'].to(self.device)
# 拼接: [Instruction, Response]
input_ids = torch.cat([instruction_ids, response_ids], dim=1)
# 2. 准备多模态输入结构
input_data = {'segments': []}
# 处理额外的模态数据
raw_modality_data = batch.get('modality_data')
if raw_modality_data is not None:
modality_type = batch.get('modality_type', 'image')
if isinstance(modality_type, list): modality_type = modality_type[0]
mod_segments = self.preprocessor.process_batch(raw_modality_data, modality_type)
for seg in mod_segments:
seg['data'] = seg['data'].to(self.device)
input_data['segments'].append(seg)
input_data['segments'].append({
'type': 'text',
'data': input_ids,
'modality_id': 0
})
output = self.model(input_data)
logits = output['logits']
# 4. 计算 Loss (Standard Causal LM Loss)
# Shift logits and labels
# input_ids: [I1, I2, R1, R2]
# labels: [I2, R1, R2, EOS]
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
# 创建 Mask: 只在 Response 部分计算梯度
# Instruction 长度
inst_len = instruction_ids.shape[1]
loss_mask = torch.ones_like(shift_labels, dtype=torch.float)
if inst_len > 1:
loss_mask[:, :inst_len-1] = 0.0
# 计算逐个 Token 的 Loss
loss_fct = nn.CrossEntropyLoss(reduction='none')
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# 应用 Mask 并求平均
loss = (loss * loss_mask.view(-1)).sum() / (loss_mask.sum() + 1e-6)
# 5. 反向传播累积梯度平方
loss.backward()
for n, p in self.model.named_parameters():
if p.grad is not None and n in fisher:
fisher[n] += p.grad.detach() ** 2
num_samples += input_ids.size(0)
# 平均化
if num_samples > 0:
for n in fisher:
fisher[n] /= num_samples
self.model.train()
return fisher
def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor:
target_model = model if model is not None else self.model
loss = torch.tensor(0.0, device=self.device)
for n, p in target_model.named_parameters():
if n in self.params and p.requires_grad:
if n in self.fisher:
loss += (self.fisher[n] * (p - self.params[n]) ** 2).sum()
return self.importance * loss
class OnlineEWC(EWC):
def __init__(
self,
model: nn.Module,
preprocessor: UnifiedMultiModalPreprocessor,
importance: float = 1000.0,
gamma: float = 0.9
):
self.model = model
self.preprocessor = preprocessor
self.importance = importance
self.gamma = gamma
self.device = next(model.parameters()).device
self.params = {}
self.fisher = {}
self.task_count = 0
def update_fisher(self, dataloader: DataLoader):
"""更新Fisher信息矩阵"""
print(f"Updating Online EWC Fisher Matrix (Task {self.task_count + 1})...")
new_fisher = self._compute_fisher(dataloader)
if self.task_count == 0:
self.fisher = new_fisher
else:
for n in self.fisher:
if n in new_fisher:
# 移动平均更新 Fisher 信息
self.fisher[n] = self.gamma * self.fisher[n] + new_fisher[n]
# 更新参考参数为当前任务训练后的参数
self.params = {
n: p.clone().detach()
for n, p in self.model.named_parameters()
if p.requires_grad
}
self.task_count += 1
print(f"Online EWC regularizer updated.")
def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor:
"""计算EWC惩罚项"""
if self.task_count == 0:
return torch.tensor(0.0, device=self.device)
return super().penalty(model)