Buckets:

hf-doc-build/doc-dev / trl /main /en /logging.html
rtrm's picture
download
raw
16.3 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Logging&quot;,&quot;local&quot;:&quot;logging&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;PPO Logging&quot;,&quot;local&quot;:&quot;ppo-logging&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Crucial values&quot;,&quot;local&quot;:&quot;crucial-values&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/trl/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/entry/start.183b226a.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/scheduler.85c25b89.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/singletons.98fe034d.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/paths.eb9df337.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/entry/app.9853b7f5.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/index.c142fe32.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/nodes/0.5efac18d.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/nodes/20.8c72bfbb.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/CodeBlock.a5e95a57.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/EditOnGithub.a592e7aa.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Logging&quot;,&quot;local&quot;:&quot;logging&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;PPO Logging&quot;,&quot;local&quot;:&quot;ppo-logging&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Crucial values&quot;,&quot;local&quot;:&quot;crucial-values&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="logging" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#logging"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Logging</span></h1> <p data-svelte-h="svelte-1cil2uv">As reinforcement learning algorithms are historically challenging to debug, it’s important to pay careful attention to logging.
By default, the TRL <a href="/docs/trl/main/en/ppo_trainer#trl.PPOTrainer">PPOTrainer</a> saves a lot of relevant information to <code>wandb</code> or <code>tensorboard</code>.</p> <p data-svelte-h="svelte-uvaohc">Upon initialization, pass one of these two options to the <a href="/docs/trl/main/en/ppo_trainer#trl.PPOConfig">PPOConfig</a>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->config = PPOConfig(
model_name=args.model_name,
log_with=`wandb`, # <span class="hljs-literal">or</span> `tensorboard`
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-9gqjy4">If you want to log with tensorboard, add the kwarg <code>project_kwargs={&quot;logging_dir&quot;: PATH_TO_LOGS}</code> to the PPOConfig.</p> <h2 class="relative group"><a id="ppo-logging" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ppo-logging"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>PPO Logging</span></h2> <p data-svelte-h="svelte-50qxhy">Here’s a brief explanation for the logged metrics provided in the data:</p> <p data-svelte-h="svelte-1oysg7j">Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:</p> <ol data-svelte-h="svelte-1c10na6"><li><code>env/reward_mean</code>: The average reward obtained from the environment. Alias <code>ppo/mean_scores</code>, which is sed to specifically monitor the reward model.</li> <li><code>env/reward_std</code>: The standard deviation of the reward obtained from the environment. Alias `<code>ppo/std_scores</code>, which is sed to specifically monitor the reward model.</li> <li><code>env/reward_dist</code>: The histogram distribution of the reward obtained from the environment.</li> <li><code>objective/kl</code>: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.</li> <li><code>objective/kl_dist</code>: The histogram distribution of the <code>objective/kl</code>.</li> <li><code>objective/kl_coef</code>: The coefficient for Kullback-Leibler (KL) divergence in the objective function.</li> <li><code>ppo/mean_non_score_reward</code>: The <strong>KL penalty</strong> calculated by <code>objective/kl * objective/kl_coef</code> as the total reward for optimization to prevent the new policy from deviating too far from the old policy.</li> <li><code>objective/entropy</code>: The entropy of the model’s policy, calculated by <code>-logprobs.sum(-1).mean()</code>. High entropy means the model’s actions are more random, which can be beneficial for exploration.</li></ol> <p data-svelte-h="svelte-ma4ajt">Training stats:</p> <ol data-svelte-h="svelte-1eo0hbb"><li><code>ppo/learning_rate</code>: The learning rate for the PPO algorithm.</li> <li><code>ppo/policy/entropy</code>: The entropy of the model’s policy, calculated by <code>pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)</code>. It measures the randomness of the policy.</li> <li><code>ppo/policy/clipfrac</code>: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.</li> <li><code>ppo/policy/approxkl</code>: The approximate KL divergence between the old and new policies, measured by <code>0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)</code>, corresponding to the <code>k2</code> estimator in <a href="http://joschu.net/blog/kl-approx.html" rel="nofollow">http://joschu.net/blog/kl-approx.html</a></li> <li><code>ppo/policy/policykl</code>: Similar to <code>ppo/policy/approxkl</code>, but measured by <code>masked_mean(old_logprobs - logprobs, mask)</code>, corresponding to the <code>k1</code> estimator in <a href="http://joschu.net/blog/kl-approx.html" rel="nofollow">http://joschu.net/blog/kl-approx.html</a></li> <li><code>ppo/policy/ratio</code>: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.</li> <li><code>ppo/policy/advantages_mean</code>: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.</li> <li><code>ppo/policy/advantages</code>: The histogram distribution of <code>ppo/policy/advantages_mean</code>.</li> <li><code>ppo/returns/mean</code>: The mean of the TD(λ) returns, calculated by <code>returns = advantage + values</code>, another indicator of model performance. See <a href="https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/" rel="nofollow">https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/</a> for more details.</li> <li><code>ppo/returns/var</code>: The variance of the TD(λ) returns, calculated by <code>returns = advantage + values</code>, another indicator of model performance.</li> <li><code>ppo/val/mean</code>: The mean of the values, used to monitor the value function’s performance.</li> <li><code>ppo/val/var</code> : The variance of the values, used to monitor the value function’s performance.</li> <li><code>ppo/val/var_explained</code>: The explained variance for the value function, used to monitor the value function’s performance.</li> <li><code>ppo/val/clipfrac</code>: The fraction of the value function’s predicted values that are clipped.</li> <li><code>ppo/val/vpred</code>: The predicted values from the value function.</li> <li><code>ppo/val/error</code>: The mean squared error between the <code>ppo/val/vpred</code> and returns, used to monitor the value function’s performance.</li> <li><code>ppo/loss/policy</code>: The policy loss for the Proximal Policy Optimization (PPO) algorithm.</li> <li><code>ppo/loss/value</code>: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.</li> <li><code>ppo/loss/total</code>: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.</li></ol> <p data-svelte-h="svelte-1csidm3">Stats on queries, responses, and logprobs:</p> <ol data-svelte-h="svelte-pvawsw"><li><code>tokens/queries_len_mean</code>: The average length of the queries tokens.</li> <li><code>tokens/queries_len_std</code>: The standard deviation of the length of the queries tokens.</li> <li><code>tokens/queries_dist</code>: The histogram distribution of the length of the queries tokens.</li> <li><code>tokens/responses_len_mean</code>: The average length of the responses tokens.</li> <li><code>tokens/responses_len_std</code>: The standard deviation of the length of the responses tokens.</li> <li><code>tokens/responses_dist</code>: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be <code>tokens/responses_len_dist</code>)</li> <li><code>objective/logprobs</code>: The histogram distribution of the log probabilities of the actions taken by the model.</li> <li><code>objective/ref_logprobs</code>: The histogram distribution of the log probabilities of the actions taken by the reference model.</li></ol> <h3 class="relative group"><a id="crucial-values" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#crucial-values"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Crucial values</span></h3> <p data-svelte-h="svelte-189bxez">During training, many values are logged, here are the most important ones:</p> <ol data-svelte-h="svelte-wjaroo"><li><code>env/reward_mean</code>,<code>env/reward_std</code>, <code>env/reward_dist</code>: the properties of the reward distribution from the “environment” / reward model</li> <li><code>ppo/mean_non_score_reward</code>: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)</li></ol> <p data-svelte-h="svelte-1yomxw6">Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):</p> <ol data-svelte-h="svelte-zapvl9"><li><code>ppo/loss/value</code>: it will spike / NaN when not going well.</li> <li><code>ppo/policy/ratio</code>: <code>ratio</code> being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.</li> <li><code>ppo/policy/clipfrac</code> and <code>ppo/policy/approxkl</code>: if <code>ratio</code> is too high, the <code>ratio</code> is going to get clipped, resulting in high <code>clipfrac</code> and high <code>approxkl</code> as well.</li> <li><code>objective/kl</code>: it should stay positive so that the policy is not too far away from the reference policy.</li> <li><code>objective/kl_coef</code>: The target coefficient with <code>AdaptiveKLController</code>. Often increases before numerical instabilities.</li></ol> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/trl/blob/main/docs/source/logging.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_5yobsv = {
assets: "/docs/trl/main/en",
base: "/docs/trl/main/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/trl/main/en/_app/immutable/entry/start.183b226a.js"),
import("/docs/trl/main/en/_app/immutable/entry/app.9853b7f5.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 20],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
16.3 kB
·
Xet hash:
819f99a7d6c37ada55beaa5ab8c09b4266ab0f6eb57be81aa8fd1c94f182d04c

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.