Update modeling_sdar.py
Browse files- modeling_sdar.py +2 -152
modeling_sdar.py
CHANGED
|
@@ -77,10 +77,8 @@ def modify_padded_position_ids_2d(position_ids: torch.LongTensor) -> torch.LongT
|
|
| 77 |
使用完全向量化的 PyTorch 操作修改一个 batch 的 packed position_ids。
|
| 78 |
这个函数假设输入是一个 2D Tensor,形状为 (batch_size, sequence_length)。
|
| 79 |
它会独立地处理 batch 中的每一行。
|
| 80 |
-
|
| 81 |
Args:
|
| 82 |
position_ids: 二维 PyTorch Tensor, shape (batch_size, sequence_length).
|
| 83 |
-
|
| 84 |
Returns:
|
| 85 |
修改后的 position_ids Tensor, shape (batch_size, sequence_length).
|
| 86 |
"""
|
|
@@ -108,7 +106,6 @@ def modify_padded_position_ids_2d(position_ids: torch.LongTensor) -> torch.LongT
|
|
| 108 |
def calculate_token_nums(position_ids: torch.Tensor):
|
| 109 |
"""
|
| 110 |
使用 PyTorch 高效计算一个批次中每个打包序列的长度。
|
| 111 |
-
|
| 112 |
Args:
|
| 113 |
position_ids (torch.Tensor): 一个 2D Tensor,形状为 (batch_size, sequence_length)。
|
| 114 |
例如:tensor([[0,1,2,3,4,0,1,2,3,4,5,0,1,2,3,0,0,0]])
|
|
@@ -162,11 +159,9 @@ def forward_add_noise_packed(
|
|
| 162 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 163 |
"""
|
| 164 |
为一批打包(packed)序列的 token ID 添加噪声。
|
| 165 |
-
|
| 166 |
此函数保留了为每个逻辑样本(在每个批次项内拼接)生成独立随机噪声率的逻辑。
|
| 167 |
它会随机将一部分 token 的 ID 替换为 mask_id。
|
| 168 |
这个过程会避开被 prompt_mask 标记的位置。
|
| 169 |
-
|
| 170 |
Args:
|
| 171 |
inputs_ids (torch.Tensor):
|
| 172 |
输入的 token ID 张量,形状为 (bsz, total_tokens)。
|
|
@@ -182,7 +177,6 @@ def forward_add_noise_packed(
|
|
| 182 |
微小值,用于防止噪声率 t 恰好为 0,确保 p_mask > 0。
|
| 183 |
max_tries (int):
|
| 184 |
为确保至少一个非 prompt token 被 mask,对每个批次项尝试的最大次数。
|
| 185 |
-
|
| 186 |
Returns:
|
| 187 |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 188 |
- noisy_input_ids (torch.Tensor):
|
|
@@ -290,13 +284,11 @@ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
|
|
| 290 |
- **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
|
| 291 |
- **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
|
| 292 |
- **Block Causal Mask (M_BC)**: Attention to update x0
|
| 293 |
-
|
| 294 |
Args:
|
| 295 |
b, h: Batch and head indices (ignored for mask logic).
|
| 296 |
q_idx, kv_idx: Query and Key indices.
|
| 297 |
seq_len: Total sequence length.
|
| 298 |
block_size: Defines the block structure.
|
| 299 |
-
|
| 300 |
Returns:
|
| 301 |
A boolean attention mask.
|
| 302 |
"""
|
|
@@ -410,7 +402,6 @@ def rotate_half(x):
|
|
| 410 |
|
| 411 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 412 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 413 |
-
|
| 414 |
Args:
|
| 415 |
q (`torch.Tensor`): The query tensor.
|
| 416 |
k (`torch.Tensor`): The key tensor.
|
|
@@ -970,7 +961,6 @@ class SDARModel(SDARPreTrainedModel):
|
|
| 970 |
"""
|
| 971 |
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 972 |
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 973 |
-
|
| 974 |
Args:
|
| 975 |
attention_mask (`torch.Tensor`):
|
| 976 |
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
|
@@ -1160,7 +1150,6 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
|
| 1160 |
compute_rl_loss: bool = False,
|
| 1161 |
p_mask: Optional[torch.Tensor] = None,
|
| 1162 |
adv: Optional[torch.Tensor] = None,
|
| 1163 |
-
adv_optimization: bool = False,
|
| 1164 |
logp_old_tok: Optional[torch.Tensor] = None,
|
| 1165 |
logp_ref_tok: Optional[torch.Tensor] = None,
|
| 1166 |
is_real: Optional[torch.Tensor] = None,
|
|
@@ -1237,12 +1226,6 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
|
| 1237 |
|
| 1238 |
# 选出 logits — 保持原样
|
| 1239 |
logits_p = logits[p_to_keep_real] # (N, V)
|
| 1240 |
-
N = p_to_keep_real.sum().item()
|
| 1241 |
-
total_response_tokens = (labels != -100).sum().item()
|
| 1242 |
-
total_p_mask = p_mask.sum().item()
|
| 1243 |
-
total_masked_indices = masked_indices.sum().item()
|
| 1244 |
-
total_is_real = is_real_tensor.sum().item() if is_real_tensor.dim() > 0 else (1 if is_real_tensor.item() else 0)
|
| 1245 |
-
|
| 1246 |
|
| 1247 |
# log_softmax
|
| 1248 |
log_probs_p = torch.nn.functional.log_softmax(logits_p, dim=-1)
|
|
@@ -1260,133 +1243,7 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
|
| 1260 |
|
| 1261 |
# advantage 处理
|
| 1262 |
adv_tensor = adv.to(device) if torch.is_tensor(adv) else torch.tensor(adv, dtype=torch.float, device=device)
|
| 1263 |
-
|
| 1264 |
-
if adv_optimization:
|
| 1265 |
-
# token级别优化:对相同前缀取最大advantage(剪枝优化版本)
|
| 1266 |
-
response_mask = (labels != -100) # (B, L)
|
| 1267 |
-
bsz, seq_len = input_ids.shape
|
| 1268 |
-
|
| 1269 |
-
# 预计算每个样本的response起始位置
|
| 1270 |
-
response_starts = torch.full((bsz,), seq_len, dtype=torch.long, device=device)
|
| 1271 |
-
for b in range(bsz):
|
| 1272 |
-
if response_mask[b].any():
|
| 1273 |
-
response_starts[b] = response_mask[b].long().argmax()
|
| 1274 |
-
|
| 1275 |
-
# 剪枝1: 找出已经是最大advantage的样本,直接填充不参与比较
|
| 1276 |
-
max_adv_value = adv_tensor.max()
|
| 1277 |
-
is_max_adv = (adv_tensor == max_adv_value) # (B,) bool
|
| 1278 |
-
|
| 1279 |
-
# 创建优化后的 advantage map (B, L),确保dtype与adv_tensor一致
|
| 1280 |
-
optimized_adv = torch.zeros_like(labels, dtype=adv_tensor.dtype)
|
| 1281 |
-
|
| 1282 |
-
# 对于已是最大advantage的样本,直接填充
|
| 1283 |
-
for b in range(bsz):
|
| 1284 |
-
if is_max_adv[b]:
|
| 1285 |
-
optimized_adv[b][response_mask[b]] = max_adv_value
|
| 1286 |
-
|
| 1287 |
-
# 统计信息
|
| 1288 |
-
total_response_tokens = 0
|
| 1289 |
-
updated_tokens = 0
|
| 1290 |
-
skipped_tokens = 0
|
| 1291 |
-
original_adv_sum = 0.0
|
| 1292 |
-
optimized_adv_sum = 0.0
|
| 1293 |
-
|
| 1294 |
-
# 按position处理,批量比较前缀
|
| 1295 |
-
for pos in range(seq_len):
|
| 1296 |
-
valid_samples = response_mask[:, pos] # (B,)
|
| 1297 |
-
if not valid_samples.any():
|
| 1298 |
-
continue
|
| 1299 |
-
|
| 1300 |
-
# 剪枝2: 排除已是最大advantage的样本
|
| 1301 |
-
valid_samples = valid_samples & ~is_max_adv
|
| 1302 |
-
if not valid_samples.any():
|
| 1303 |
-
# 所有样本都是最大值,统计后跳过
|
| 1304 |
-
max_count = (response_mask[:, pos] & is_max_adv).sum().item()
|
| 1305 |
-
total_response_tokens += max_count
|
| 1306 |
-
skipped_tokens += max_count
|
| 1307 |
-
original_adv_sum += max_adv_value.item() * max_count
|
| 1308 |
-
optimized_adv_sum += max_adv_value.item() * max_count
|
| 1309 |
-
continue
|
| 1310 |
-
|
| 1311 |
-
# 获取所有需要处理的样本索引
|
| 1312 |
-
valid_indices = valid_samples.nonzero(as_tuple=True)[0] # (N,)
|
| 1313 |
-
|
| 1314 |
-
for b in valid_indices:
|
| 1315 |
-
b_item = b.item()
|
| 1316 |
-
response_start = response_starts[b_item].item()
|
| 1317 |
-
prefix_len = pos + 1 - response_start
|
| 1318 |
-
|
| 1319 |
-
if prefix_len <= 0:
|
| 1320 |
-
optimized_adv[b_item, pos] = adv_tensor[b_item]
|
| 1321 |
-
continue
|
| 1322 |
-
|
| 1323 |
-
# 找出所有response起始位置相同且在pos位置有效的样本(包括已是最大值的)
|
| 1324 |
-
same_start_mask = (response_starts == response_start) & response_mask[:, pos]
|
| 1325 |
-
same_start_indices = same_start_mask.nonzero(as_tuple=True)[0]
|
| 1326 |
-
|
| 1327 |
-
if len(same_start_indices) == 1:
|
| 1328 |
-
# 只有自己,不需要比较
|
| 1329 |
-
optimized_adv[b_item, pos] = adv_tensor[b_item]
|
| 1330 |
-
total_response_tokens += 1
|
| 1331 |
-
original_adv_sum += adv_tensor[b_item].item()
|
| 1332 |
-
optimized_adv_sum += adv_tensor[b_item].item()
|
| 1333 |
-
continue
|
| 1334 |
-
|
| 1335 |
-
# 剪枝3: 如果候选中有最大advantage样本,可以直接用最大值
|
| 1336 |
-
has_max_in_candidates = (same_start_mask & is_max_adv).any()
|
| 1337 |
-
|
| 1338 |
-
prefix_end = pos + 1
|
| 1339 |
-
current_prefix = input_ids[b_item, response_start:prefix_end]
|
| 1340 |
-
|
| 1341 |
-
# 批量比较:提取所有候选样本的前缀
|
| 1342 |
-
prefixes = input_ids[same_start_indices, response_start:prefix_end] # (M, prefix_len)
|
| 1343 |
-
|
| 1344 |
-
# 使用广播比较:(M, prefix_len) vs (prefix_len,)
|
| 1345 |
-
matches = (prefixes == current_prefix.unsqueeze(0)).all(dim=1) # (M,)
|
| 1346 |
-
|
| 1347 |
-
# 找到匹配的样本
|
| 1348 |
-
matching_indices = same_start_indices[matches]
|
| 1349 |
-
|
| 1350 |
-
# 在相同前缀的样本中取最大 advantage
|
| 1351 |
-
original_adv_value = adv_tensor[b_item].item()
|
| 1352 |
-
if matching_indices.numel() > 0:
|
| 1353 |
-
# 剪枝4: 如果匹配中有最大值样本,直接用最大值
|
| 1354 |
-
if has_max_in_candidates and is_max_adv[matching_indices].any():
|
| 1355 |
-
max_adv = max_adv_value
|
| 1356 |
-
else:
|
| 1357 |
-
max_adv = adv_tensor[matching_indices].max()
|
| 1358 |
-
|
| 1359 |
-
optimized_adv[b_item, pos] = max_adv
|
| 1360 |
-
# 统计
|
| 1361 |
-
if abs(max_adv.item() - original_adv_value) > 1e-6:
|
| 1362 |
-
updated_tokens += 1
|
| 1363 |
-
original_adv_sum += original_adv_value
|
| 1364 |
-
optimized_adv_sum += max_adv.item()
|
| 1365 |
-
else:
|
| 1366 |
-
optimized_adv[b_item, pos] = adv_tensor[b_item]
|
| 1367 |
-
original_adv_sum += original_adv_value
|
| 1368 |
-
optimized_adv_sum += original_adv_value
|
| 1369 |
-
|
| 1370 |
-
total_response_tokens += 1
|
| 1371 |
-
|
| 1372 |
-
# 输出统计信息
|
| 1373 |
-
if total_response_tokens > 0:
|
| 1374 |
-
update_ratio = updated_tokens / total_response_tokens
|
| 1375 |
-
skip_ratio = skipped_tokens / total_response_tokens
|
| 1376 |
-
avg_original = original_adv_sum / total_response_tokens
|
| 1377 |
-
avg_optimized = optimized_adv_sum / total_response_tokens
|
| 1378 |
-
print(f"[Adv Optimization] Total: {total_response_tokens}, "
|
| 1379 |
-
f"Updated: {updated_tokens} ({update_ratio:.2%}), "
|
| 1380 |
-
f"Skipped: {skipped_tokens} ({skip_ratio:.2%}), "
|
| 1381 |
-
f"Avg adv: {avg_original:.4f} -> {avg_optimized:.4f} "
|
| 1382 |
-
f"(+{avg_optimized - avg_original:.4f})")
|
| 1383 |
-
|
| 1384 |
-
# 使用优化后的 advantage
|
| 1385 |
-
adv_expanded = optimized_adv
|
| 1386 |
-
else:
|
| 1387 |
-
# 不优化:直接使用原始 advantage
|
| 1388 |
-
adv_expanded = adv_tensor.unsqueeze(1).expand_as(p_mask)
|
| 1389 |
-
|
| 1390 |
adv_p = adv_expanded[masked_indices][p_to_keep_real]
|
| 1391 |
|
| 1392 |
# old logp
|
|
@@ -1394,20 +1251,13 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
|
| 1394 |
logp_old_p = logp_old_tok.to(device)[masked_indices][p_to_keep_real]
|
| 1395 |
else:
|
| 1396 |
logp_old_p = logp_p.detach()
|
| 1397 |
-
|
| 1398 |
# ratio/exp
|
| 1399 |
ratio_p = (logp_p - logp_old_p).clamp(-10.0, 10.0).exp()
|
| 1400 |
clipped = ratio_p.clamp(1 - ppo_eps, 1 + ppo_eps+0.08)
|
| 1401 |
surrogate_p = torch.minimum(ratio_p * adv_p, clipped * adv_p)
|
| 1402 |
-
# 输出离1最远的ratio值
|
| 1403 |
-
# if not torch.allclose(ratio_p, torch.ones_like(ratio_p)):
|
| 1404 |
-
furthest_value = ratio_p[torch.abs(ratio_p - 1).argmax()]
|
| 1405 |
-
# print(f"Furthest ratio from 1: {furthest_value.item()}")
|
| 1406 |
|
| 1407 |
# Policy loss: use mean or sum based on loss_mean parameter
|
| 1408 |
-
num_masked = masked_indices.sum().item()
|
| 1409 |
-
num_loss_elements = surrogate_p.numel()
|
| 1410 |
-
print(f"masked_indices.sum()={num_masked}, surrogate_p.numel()={num_loss_elements}")
|
| 1411 |
if loss_mean:
|
| 1412 |
policy_loss = -surrogate_p.mean()
|
| 1413 |
else:
|
|
|
|
| 77 |
使用完全向量化的 PyTorch 操作修改一个 batch 的 packed position_ids。
|
| 78 |
这个函数假设输入是一个 2D Tensor,形状为 (batch_size, sequence_length)。
|
| 79 |
它会独立地处理 batch 中的每一行。
|
|
|
|
| 80 |
Args:
|
| 81 |
position_ids: 二维 PyTorch Tensor, shape (batch_size, sequence_length).
|
|
|
|
| 82 |
Returns:
|
| 83 |
修改后的 position_ids Tensor, shape (batch_size, sequence_length).
|
| 84 |
"""
|
|
|
|
| 106 |
def calculate_token_nums(position_ids: torch.Tensor):
|
| 107 |
"""
|
| 108 |
使用 PyTorch 高效计算一个批次中每个打包序列的长度。
|
|
|
|
| 109 |
Args:
|
| 110 |
position_ids (torch.Tensor): 一个 2D Tensor,形状为 (batch_size, sequence_length)。
|
| 111 |
例如:tensor([[0,1,2,3,4,0,1,2,3,4,5,0,1,2,3,0,0,0]])
|
|
|
|
| 159 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 160 |
"""
|
| 161 |
为一批打包(packed)序列的 token ID 添加噪声。
|
|
|
|
| 162 |
此函数保留了为每个逻辑样本(在每个批次项内拼接)生成独立随机噪声率的逻辑。
|
| 163 |
它会随机将一部分 token 的 ID 替换为 mask_id。
|
| 164 |
这个过程会避开被 prompt_mask 标记的位置。
|
|
|
|
| 165 |
Args:
|
| 166 |
inputs_ids (torch.Tensor):
|
| 167 |
输入的 token ID 张量,形状为 (bsz, total_tokens)。
|
|
|
|
| 177 |
微小值,用于防止噪声率 t 恰好为 0,确保 p_mask > 0。
|
| 178 |
max_tries (int):
|
| 179 |
为确保至少一个非 prompt token 被 mask,对每个批次项尝试的最大次数。
|
|
|
|
| 180 |
Returns:
|
| 181 |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 182 |
- noisy_input_ids (torch.Tensor):
|
|
|
|
| 284 |
- **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
|
| 285 |
- **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
|
| 286 |
- **Block Causal Mask (M_BC)**: Attention to update x0
|
|
|
|
| 287 |
Args:
|
| 288 |
b, h: Batch and head indices (ignored for mask logic).
|
| 289 |
q_idx, kv_idx: Query and Key indices.
|
| 290 |
seq_len: Total sequence length.
|
| 291 |
block_size: Defines the block structure.
|
|
|
|
| 292 |
Returns:
|
| 293 |
A boolean attention mask.
|
| 294 |
"""
|
|
|
|
| 402 |
|
| 403 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 404 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
|
| 405 |
Args:
|
| 406 |
q (`torch.Tensor`): The query tensor.
|
| 407 |
k (`torch.Tensor`): The key tensor.
|
|
|
|
| 961 |
"""
|
| 962 |
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 963 |
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
|
|
|
| 964 |
Args:
|
| 965 |
attention_mask (`torch.Tensor`):
|
| 966 |
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
|
|
|
| 1150 |
compute_rl_loss: bool = False,
|
| 1151 |
p_mask: Optional[torch.Tensor] = None,
|
| 1152 |
adv: Optional[torch.Tensor] = None,
|
|
|
|
| 1153 |
logp_old_tok: Optional[torch.Tensor] = None,
|
| 1154 |
logp_ref_tok: Optional[torch.Tensor] = None,
|
| 1155 |
is_real: Optional[torch.Tensor] = None,
|
|
|
|
| 1226 |
|
| 1227 |
# 选出 logits — 保持原样
|
| 1228 |
logits_p = logits[p_to_keep_real] # (N, V)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1229 |
|
| 1230 |
# log_softmax
|
| 1231 |
log_probs_p = torch.nn.functional.log_softmax(logits_p, dim=-1)
|
|
|
|
| 1243 |
|
| 1244 |
# advantage 处理
|
| 1245 |
adv_tensor = adv.to(device) if torch.is_tensor(adv) else torch.tensor(adv, dtype=torch.float, device=device)
|
| 1246 |
+
adv_expanded = adv_tensor.unsqueeze(1).expand_as(p_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1247 |
adv_p = adv_expanded[masked_indices][p_to_keep_real]
|
| 1248 |
|
| 1249 |
# old logp
|
|
|
|
| 1251 |
logp_old_p = logp_old_tok.to(device)[masked_indices][p_to_keep_real]
|
| 1252 |
else:
|
| 1253 |
logp_old_p = logp_p.detach()
|
| 1254 |
+
|
| 1255 |
# ratio/exp
|
| 1256 |
ratio_p = (logp_p - logp_old_p).clamp(-10.0, 10.0).exp()
|
| 1257 |
clipped = ratio_p.clamp(1 - ppo_eps, 1 + ppo_eps+0.08)
|
| 1258 |
surrogate_p = torch.minimum(ratio_p * adv_p, clipped * adv_p)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1259 |
|
| 1260 |
# Policy loss: use mean or sum based on loss_mean parameter
|
|
|
|
|
|
|
|
|
|
| 1261 |
if loss_mean:
|
| 1262 |
policy_loss = -surrogate_p.mean()
|
| 1263 |
else:
|