| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
|
|
| import torch |
|
|
| from verl.protocol import DataProto |
|
|
| logger = logging.getLogger(__file__) |
|
|
|
|
| def calculate_token_list_diff(tensor1: torch.Tensor, tensor2: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| |
| if tensor1.numel() == 0 or tensor2.numel() == 0: |
| return torch.zeros(tensor1.shape[0], dtype=torch.long, device=tensor1.device) |
| if tensor1.shape != tensor2.shape or mask.shape != tensor1.shape or mask.shape != tensor2.shape: |
| print( |
| f"<WARN> dim of tensor1, tensor2, mask is not equal, {(tensor1.shape)=},{(tensor2.shape)=}, {(mask.shape)=}" |
| ) |
| return torch.ones_like(tensor1) |
| |
| if tensor2.device != tensor1.device: |
| tensor2 = tensor2.to(tensor1.device) |
| if mask.device != tensor1.device: |
| mask = mask.to(tensor1.device) |
|
|
| |
| diff_mask = tensor1 != tensor2 |
|
|
| valid_diff_mask = diff_mask & (mask == 1) |
|
|
| diff_counts = valid_diff_mask.sum(dim=1) |
|
|
| return diff_counts |
|
|
|
|
| def pearson_correlation_coefficient(tensor1: torch.Tensor, tensor2: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| |
| if tensor1.shape != tensor2.shape or mask.shape != tensor1.shape or mask.shape != tensor2.shape: |
| return 0 |
| mt1 = torch.masked_select(tensor1, mask) |
| mt2 = torch.masked_select(tensor2, mask) |
| result = torch.corrcoef(torch.stack([mt1, mt2], dim=0)) |
| return result[0][1].detach().item() |
|
|
|
|
| def calculate_log_prob_diff(log_probs1: torch.Tensor, log_probs2: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| full_diff = torch.abs(log_probs1 - log_probs2) |
| return torch.masked_select(full_diff, mask) |
|
|
|
|
| def calculate_debug_metrics(data: DataProto) -> dict: |
| """ |
| calculate rollout vs actor logprobs diff, for debugging purpose |
| |
| Args: |
| data: DataProto |
| the data batch to calculate |
| rollout_log_probs: log_probs record when rollout forward tokens |
| old_log_probs(actor log probs): log_probs record when actor forward tokens |
| loss_mask or attention_mask: to mask unrelated token |
| responses: the response tokens, for calculating size |
| Returns: |
| dict: metrics |
| "training/rollout_probs_diff_valid": 1->input is valid, 0->input is invalid |
| "training/rollout_probs_diff_max": max value of logprob diff of rollout vs. actor |
| "training/rollout_probs_diff_mean": mean value of logprob diff of rollout vs. actor |
| "training/rollout_probs_diff_std": std value of logprob diff of rollout vs. actor |
| "training/rollout_actor_probs_pearson_corr": logprob's pearson corrcoef of rollout vs. actor, reference to https://arxiv.org/pdf/2506.13585 |
| """ |
|
|
| rollout_old_log_probs = data.batch["rollout_log_probs"] |
| actor_old_log_probs = data.batch["old_log_probs"] |
| if "response_mask" in data.batch: |
| logger.debug("response mask found, use it to mask log probs") |
| log_prob_mask = data.batch["response_mask"] |
| elif "attention_mask" in data.batch: |
| log_prob_mask = data.batch["attention_mask"] |
| else: |
| logger.warning(f"no mask info found, use all log probs, {(data.batch.keys())=}") |
| log_prob_mask = torch.ones_like(rollout_old_log_probs) |
| responses = data.batch["responses"] |
| response_length = responses.size(1) |
|
|
| response_mask = log_prob_mask[:, -response_length:] |
| |
| actor_probs = torch.exp(actor_old_log_probs) |
| rollout_probs = torch.exp(rollout_old_log_probs) |
| response_mask_bool = response_mask.bool() |
| pearson_corrcoef = pearson_correlation_coefficient(actor_probs, rollout_probs, response_mask_bool) |
| rollout_probs_diff = calculate_log_prob_diff(actor_probs, rollout_probs, response_mask_bool) |
| return { |
| "training/rollout_probs_diff_valid": 1, |
| "training/rollout_probs_diff_max": torch.max(rollout_probs_diff).detach().item(), |
| "training/rollout_probs_diff_mean": torch.mean(rollout_probs_diff).detach().item(), |
| "training/rollout_probs_diff_std": torch.std(rollout_probs_diff).detach().item(), |
| "training/rollout_actor_probs_pearson_corr": pearson_corrcoef, |
| } |
|
|