Auraithm commited on
Commit
898c2ea
·
verified ·
1 Parent(s): 89225da

Update modeling_sdar.py

Browse files
Files changed (1) hide show
  1. modeling_sdar.py +2 -152
modeling_sdar.py CHANGED
@@ -77,10 +77,8 @@ def modify_padded_position_ids_2d(position_ids: torch.LongTensor) -> torch.LongT
77
  使用完全向量化的 PyTorch 操作修改一个 batch 的 packed position_ids。
78
  这个函数假设输入是一个 2D Tensor,形状为 (batch_size, sequence_length)。
79
  它会独立地处理 batch 中的每一行。
80
-
81
  Args:
82
  position_ids: 二维 PyTorch Tensor, shape (batch_size, sequence_length).
83
-
84
  Returns:
85
  修改后的 position_ids Tensor, shape (batch_size, sequence_length).
86
  """
@@ -108,7 +106,6 @@ def modify_padded_position_ids_2d(position_ids: torch.LongTensor) -> torch.LongT
108
  def calculate_token_nums(position_ids: torch.Tensor):
109
  """
110
  使用 PyTorch 高效计算一个批次中每个打包序列的长度。
111
-
112
  Args:
113
  position_ids (torch.Tensor): 一个 2D Tensor,形状为 (batch_size, sequence_length)。
114
  例如:tensor([[0,1,2,3,4,0,1,2,3,4,5,0,1,2,3,0,0,0]])
