szxllm commited on
Commit
4bfa065
·
verified ·
1 Parent(s): c13bd26

Update continual_learning.py

Browse files
Files changed (1) hide show
  1. continual_learning.py +260 -293
continual_learning.py CHANGED
@@ -1,294 +1,261 @@
1
- """
2
- 持续学习模块
3
- 支持EWC和经验回放
4
- 修复版本:适配 MultiModalDenseTransformer data_loader.py
5
- """
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- import numpy as np
10
- from torch.utils.data import DataLoader
11
- from collections import deque
12
- from typing import List, Dict, Any, Optional, Union
13
- from tqdm import tqdm
14
- from dataclasses import dataclass
15
-
16
- # 假设 model.py 中已有定义,用于类型提示
17
- # from model import MultiModalDenseTransformer
18
-
19
- @dataclass
20
- class ModalityConfig:
21
- name: str
22
- modality_id: int
23
-
24
- class UnifiedMultiModalPreprocessor(nn.Module):
25
- """
26
- 统一多模态预处理器
27
- 职责:仅负责将原始Batch数据格式化为 MultiModalDenseTransformer 接受的 'segments' 结构。
28
- 不再包含编码器,编码工作交由模型自身完成,以确保 EWC 能够捕捉模型参数的梯度。
29
- """
30
- def __init__(self, model_dim: int = 2048):
31
- super().__init__()
32
- self.modality_configs = {
33
- 'text': ModalityConfig('text', 0),
34
- 'image': ModalityConfig('image', 1),
35
- 'audio': ModalityConfig('audio', 2),
36
- 'video': ModalityConfig('video', 3)
37
- }
38
-
39
- def process_batch(self, batch_data: Union[torch.Tensor, List[Any]], modality_type: str) -> List[Dict]:
40
- """
41
- 将特定模态的数据封装为 segment 格式
42
- """
43
- processed_segments = []
44
- if modality_type not in self.modality_configs:
45
- return processed_segments
46
-
47
- config = self.modality_configs[modality_type]
48
-
49
- # 确保数据是 Tensor 格式
50
- if isinstance(batch_data, list):
51
- # 过滤 None
52
- valid_data = [x for x in batch_data if x is not None]
53
- if not valid_data:
54
- return []
55
- # 假设 list 中全是 Tensor,且维度一致,进行堆叠
56
- # 如果是 list of tensor (B, C, H, W) -> stack -> (B, C, H, W)
57
- try:
58
- data_tensor = torch.stack(valid_data)
59
- except Exception as e:
60
- print(f"Error stacking modality data: {e}")
61
- return []
62
- elif isinstance(batch_data, torch.Tensor):
63
- data_tensor = batch_data
64
- else:
65
- return []
66
-
67
- processed_segments.append({
68
- 'type': modality_type,
69
- 'data': data_tensor, # 保持原始数据 (如图片像素),模型内部会encode
70
- 'modality_id': config.modality_id
71
- })
72
- return processed_segments
73
-
74
-
75
- class ExperienceReplayBuffer:
76
- """经验回放缓冲区 - 内存安全版"""
77
- def __init__(self, max_size: int = 10000):
78
- self.buffer = deque(maxlen=max_size)
79
-
80
- def add(self, sample: Dict[str, Any]):
81
- """
82
- 添加样本到buffer
83
- 关键修复:将数据移至 CPU 并 detach,防止显存泄漏
84
- """
85
- safe_sample = {}
86
- for k, v in sample.items():
87
- if isinstance(v, torch.Tensor):
88
- safe_sample[k] = v.detach().cpu()
89
- elif isinstance(v, list):
90
- # 递归处理 list 中的 tensor
91
- safe_sample[k] = [x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in v]
92
- else:
93
- safe_sample[k] = v
94
- self.buffer.append(safe_sample)
95
-
96
- def sample(self, batch_size: int) -> List[Any]:
97
- """从buffer中采样"""
98
- if not self.buffer:
99
- return []
100
-
101
- indices = np.random.choice(
102
- len(self.buffer),
103
- min(len(self.buffer), batch_size),
104
- replace=False
105
- )
106
- return [self.buffer[i] for i in indices]
107
-
108
- def __len__(self):
109
- return len(self.buffer)
110
-
111
- def clear(self):
112
- """清空buffer"""
113
- self.buffer.clear()
114
-
115
-
116
- class EWC:
117
- """弹性权重固化 (Elastic Weight Consolidation)"""
118
- def __init__(
119
- self,
120
- model: nn.Module,
121
- dataloader: DataLoader,
122
- preprocessor: UnifiedMultiModalPreprocessor,
123
- importance: float = 1000.0
124
- ):
125
- self.model = model
126
- self.preprocessor = preprocessor
127
- self.importance = importance
128
- self.device = next(model.parameters()).device
129
-
130
- # 冻结当前参数作为参考
131
- self.params = {
132
- n: p.clone().detach()
133
- for n, p in model.named_parameters()
134
- if p.requires_grad
135
- }
136
-
137
- self.fisher = self._compute_fisher(dataloader)
138
-
139
- def _compute_fisher(self, dataloader: DataLoader) -> Dict[str, torch.Tensor]:
140
- """计算Fisher信息矩阵 (使用 Empirical Fisher)"""
141
- fisher = {
142
- n: torch.zeros_like(p)
143
- for n, p in self.model.named_parameters()
144
- if p.requires_grad
145
- }
146
-
147
- self.model.eval()
148
- num_samples = 0
149
-
150
- # 使用 tqdm 稍微简化输出
151
- pbar = tqdm(dataloader, desc="Computing Fisher Matrix", leave=False)
152
- for batch in pbar:
153
- if batch is None: continue
154
-
155
- self.model.zero_grad()
156
-
157
- # 1. 准备文本输入
158
- instruction_ids = batch['instruction'].to(self.device)
159
- response_ids = batch['response'].to(self.device)
160
- # 拼接: [Instruction, Response]
161
- input_ids = torch.cat([instruction_ids, response_ids], dim=1)
162
-
163
- # 2. 准备多模态输入结构
164
- input_data = {'segments': []}
165
-
166
- # 处理额外的模态数据 (如果有)
167
- # 这里的 batch['modality_data'] 可能是 list (由 collate_fn_v2 生成)
168
- raw_modality_data = batch.get('modality_data')
169
- if raw_modality_data is not None:
170
- # 尝试判断模态类型,如果 dataset 中没有明确指定,默认尝试 'image'
171
- # 实际应用中建议 dataset 返回 'modality_type'
172
- modality_type = batch.get('modality_type', 'image')
173
- if isinstance(modality_type, list): modality_type = modality_type[0]
174
-
175
- # Preprocessor 处理数据堆叠和格式化
176
- mod_segments = self.preprocessor.process_batch(raw_modality_data, modality_type)
177
- # 只有在数据有效时才传给 device
178
- for seg in mod_segments:
179
- seg['data'] = seg['data'].to(self.device)
180
- input_data['segments'].append(seg)
181
-
182
- # 添加文本 Segment
183
- input_data['segments'].append({
184
- 'type': 'text',
185
- 'data': input_ids,
186
- 'modality_id': 0
187
- })
188
-
189
- # 3. 前向传播
190
- output = self.model(input_data)
191
- logits = output['logits'] # (B, Seq_Len, Vocab)
192
-
193
- # 4. 计算 Loss (Standard Causal LM Loss)
194
- # Shift logits and labels
195
- # input_ids: [I1, I2, R1, R2]
196
- # labels: [I2, R1, R2, EOS]
197
- shift_logits = logits[:, :-1, :].contiguous()
198
- shift_labels = input_ids[:, 1:].contiguous()
199
-
200
- # 创建 Mask: 只在 Response 部分计算梯度
201
- # Instruction 长度
202
- inst_len = instruction_ids.shape[1]
203
- loss_mask = torch.ones_like(shift_labels, dtype=torch.float)
204
- if inst_len > 1:
205
- # 掩盖 Instruction 部分 (注意 shift 后的索引偏移)
206
- loss_mask[:, :inst_len-1] = 0.0
207
-
208
- # 计算逐个 Token Loss
209
- loss_fct = nn.CrossEntropyLoss(reduction='none')
210
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
211
-
212
- # 应用 Mask 并求平均
213
- loss = (loss * loss_mask.view(-1)).sum() / (loss_mask.sum() + 1e-6)
214
-
215
- # 5. 反向传播累积梯度平方
216
- loss.backward()
217
-
218
- for n, p in self.model.named_parameters():
219
- if p.grad is not None and n in fisher:
220
- fisher[n] += p.grad.detach() ** 2
221
-
222
- num_samples += input_ids.size(0)
223
-
224
- # 平均化
225
- if num_samples > 0:
226
- for n in fisher:
227
- fisher[n] /= num_samples
228
-
229
- self.model.train()
230
- return fisher
231
-
232
- def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor:
233
- """计算EWC惩罚项"""
234
- # 兼容性处理:如果传入了 model 参数,优先使用(通常 self.model 就是同一个)
235
- target_model = model if model is not None else self.model
236
-
237
- loss = torch.tensor(0.0, device=self.device)
238
-
239
- for n, p in target_model.named_parameters():
240
- if n in self.params and p.requires_grad:
241
- if n in self.fisher:
242
- loss += (self.fisher[n] * (p - self.params[n]) ** 2).sum()
243
-
244
- return self.importance * loss
245
-
246
-
247
- class OnlineEWC(EWC):
248
- """在线EWC - 支持持续更新Fisher矩阵"""
249
- def __init__(
250
- self,
251
- model: nn.Module,
252
- preprocessor: UnifiedMultiModalPreprocessor,
253
- importance: float = 1000.0,
254
- gamma: float = 0.9
255
- ):
256
- # 初始时不计算 Fisher,等待 update_fisher 调用
257
- self.model = model
258
- self.preprocessor = preprocessor
259
- self.importance = importance
260
- self.gamma = gamma
261
- self.device = next(model.parameters()).device
262
-
263
- self.params = {}
264
- self.fisher = {}
265
- self.task_count = 0
266
-
267
- def update_fisher(self, dataloader: DataLoader):
268
- """更新Fisher信息矩阵"""
269
- print(f"Updating Online EWC Fisher Matrix (Task {self.task_count + 1})...")
270
- new_fisher = self._compute_fisher(dataloader)
271
-
272
- if self.task_count == 0:
273
- self.fisher = new_fisher
274
- else:
275
- for n in self.fisher:
276
- if n in new_fisher:
277
- # 移动平均更新 Fisher 信息
278
- self.fisher[n] = self.gamma * self.fisher[n] + new_fisher[n]
279
-
280
- # 更新参考参数为当前任务训练后的参数
281
- self.params = {
282
- n: p.clone().detach()
283
- for n, p in self.model.named_parameters()
284
- if p.requires_grad
285
- }
286
-
287
- self.task_count += 1
288
- print(f"Online EWC regularizer updated.")
289
-
290
- def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor:
291
- """计算EWC惩罚项"""
292
- if self.task_count == 0:
293
- return torch.tensor(0.0, device=self.device)
294
  return super().penalty(model)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from torch.utils.data import DataLoader
6
+ from collections import deque
7
+ from typing import List, Dict, Any, Optional, Union
8
+ from tqdm import tqdm
9
+ from dataclasses import dataclass
10
+
11
+
12
+ @dataclass
13
+ class ModalityConfig:
14
+ name: str
15
+ modality_id: int
16
+
17
+ class UnifiedMultiModalPreprocessor(nn.Module):
18
+ def __init__(self, model_dim: int = 2048):
19
+ super().__init__()
20
+ self.modality_configs = {
21
+ 'text': ModalityConfig('text', 0),
22
+ 'image': ModalityConfig('image', 1),
23
+ 'audio': ModalityConfig('audio', 2),
24
+ 'video': ModalityConfig('video', 3)
25
+ }
26
+
27
+ def process_batch(self, batch_data: Union[torch.Tensor, List[Any]], modality_type: str) -> List[Dict]:
28
+ processed_segments = []
29
+ if modality_type not in self.modality_configs:
30
+ return processed_segments
31
+
32
+ config = self.modality_configs[modality_type]
33
+
34
+ if isinstance(batch_data, list):
35
+ # 过滤 None
36
+ valid_data = [x for x in batch_data if x is not None]
37
+ if not valid_data:
38
+ return []
39
+ # 假设 list 中全是 Tensor,且维度一致,进行堆叠
40
+ # 如果是 list of tensor (B, C, H, W) -> stack -> (B, C, H, W)
41
+ try:
42
+ data_tensor = torch.stack(valid_data)
43
+ except Exception as e:
44
+ print(f"Error stacking modality data: {e}")
45
+ return []
46
+ elif isinstance(batch_data, torch.Tensor):
47
+ data_tensor = batch_data
48
+ else:
49
+ return []
50
+
51
+ processed_segments.append({
52
+ 'type': modality_type,
53
+ 'data': data_tensor,
54
+ 'modality_id': config.modality_id
55
+ })
56
+ return processed_segments
57
+
58
+
59
+ class ExperienceReplayBuffer:
60
+ def __init__(self, max_size: int = 10000):
61
+ self.buffer = deque(maxlen=max_size)
62
+
63
+ def add(self, sample: Dict[str, Any]):
64
+ safe_sample = {}
65
+ for k, v in sample.items():
66
+ if isinstance(v, torch.Tensor):
67
+ safe_sample[k] = v.detach().cpu()
68
+ elif isinstance(v, list):
69
+ # 递归处理 list 中的 tensor
70
+ safe_sample[k] = [x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in v]
71
+ else:
72
+ safe_sample[k] = v
73
+ self.buffer.append(safe_sample)
74
+
75
+ def sample(self, batch_size: int) -> List[Any]:
76
+ """从buffer中采样"""
77
+ if not self.buffer:
78
+ return []
79
+
80
+ indices = np.random.choice(
81
+ len(self.buffer),
82
+ min(len(self.buffer), batch_size),
83
+ replace=False
84
+ )
85
+ return [self.buffer[i] for i in indices]
86
+
87
+ def __len__(self):
88
+ return len(self.buffer)
89
+
90
+ def clear(self):
91
+ """清空buffer"""
92
+ self.buffer.clear()
93
+
94
+
95
+ class EWC:
96
+ """弹性权重固化 (Elastic Weight Consolidation)"""
97
+ def __init__(
98
+ self,
99
+ model: nn.Module,
100
+ dataloader: DataLoader,
101
+ preprocessor: UnifiedMultiModalPreprocessor,
102
+ importance: float = 1000.0
103
+ ):
104
+ self.model = model
105
+ self.preprocessor = preprocessor
106
+ self.importance = importance
107
+ self.device = next(model.parameters()).device
108
+
109
+ # 冻结当前参数作为参考
110
+ self.params = {
111
+ n: p.clone().detach()
112
+ for n, p in model.named_parameters()
113
+ if p.requires_grad
114
+ }
115
+
116
+ self.fisher = self._compute_fisher(dataloader)
117
+
118
+ def _compute_fisher(self, dataloader: DataLoader) -> Dict[str, torch.Tensor]:
119
+ """计算Fisher信息矩阵 (使用 Empirical Fisher)"""
120
+ fisher = {
121
+ n: torch.zeros_like(p)
122
+ for n, p in self.model.named_parameters()
123
+ if p.requires_grad
124
+ }
125
+
126
+ self.model.eval()
127
+ num_samples = 0
128
+
129
+ # 使用 tqdm 稍微简化输出
130
+ pbar = tqdm(dataloader, desc="Computing Fisher Matrix", leave=False)
131
+ for batch in pbar:
132
+ if batch is None: continue
133
+
134
+ self.model.zero_grad()
135
+
136
+ # 1. 准备文本输入
137
+ instruction_ids = batch['instruction'].to(self.device)
138
+ response_ids = batch['response'].to(self.device)
139
+ # 拼接: [Instruction, Response]
140
+ input_ids = torch.cat([instruction_ids, response_ids], dim=1)
141
+
142
+ # 2. 准备多模态输入结构
143
+ input_data = {'segments': []}
144
+
145
+ # 处理额外的模态数据
146
+ raw_modality_data = batch.get('modality_data')
147
+ if raw_modality_data is not None:
148
+ modality_type = batch.get('modality_type', 'image')
149
+ if isinstance(modality_type, list): modality_type = modality_type[0]
150
+
151
+ mod_segments = self.preprocessor.process_batch(raw_modality_data, modality_type)
152
+ for seg in mod_segments:
153
+ seg['data'] = seg['data'].to(self.device)
154
+ input_data['segments'].append(seg)
155
+
156
+ input_data['segments'].append({
157
+ 'type': 'text',
158
+ 'data': input_ids,
159
+ 'modality_id': 0
160
+ })
161
+
162
+ output = self.model(input_data)
163
+ logits = output['logits']
164
+
165
+ # 4. 计算 Loss (Standard Causal LM Loss)
166
+ # Shift logits and labels
167
+ # input_ids: [I1, I2, R1, R2]
168
+ # labels: [I2, R1, R2, EOS]
169
+ shift_logits = logits[:, :-1, :].contiguous()
170
+ shift_labels = input_ids[:, 1:].contiguous()
171
+
172
+ # 创建 Mask: 只在 Response 部分计算梯度
173
+ # Instruction 长度
174
+ inst_len = instruction_ids.shape[1]
175
+ loss_mask = torch.ones_like(shift_labels, dtype=torch.float)
176
+ if inst_len > 1:
177
+ loss_mask[:, :inst_len-1] = 0.0
178
+
179
+ # 计算逐个 Token 的 Loss
180
+ loss_fct = nn.CrossEntropyLoss(reduction='none')
181
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
182
+
183
+ # 应用 Mask 并求平均
184
+ loss = (loss * loss_mask.view(-1)).sum() / (loss_mask.sum() + 1e-6)
185
+
186
+ # 5. 反向传播累积梯度平方
187
+ loss.backward()
188
+
189
+ for n, p in self.model.named_parameters():
190
+ if p.grad is not None and n in fisher:
191
+ fisher[n] += p.grad.detach() ** 2
192
+
193
+ num_samples += input_ids.size(0)
194
+
195
+ # 平均化
196
+ if num_samples > 0:
197
+ for n in fisher:
198
+ fisher[n] /= num_samples
199
+
200
+ self.model.train()
201
+ return fisher
202
+
203
+ def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor:
204
+ target_model = model if model is not None else self.model
205
+
206
+ loss = torch.tensor(0.0, device=self.device)
207
+
208
+ for n, p in target_model.named_parameters():
209
+ if n in self.params and p.requires_grad:
210
+ if n in self.fisher:
211
+ loss += (self.fisher[n] * (p - self.params[n]) ** 2).sum()
212
+
213
+ return self.importance * loss
214
+
215
+
216
+ class OnlineEWC(EWC):
217
+ def __init__(
218
+ self,
219
+ model: nn.Module,
220
+ preprocessor: UnifiedMultiModalPreprocessor,
221
+ importance: float = 1000.0,
222
+ gamma: float = 0.9
223
+ ):
224
+ self.model = model
225
+ self.preprocessor = preprocessor
226
+ self.importance = importance
227
+ self.gamma = gamma
228
+ self.device = next(model.parameters()).device
229
+
230
+ self.params = {}
231
+ self.fisher = {}
232
+ self.task_count = 0
233
+
234
+ def update_fisher(self, dataloader: DataLoader):
235
+ """更新Fisher信息矩阵"""
236
+ print(f"Updating Online EWC Fisher Matrix (Task {self.task_count + 1})...")
237
+ new_fisher = self._compute_fisher(dataloader)
238
+
239
+ if self.task_count == 0:
240
+ self.fisher = new_fisher
241
+ else:
242
+ for n in self.fisher:
243
+ if n in new_fisher:
244
+ # 移动平均更新 Fisher 信息
245
+ self.fisher[n] = self.gamma * self.fisher[n] + new_fisher[n]
246
+
247
+ # 更新参考参数为当前任务训练后的参数
248
+ self.params = {
249
+ n: p.clone().detach()
250
+ for n, p in self.model.named_parameters()
251
+ if p.requires_grad
252
+ }
253
+
254
+ self.task_count += 1
255
+ print(f"Online EWC regularizer updated.")
256
+
257
+ def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor:
258
+ """计算EWC惩罚项"""
259
+ if self.task_count == 0:
260
+ return torch.tensor(0.0, device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  return super().penalty(model)