|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from ..protocol import DataProto |
|
|
|
|
|
|
|
|
def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: |
|
|
return {key: np.mean(value) for key, value in metrics.items()} |
|
|
|
|
|
|
|
|
def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str, Any]: |
|
|
sequence_score = batch.batch["token_level_scores"].sum(-1) |
|
|
sequence_reward = batch.batch["token_level_rewards"].sum(-1) |
|
|
|
|
|
advantages = batch.batch["advantages"] |
|
|
returns = batch.batch["returns"] |
|
|
|
|
|
max_response_length = batch.batch["responses"].size(-1) |
|
|
|
|
|
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() |
|
|
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() |
|
|
|
|
|
max_prompt_length = prompt_mask.size(-1) |
|
|
prompt_length = prompt_mask.sum(-1).float() |
|
|
response_length = response_mask.sum(-1).float() |
|
|
|
|
|
valid_adv = torch.masked_select(advantages, response_mask) |
|
|
valid_returns = torch.masked_select(returns, response_mask) |
|
|
|
|
|
if use_critic: |
|
|
values = batch.batch["values"] |
|
|
valid_values = torch.masked_select(values, response_mask) |
|
|
return_diff_var = torch.var(valid_returns - valid_values) |
|
|
return_var = torch.var(valid_returns) |
|
|
|
|
|
metrics = { |
|
|
|
|
|
"critic/score/mean": torch.mean(sequence_score).detach().item(), |
|
|
"critic/score/max": torch.max(sequence_score).detach().item(), |
|
|
"critic/score/min": torch.min(sequence_score).detach().item(), |
|
|
|
|
|
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(), |
|
|
"critic/rewards/max": torch.max(sequence_reward).detach().item(), |
|
|
"critic/rewards/min": torch.min(sequence_reward).detach().item(), |
|
|
|
|
|
"critic/advantages/mean": torch.mean(valid_adv).detach().item(), |
|
|
"critic/advantages/max": torch.max(valid_adv).detach().item(), |
|
|
"critic/advantages/min": torch.min(valid_adv).detach().item(), |
|
|
|
|
|
"critic/returns/mean": torch.mean(valid_returns).detach().item(), |
|
|
"critic/returns/max": torch.max(valid_returns).detach().item(), |
|
|
"critic/returns/min": torch.min(valid_returns).detach().item(), |
|
|
**( |
|
|
{ |
|
|
|
|
|
"critic/values/mean": torch.mean(valid_values).detach().item(), |
|
|
"critic/values/max": torch.max(valid_values).detach().item(), |
|
|
"critic/values/min": torch.min(valid_values).detach().item(), |
|
|
|
|
|
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), |
|
|
} |
|
|
if use_critic |
|
|
else {} |
|
|
), |
|
|
|
|
|
"response_length/mean": torch.mean(response_length).detach().item(), |
|
|
"response_length/max": torch.max(response_length).detach().item(), |
|
|
"response_length/min": torch.min(response_length).detach().item(), |
|
|
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) |
|
|
.detach() |
|
|
.item(), |
|
|
|
|
|
"prompt_length/mean": torch.mean(prompt_length).detach().item(), |
|
|
"prompt_length/max": torch.max(prompt_length).detach().item(), |
|
|
"prompt_length/min": torch.min(prompt_length).detach().item(), |
|
|
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), |
|
|
} |
|
|
return metrics |
|
|
|
|
|
|
|
|
def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]: |
|
|
num_response_tokens = torch.sum(batch.batch["response_mask"]).item() |
|
|
num_overall_tokens = sum(batch.meta_info["global_token_num"]) |
|
|
num_tokens_of_section = { |
|
|
**dict.fromkeys(["gen", "reward"], num_response_tokens), |
|
|
**dict.fromkeys(["ref", "old", "values", "adv", "update_critic", "update_actor"], num_overall_tokens), |
|
|
} |
|
|
return { |
|
|
**{f"timing_s/{name}": value for name, value in timing_raw.items()}, |
|
|
**{ |
|
|
f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] |
|
|
for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]: |
|
|
total_num_tokens = sum(batch.meta_info["global_token_num"]) |
|
|
time = timing_raw["step"] |
|
|
return { |
|
|
"perf/total_num_tokens": total_num_tokens, |
|
|
"perf/time_per_step": time, |
|
|
"perf/throughput": total_num_tokens / (time * n_gpus), |
|
|
} |
|
|
|