Spaces:
Running on Zero
Running on Zero
| """Live training visualization callbacks for GRPO/SFT notebooks.""" | |
| from __future__ import annotations | |
| import html as _html | |
| import logging | |
| from typing import Any | |
| _logger = logging.getLogger(__name__) | |
| try: | |
| from transformers import TrainerCallback | |
| except ImportError: | |
| TrainerCallback = object # type: ignore[assignment,misc] | |
| try: | |
| from .env_metrics import TRACKER | |
| except ImportError: # pragma: no cover | |
| from training.env_metrics import TRACKER # type: ignore[no-redef] | |
| class LiveVisualizationCallback(TrainerCallback): | |
| """TrainerCallback that plots reward and loss in place during training. | |
| Updates a single plot via IPython display handle without clearing | |
| the cell output. | |
| """ | |
| def __init__(self, **kwargs: Any) -> None: | |
| # Accept and ignore extra kwargs for backward compat | |
| _ = kwargs | |
| self.log_steps: list[int] = [] | |
| self.log_rewards: list[float] = [] | |
| self.log_losses: list[float] = [] | |
| self._plot_handle = None | |
| self._env_failure_rate = 0.0 | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| try: | |
| from IPython.display import HTML, display | |
| self._plot_handle = display( | |
| HTML("<em>Waiting for first log...</em>"), | |
| display_id="viz_plot", | |
| ) | |
| except Exception: | |
| pass | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if not logs: | |
| return | |
| step = state.global_step | |
| # Find reward (prefer mean) | |
| reward = None | |
| for key in sorted(logs.keys()): | |
| if "reward" in key and "mean" in key: | |
| reward = logs[key] | |
| break | |
| if reward is None: | |
| for key in ("reward", "rewards/mean"): | |
| if key in logs: | |
| reward = logs[key] | |
| break | |
| loss = logs.get("loss") | |
| has_data = False | |
| if reward is not None: | |
| self.log_rewards.append(float(reward)) | |
| has_data = True | |
| if loss is not None: | |
| self.log_losses.append(float(loss)) | |
| has_data = True | |
| if has_data: | |
| self.log_steps.append(step) | |
| # Data-quality guardrail: surface the harness failure rate and log a | |
| # WARNING when it exceeds 5% ("harness problem, not a model problem"). | |
| try: | |
| env_summary = TRACKER.summary() | |
| if env_summary["episodes_started"]: | |
| self._env_failure_rate = env_summary["failure_rate"] | |
| TRACKER.warn_if_exceeded() | |
| except Exception: # pragma: no cover - never break training on metrics | |
| pass | |
| self._update_plot() | |
| def _update_plot(self) -> None: | |
| if self._plot_handle is None: | |
| return | |
| try: | |
| import base64 | |
| import io | |
| import matplotlib.pyplot as plt | |
| from IPython.display import HTML | |
| fig, ax = plt.subplots(1, 1, figsize=(8, 3.5)) | |
| if self.log_rewards: | |
| ax.plot( | |
| self.log_steps[: len(self.log_rewards)], | |
| self.log_rewards, | |
| "b-o", | |
| markersize=3, | |
| label="Reward", | |
| ) | |
| ax.set_ylabel("Reward") | |
| ax.legend(loc="upper left") | |
| if self.log_losses: | |
| # SFT-only: plot loss on primary axis | |
| # GRPO: plot loss on secondary axis | |
| if self.log_rewards: | |
| ax2 = ax.twinx() | |
| ax2.plot( | |
| self.log_steps[: len(self.log_losses)], | |
| self.log_losses, | |
| "r-", | |
| alpha=0.4, | |
| label="Loss", | |
| ) | |
| ax2.set_ylabel("Loss", color="r", alpha=0.6) | |
| ax2.legend(loc="upper right") | |
| else: | |
| ax.plot( | |
| self.log_steps[: len(self.log_losses)], | |
| self.log_losses, | |
| "r-o", | |
| markersize=3, | |
| label="Loss", | |
| ) | |
| ax.set_ylabel("Loss") | |
| ax.legend(loc="upper right") | |
| ax.set_xlabel("Step") | |
| latest = self.log_steps[-1] if self.log_steps else 0 | |
| title = f"Training Progress (step {latest})" | |
| if self._env_failure_rate > 0: | |
| title += f" | harness-fail {self._env_failure_rate:.1%}" | |
| ax.set_title(title) | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=100, bbox_inches="tight") | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = base64.b64encode(buf.read()).decode("utf-8") | |
| self._plot_handle.update(HTML(f'<img src="data:image/png;base64,{img}">')) | |
| except Exception as exc: | |
| _logger.debug("Plot update failed: %s", exc) | |
| class SFTMonitorCallback(TrainerCallback): | |
| """Show sample completions and optional eval accuracy during SFT. | |
| Every ``eval_every_steps`` training steps the callback generates | |
| first-turn completions for a handful of prompts so the user can | |
| watch the model learn tool-calling patterns in real time. | |
| """ | |
| def __init__( | |
| self, | |
| tokenizer: Any, | |
| sample_prompts: list[list[dict[str, str]]], | |
| *, | |
| tools: list[dict] | None = None, | |
| train_dataset: Any = None, | |
| eval_every_steps: int = 50, | |
| max_new_tokens: int = 200, | |
| ) -> None: | |
| self.tokenizer = tokenizer | |
| self.sample_prompts = sample_prompts[:3] | |
| self.tools = tools | |
| self.train_dataset = train_dataset | |
| self.eval_every_steps = eval_every_steps | |
| self.max_new_tokens = max_new_tokens | |
| self._model: Any = None | |
| self._display_handle: Any = None | |
| # ------------------------------------------------------------------ | |
| def on_train_begin(self, args, state, control, model=None, **kwargs): | |
| self._model = model | |
| try: | |
| from IPython.display import HTML, display | |
| # Always use canonical tools (avoid Dataset serialization artifacts) | |
| tpl_tools = self.tools | |
| # 1) Inference prompt — what the model sees at generation time | |
| if self.sample_prompts: | |
| tpl_kwargs: dict[str, Any] = { | |
| "tokenize": False, | |
| "add_generation_prompt": True, | |
| } | |
| if tpl_tools: | |
| tpl_kwargs["tools"] = tpl_tools | |
| preview = self.tokenizer.apply_chat_template( | |
| self.sample_prompts[0], | |
| **tpl_kwargs, | |
| ) | |
| n_tok = len(self.tokenizer.encode(preview)) | |
| display( | |
| HTML( | |
| "<details><summary>" | |
| f"<b>Inference prompt</b> ({n_tok} tok)" | |
| " — system + tools + question, " | |
| "as seen by model during GRPO generation" | |
| "</summary>" | |
| "<pre style='background:#2d2d2d;color:#e0e0e0;" | |
| "padding:8px;border-radius:4px;font-size:12px;" | |
| "white-space:pre-wrap;max-height:600px;" | |
| "overflow-y:auto;'>" | |
| f"{_html.escape(preview)}</pre></details>" | |
| ) | |
| ) | |
| # 2) Training example — one per-turn example from the dataset | |
| if self.train_dataset is not None and len(self.train_dataset) > 0: | |
| row = self.train_dataset[0] | |
| msgs = row.get("messages", []) | |
| if msgs: | |
| ex_kwargs: dict[str, Any] = {"tokenize": False} | |
| if tpl_tools: | |
| ex_kwargs["tools"] = tpl_tools | |
| rendered_ex = self.tokenizer.apply_chat_template( | |
| msgs, | |
| **ex_kwargs, | |
| ) | |
| n_ex_tok = len(self.tokenizer.encode(rendered_ex)) | |
| n_turns = sum(1 for m in msgs if m.get("role") == "assistant") | |
| last_role = msgs[-1].get("role", "?") | |
| display( | |
| HTML( | |
| "<details><summary>" | |
| f"<b>SFT training example</b>" | |
| f" ({n_ex_tok} tok, {n_turns} asst turn)" | |
| " — history + one assistant tool_call, " | |
| "exactly what the model learns to predict" | |
| "</summary>" | |
| "<pre style='background:#1a1a2e;color:#e0e0e0;" | |
| "padding:8px;border-radius:4px;font-size:12px;" | |
| "white-space:pre-wrap;max-height:600px;" | |
| "overflow-y:auto;'>" | |
| f"{_html.escape(rendered_ex)}</pre>" | |
| f"<p style='color:#888;font-size:11px;'>" | |
| f"Last message role: <b>{last_role}</b> " | |
| f"| Loss is on this turn only</p>" | |
| "</details>" | |
| ) | |
| ) | |
| self._display_handle = display( | |
| HTML("<em>SFT samples: waiting for first checkpoint...</em>"), | |
| display_id="sft_samples", | |
| ) | |
| except Exception: | |
| pass | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| step = state.global_step | |
| if step == 0 or step % self.eval_every_steps != 0: | |
| return | |
| if self._model is None: | |
| return | |
| self._generate_and_display(step) | |
| def on_train_end(self, args, state, control, **kwargs): | |
| if self._model is not None: | |
| self._generate_and_display(state.global_step, final=True) | |
| # ------------------------------------------------------------------ | |
| def _generate_and_display(self, step: int, final: bool = False) -> None: | |
| try: | |
| import torch | |
| was_training = self._model.training | |
| self._model.eval() | |
| header = "SFT Final Samples" if final else f"SFT Samples (step {step})" | |
| parts = [f"<h4 style='color:#e0e0e0;'>{header}</h4>"] | |
| with torch.no_grad(): | |
| for messages in self.sample_prompts: | |
| question = messages[-1]["content"][:100] | |
| tpl_kwargs: dict[str, Any] = { | |
| "tokenize": False, | |
| "add_generation_prompt": True, | |
| } | |
| if self.tools: | |
| tpl_kwargs["tools"] = self.tools | |
| rendered = self.tokenizer.apply_chat_template( | |
| messages, | |
| **tpl_kwargs, | |
| ) | |
| inputs = self.tokenizer(rendered, return_tensors="pt") | |
| device = next(self._model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| has_tools = "<tools>" in rendered | |
| prompt_len = inputs["input_ids"].shape[1] | |
| out = self._model.generate( | |
| **inputs, | |
| max_new_tokens=self.max_new_tokens, | |
| do_sample=False, | |
| ) | |
| new_tokens = out[0][prompt_len:] | |
| raw = self.tokenizer.decode(new_tokens, skip_special_tokens=False) | |
| # Show first tool call only (stop at </tool_call>) | |
| end = raw.find("</tool_call>") | |
| if end != -1: | |
| raw = raw[: end + len("</tool_call>")] | |
| text = self.tokenizer.decode( | |
| self.tokenizer.encode(raw), | |
| skip_special_tokens=True, | |
| ).strip() | |
| badge = ( | |
| "<span style='color:#4caf50;'>✓ tools</span>" | |
| if has_tools | |
| else "<span style='color:#f44336;'>✗ no tools</span>" | |
| ) | |
| parts.append( | |
| "<pre style='background:#2d2d2d;color:#e0e0e0;" | |
| "padding:8px;margin:4px 0;border-radius:4px;" | |
| "font-size:13px;line-height:1.4;'>" | |
| f"<b style='color:#82aaff;'>Q:</b> " | |
| f"{_html.escape(question)}" | |
| f" [{badge}, {prompt_len} tok]\n" | |
| f"<b style='color:#c3e88d;'>→</b> " | |
| f"{_html.escape(text)}</pre>" | |
| ) | |
| if was_training: | |
| self._model.train() | |
| if self._display_handle is not None: | |
| from IPython.display import HTML | |
| self._display_handle.update(HTML("\n".join(parts))) | |
| except Exception as exc: | |
| _logger.debug("SFT sample generation failed: %s", exc) | |