@@ -162,11 +159,9 @@ def forward_add_noise_packed(
162
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
163
  """
164
  为一批打包(packed)序列的 token ID 添加噪声。
165
-
166
  此函数保留了为每个逻辑样本(在每个批次项内拼接)生成独立随机噪声率的逻辑。
167
  它会随机将一部分 token 的 ID 替换为 mask_id。
168
  这个过程会避开被 prompt_mask 标记的位置。
169
-
170
  Args:
171
  inputs_ids (torch.Tensor):
172
  输入的 token ID 张量,形状为 (bsz, total_tokens)。
@@ -182,7 +177,6 @@ def forward_add_noise_packed(
182
  微小值,用于防止噪声率 t 恰好为 0,确保 p_mask > 0。
183
  max_tries (int):
184
  为确保至少一个非 prompt token 被 mask,对每个批次项尝试的最大次数。
185
-
186
  Returns:
187
  Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
188
  - noisy_input_ids (torch.Tensor):
@@ -290,13 +284,11 @@ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
290
  - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
291
  - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
292
  - **Block Causal Mask (M_BC)**: Attention to update x0
293
-
294
  Args:
295
  b, h: Batch and head indices (ignored for mask logic).
296
  q_idx, kv_idx: Query and Key indices.
297
  seq_len: Total sequence length.
298
  block_size: Defines the block structure.
299
-
300
  Returns:
301
  A boolean attention mask.
302
  """
@@ -410,7 +402,6 @@ def rotate_half(x):
410
 
411
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
412
  """Applies Rotary Position Embedding to the query and key tensors.
413
-
414
  Args:
415
  q (`torch.Tensor`): The query tensor.
416
  k (`torch.Tensor`): The key tensor.
@@ -970,7 +961,6 @@ class SDARModel(SDARPreTrainedModel):
970
  """
971
  Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
972
  `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
973
-
974
  Args:
975
  attention_mask (`torch.Tensor`):
976
  A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
@@ -1160,7 +1150,6 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
1160
  compute_rl_loss: bool = False,
1161
  p_mask: Optional[torch.Tensor] = None,
1162
  adv: Optional[torch.Tensor] = None,
1163
- adv_optimization: bool = False,
1164
  logp_old_tok: Optional[torch.Tensor] = None,
1165
  logp_ref_tok: Optional[torch.Tensor] = None,
1166
  is_real: Optional[torch.Tensor] = None,
@@ -1237,12 +1226,6 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
1237
 
1238
  # 选出 logits — 保持原样
1239
  logits_p = logits[p_to_keep_real] # (N, V)
1240
- N = p_to_keep_real.sum().item()
1241
- total_response_tokens = (labels != -100).sum().item()
1242
- total_p_mask = p_mask.sum().item()
1243
- total_masked_indices = masked_indices.sum().item()
1244
- total_is_real = is_real_tensor.sum().item() if is_real_tensor.dim() > 0 else (1 if is_real_tensor.item() else 0)
1245
-
1246
 
1247
  # log_softmax
1248
  log_probs_p = torch.nn.functional.log_softmax(logits_p, dim=-1)
@@ -1260,133 +1243,7 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
1260
 
1261
  # advantage 处理
1262
  adv_tensor = adv.to(device) if torch.is_tensor(adv) else torch.tensor(adv, dtype=torch.float, device=device)
1263
- adv_optimization=False
1264
- if adv_optimization:
1265
- # token级别优化:对相同前缀取最大advantage(剪枝优化版本)
1266
- response_mask = (labels != -100) # (B, L)
1267
- bsz, seq_len = input_ids.shape
1268
-
1269
- # 预计算每个样本的response起始位置
1270
- response_starts = torch.full((bsz,), seq_len, dtype=torch.long, device=device)
1271
- for b in range(bsz):
1272
- if response_mask[b].any():
1273
- response_starts[b] = response_mask[b].long().argmax()
1274
-
1275
- # 剪枝1: 找出已经是最大advantage的样本,直接填充不参与比较
1276
- max_adv_value = adv_tensor.max()
1277
- is_max_adv = (adv_tensor == max_adv_value) # (B,) bool
1278
-
1279
- # 创建优化后的 advantage map (B, L),确保dtype与adv_tensor一致
1280
- optimized_adv = torch.zeros_like(labels, dtype=adv_tensor.dtype)
1281
-
1282
- # 对于已是最大advantage的样本,直接填充
1283
- for b in range(bsz):
1284
- if is_max_adv[b]:
1285
- optimized_adv[b][response_mask[b]] = max_adv_value
1286
-
1287
- # 统计信息
1288
- total_response_tokens = 0
1289
- updated_tokens = 0
1290
- skipped_tokens = 0
1291
- original_adv_sum = 0.0
1292
- optimized_adv_sum = 0.0
1293
-
1294
- # 按position处理,批量比较前缀
1295
- for pos in range(seq_len):
1296
- valid_samples = response_mask[:, pos] # (B,)
1297
- if not valid_samples.any():
1298
- continue
1299
-
1300
- # 剪枝2: 排除已是最大advantage的样本
1301
- valid_samples = valid_samples & ~is_max_adv
1302
- if not valid_samples.any():
1303
- # 所有样本都是最大值,统计后跳过
1304
- max_count = (response_mask[:, pos] & is_max_adv).sum().item()
1305
- total_response_tokens += max_count
1306
- skipped_tokens += max_count
1307
- original_adv_sum += max_adv_value.item() * max_count
1308
- optimized_adv_sum += max_adv_value.item() * max_count
1309
- continue
1310
-
1311
- # 获取所有需要处理的样本索引
1312
- valid_indices = valid_samples.nonzero(as_tuple=True)[0] # (N,)
1313
-
1314
- for b in valid_indices:
1315
- b_item = b.item()
1316
- response_start = response_starts[b_item].item()
1317
- prefix_len = pos + 1 - response_start
1318
-
1319
- if prefix_len <= 0:
1320
- optimized_adv[b_item, pos] = adv_tensor[b_item]
1321
- continue
1322
-
1323
- # 找出所有response起始位置相同且在pos位置有效的样本(包括已是最大值的)
1324
- same_start_mask = (response_starts == response_start) & response_mask[:, pos]
1325
- same_start_indices = same_start_mask.nonzero(as_tuple=True)[0]
1326
-
1327
- if len(same_start_indices) == 1:
1328
- # 只有自己,不需要比较
1329
- optimized_adv[b_item, pos] = adv_tensor[b_item]
1330
- total_response_tokens += 1
1331
- original_adv_sum += adv_tensor[b_item].item()
1332
- optimized_adv_sum += adv_tensor[b_item].item()
1333
- continue
1334
-
1335
- # 剪枝3: 如果候选中有最大advantage样本,可以直接用最大值
1336
- has_max_in_candidates = (same_start_mask & is_max_adv).any()
1337
-
1338
- prefix_end = pos + 1
1339
- current_prefix = input_ids[b_item, response_start:prefix_end]
1340
-
1341
- # 批量比较:提取所有候选样本的前缀
1342
- prefixes = input_ids[same_start_indices, response_start:prefix_end] # (M, prefix_len)
1343
-
1344
- # 使用广播比较:(M, prefix_len) vs (prefix_len,)
1345
- matches = (prefixes == current_prefix.unsqueeze(0)).all(dim=1) # (M,)
1346
-
1347
- # 找到匹配的样本
1348
- matching_indices = same_start_indices[matches]
1349
-
1350
- # 在相同前缀的样本中取最大 advantage
1351
- original_adv_value = adv_tensor[b_item].item()
1352
- if matching_indices.numel() > 0:
1353
- # 剪枝4: 如果匹配中有最大值样本,直接用最大值
1354
- if has_max_in_candidates and is_max_adv[matching_indices].any():
1355
- max_adv = max_adv_value
1356
- else:
1357
- max_adv = adv_tensor[matching_indices].max()
1358
-
1359
- optimized_adv[b_item, pos] = max_adv
1360
- # 统计
1361
- if abs(max_adv.item() - original_adv_value) > 1e-6:
1362
- updated_tokens += 1
1363
- original_adv_sum += original_adv_value
1364
- optimized_adv_sum += max_adv.item()
1365
- else:
1366
- optimized_adv[b_item, pos] = adv_tensor[b_item]
1367
- original_adv_sum += original_adv_value
1368
- optimized_adv_sum += original_adv_value
1369
-
1370
- total_response_tokens += 1
1371
-
1372
- # 输出统计信息
1373
- if total_response_tokens > 0:
1374
- update_ratio = updated_tokens / total_response_tokens
1375
- skip_ratio = skipped_tokens / total_response_tokens
1376
- avg_original = original_adv_sum / total_response_tokens
1377
- avg_optimized = optimized_adv_sum / total_response_tokens
1378
- print(f"[Adv Optimization] Total: {total_response_tokens}, "
1379
- f"Updated: {updated_tokens} ({update_ratio:.2%}), "
1380
- f"Skipped: {skipped_tokens} ({skip_ratio:.2%}), "
1381
- f"Avg adv: {avg_original:.4f} -> {avg_optimized:.4f} "
1382
- f"(+{avg_optimized - avg_original:.4f})")
1383
-
1384
- # 使用优化后的 advantage
1385
- adv_expanded = optimized_adv
1386
- else:
1387
- # 不优化:直接使用原始 advantage
1388
- adv_expanded = adv_tensor.unsqueeze(1).expand_as(p_mask)
1389
-
1390
  adv_p = adv_expanded[masked_indices][p_to_keep_real]
1391
 
1392
  # old logp
@@ -1394,20 +1251,13 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
1394
  logp_old_p = logp_old_tok.to(device)[masked_indices][p_to_keep_real]
1395
  else:
1396
  logp_old_p = logp_p.detach()
1397
-
1398
  # ratio/exp
1399
  ratio_p = (logp_p - logp_old_p).clamp(-10.0, 10.0).exp()
1400
  clipped = ratio_p.clamp(1 - ppo_eps, 1 + ppo_eps+0.08)
1401
  surrogate_p = torch.minimum(ratio_p * adv_p, clipped * adv_p)
1402
- # 输出离1最远的ratio值
1403
- # if not torch.allclose(ratio_p, torch.ones_like(ratio_p)):
1404
- furthest_value = ratio_p[torch.abs(ratio_p - 1).argmax()]
1405
- # print(f"Furthest ratio from 1: {furthest_value.item()}")
1406
 
1407
  # Policy loss: use mean or sum based on loss_mean parameter
1408
- num_masked = masked_indices.sum().item()
1409
- num_loss_elements = surrogate_p.numel()
1410
- print(f"masked_indices.sum()={num_masked}, surrogate_p.numel()={num_loss_elements}")
1411
  if loss_mean:
1412
  policy_loss = -surrogate_p.mean()
1413
  else:
 
77
  使用完全向量化的 PyTorch 操作修改一个 batch 的 packed position_ids。
78
  这个函数假设输入是一个 2D Tensor,形状为 (batch_size, sequence_length)。
79
  它会独立地处理 batch 中的每一行。
 
80
  Args:
81
  position_ids: 二维 PyTorch Tensor, shape (batch_size, sequence_length).
 
82
  Returns:
83
  修改后的 position_ids Tensor, shape (batch_size, sequence_length).
84
  """
 
106
  def calculate_token_nums(position_ids: torch.Tensor):
107
  """
108
  使用 PyTorch 高效计算一个批次中每个打包序列的长度。
 
109
  Args:
110
  position_ids (torch.Tensor): 一个 2D Tensor,形状为 (batch_size, sequence_length)。
111
  例如:tensor([[0,1,2,3,4,0,1,2,3,4,5,0,1,2,3,0,0,0]])
 
159
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
160
  """
161
  为一批打包(packed)序列的 token ID 添加噪声。
 
162
  此函数保留了为每个逻辑样本(在每个批次项内拼接)生成独立随机噪声率的逻辑。
163
  它会随机将一部分 token 的 ID 替换为 mask_id。
164
  这个过程会避开被 prompt_mask 标记的位置。
 
165
  Args:
166
  inputs_ids (torch.Tensor):
167
  输入的 token ID 张量,形状为 (bsz, total_tokens)。
 
177
  微小值,用于防止噪声率 t 恰好为 0,确保 p_mask > 0。
178
  max_tries (int):
179
  为确保至少一个非 prompt token 被 mask,对每个批次项尝试的最大次数。
 
180
  Returns:
181
  Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
182
  - noisy_input_ids (torch.Tensor):
 
284
  - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
285
  - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
286
  - **Block Causal Mask (M_BC)**: Attention to update x0
 
287
  Args:
288
  b, h: Batch and head indices (ignored for mask logic).
289
  q_idx, kv_idx: Query and Key indices.
290
  seq_len: Total sequence length.
291
  block_size: Defines the block structure.
 
292
  Returns:
293
  A boolean attention mask.
294
  """
 
402
 
403
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
404
  """Applies Rotary Position Embedding to the query and key tensors.
 
405
  Args:
406
  q (`torch.Tensor`): The query tensor.
407
  k (`torch.Tensor`): The key tensor.
 
961
  """
962
  Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
963
  `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
 
964
  Args:
965
  attention_mask (`torch.Tensor`):
966
  A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
 
1150
  compute_rl_loss: bool = False,
1151
  p_mask: Optional[torch.Tensor] = None,
1152
  adv: Optional[torch.Tensor] = None,
 
1153
  logp_old_tok: Optional[torch.Tensor] = None,
1154
  logp_ref_tok: Optional[torch.Tensor] = None,
1155
  is_real: Optional[torch.Tensor] = None,
 
1226
 
1227
  # 选出 logits — 保持原样
1228
  logits_p = logits[p_to_keep_real] # (N, V)
 
 
 
 
 
 
1229
 
1230
  # log_softmax
1231
  log_probs_p = torch.nn.functional.log_softmax(logits_p, dim=-1)
 
1243
 
1244
  # advantage 处理
1245
  adv_tensor = adv.to(device) if torch.is_tensor(adv) else torch.tensor(adv, dtype=torch.float, device=device)
1246
+ adv_expanded = adv_tensor.unsqueeze(1).expand_as(p_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1247
  adv_p = adv_expanded[masked_indices][p_to_keep_real]
1248
 
1249
  # old logp
 
1251
  logp_old_p = logp_old_tok.to(device)[masked_indices][p_to_keep_real]
1252
  else:
1253
  logp_old_p = logp_p.detach()
1254
+
1255
  # ratio/exp
1256
  ratio_p = (logp_p - logp_old_p).clamp(-10.0, 10.0).exp()
1257
  clipped = ratio_p.clamp(1 - ppo_eps, 1 + ppo_eps+0.08)
1258
  surrogate_p = torch.minimum(ratio_p * adv_p, clipped * adv_p)
 
 
 
 
1259
 
1260
  # Policy loss: use mean or sum based on loss_mean parameter
 
 
 
1261
  if loss_mean:
1262
  policy_loss = -surrogate_p.mean()
1263
  else: