| |
| """ |
| Weights & Biases Logging for GRPO Training. |
| |
| Provides comprehensive logging for debugging and monitoring GRPO fine-tuning: |
| - Per-step reward metrics (each reward function's mean) |
| - Sample tables showing question → response → reward breakdown |
| - Run configuration and hyperparameters |
| - Summary statistics at training end |
| |
| Usage: |
| from prolewiki_llm.wandb_logging import ( |
| init_wandb_logging, |
| WandbSampleLogger, |
| create_logging_reward, |
| ) |
| |
| # Initialize |
| run = init_wandb_logging(project="marxist-grpo", config={...}) |
| |
| # Create logger and reward function |
| sample_logger = WandbSampleLogger(log_every_n_steps=10) |
| logging_reward = create_logging_reward(sample_logger) |
| |
| # Use in GRPOTrainer |
| trainer = GRPOTrainer( |
| reward_funcs=[..., logging_reward], |
| ... |
| ) |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from typing import TYPE_CHECKING, Any |
|
|
| if TYPE_CHECKING: |
| from collections.abc import Callable, Sequence |
|
|
| |
| _WANDB_AVAILABLE: bool | None = None |
| _wandb_module: Any = None |
|
|
|
|
| def _get_wandb() -> Any: |
| """Lazily import and return wandb module.""" |
| global _WANDB_AVAILABLE, _wandb_module |
|
|
| if _WANDB_AVAILABLE is None: |
| try: |
| import wandb |
|
|
| _wandb_module = wandb |
| _WANDB_AVAILABLE = True |
| except ImportError: |
| _WANDB_AVAILABLE = False |
| _wandb_module = None |
|
|
| return _wandb_module |
|
|
|
|
| def is_wandb_available() -> bool: |
| """Check if wandb is installed and available.""" |
| _get_wandb() |
| return _WANDB_AVAILABLE is True |
|
|
|
|
| |
| |
| |
|
|
|
|
| def init_wandb_logging( |
| project: str, |
| config: dict[str, Any], |
| name: str | None = None, |
| tags: list[str] | None = None, |
| notes: str | None = None, |
| mode: str = "online", |
| ) -> Any: |
| """ |
| Initialize Weights & Biases logging for GRPO training. |
| |
| Args: |
| project: W&B project name (e.g., "marxist-grpo") |
| config: Dictionary of hyperparameters and settings |
| name: Optional run name (auto-generated if None) |
| tags: Optional list of tags for filtering runs |
| notes: Optional notes about this run |
| mode: "online", "offline", or "disabled" |
| |
| Returns: |
| wandb.Run object (or None if wandb unavailable) |
| |
| Example: |
| run = init_wandb_logging( |
| project="marxist-grpo", |
| config={ |
| "model": "DeepSeek-R1-0528-Qwen3-8B", |
| "learning_rate": 5e-6, |
| "batch_size": 2, |
| "max_steps": 250, |
| }, |
| tags=["grpo", "marxist", "v1"], |
| ) |
| """ |
| wandb = _get_wandb() |
| if wandb is None: |
| print("[WandbLogging] wandb not installed. Install with: pip install wandb") |
| return None |
|
|
| |
| run = wandb.init( |
| project=project, |
| config=config, |
| name=name, |
| tags=tags or ["grpo", "marxist-leninist"], |
| notes=notes, |
| mode=mode, |
| ) |
|
|
| |
| _define_reward_metrics(run) |
|
|
| print(f"[WandbLogging] Initialized run: {run.name}") |
| print(f"[WandbLogging] View at: {run.url}") |
|
|
| return run |
|
|
|
|
| def _define_reward_metrics(run: Any) -> None: |
| """Define reward metrics with min/max/mean summaries.""" |
| reward_metrics = [ |
| "rewards/format_exact", |
| "rewards/format_approx", |
| "rewards/semantic_similarity", |
| "rewards/terminology", |
| "rewards/nli_coherence", |
| "rewards/self_consistency", |
| "rewards/structural_coherence", |
| "rewards/topic_relevance", |
| "rewards/interconnection_depth", |
| "rewards/completeness", |
| "rewards/total", |
| ] |
|
|
| for metric in reward_metrics: |
| |
| run.define_metric(metric, summary="mean") |
| run.define_metric(f"{metric}_min", summary="min") |
| run.define_metric(f"{metric}_max", summary="max") |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class RewardSample: |
| """A single sample with its reward breakdown.""" |
|
|
| step: int |
| question: str |
| response: str |
| ground_truth: str |
| rewards: dict[str, float] |
|
|
| @property |
| def total_reward(self) -> float: |
| """Sum of all rewards.""" |
| return sum(self.rewards.values()) |
|
|
|
|
| @dataclass |
| class WandbSampleLogger: |
| """ |
| Logs sample tables to W&B for debugging reward functions. |
| |
| Accumulates samples during training and logs them as a wandb.Table |
| every N steps. This lets you inspect actual model outputs and |
| understand why specific rewards were assigned. |
| |
| Example table: |
| | step | question | response | ground_truth | format | nli | topic | depth | total | |
| |------|----------|----------|--------------|--------|-----|-------|-------|-------| |
| | 50 | What is..| The bour.| Revisionism..| 3.0 | 2.5 | 1.5 | 1.0 | 8.0 | |
| """ |
|
|
| log_every_n_steps: int = 10 |
| max_samples_per_log: int = 4 |
| _samples: list[RewardSample] = field(default_factory=list) |
| _step_counter: int = field(default=0) |
| _table_columns: list[str] = field( |
| default_factory=lambda: [ |
| "step", |
| "question", |
| "response", |
| "ground_truth", |
| "format_exact", |
| "format_approx", |
| "nli_coherence", |
| "topic_relevance", |
| "depth", |
| "completeness", |
| "total", |
| ] |
| ) |
|
|
| def add_sample( |
| self, |
| step: int, |
| question: str, |
| response: str, |
| ground_truth: str, |
| rewards: dict[str, float], |
| ) -> None: |
| """Add a sample to the buffer.""" |
| sample = RewardSample( |
| step=step, |
| question=question[:500], |
| response=response[:500], |
| ground_truth=ground_truth[:300], |
| rewards=rewards, |
| ) |
| self._samples.append(sample) |
|
|
| |
| max_buffer = self.max_samples_per_log * 3 |
| if len(self._samples) > max_buffer: |
| self._samples = self._samples[-max_buffer:] |
|
|
| def should_log(self, step: int) -> bool: |
| """Check if we should log at this step.""" |
| return step > 0 and step % self.log_every_n_steps == 0 |
|
|
| def log_table(self, step: int) -> None: |
| """Log accumulated samples as a wandb.Table.""" |
| wandb = _get_wandb() |
| if wandb is None or not self._samples: |
| return |
|
|
| |
| samples_to_log = self._samples[-self.max_samples_per_log :] |
|
|
| |
| table = wandb.Table(columns=self._table_columns) |
|
|
| for sample in samples_to_log: |
| row = [ |
| sample.step, |
| sample.question, |
| sample.response, |
| sample.ground_truth, |
| sample.rewards.get("format_exact", 0.0), |
| sample.rewards.get("format_approx", 0.0), |
| sample.rewards.get("nli_coherence", 0.0), |
| sample.rewards.get("topic_relevance", 0.0), |
| sample.rewards.get("interconnection_depth", 0.0), |
| sample.rewards.get("completeness", 0.0), |
| sample.total_reward, |
| ] |
| table.add_data(*row) |
|
|
| |
| wandb.log({"samples": table}, step=step) |
| print(f"[WandbLogging] Logged {len(samples_to_log)} samples at step {step}") |
|
|
| def clear(self) -> None: |
| """Clear the sample buffer.""" |
| self._samples.clear() |
|
|
|
|
| |
| |
| |
|
|
|
|
| def log_reward_metrics( |
| step: int, |
| reward_scores: dict[str, list[float]], |
| ) -> None: |
| """ |
| Log reward metrics to wandb. |
| |
| Args: |
| step: Training step number |
| reward_scores: Dict mapping reward name to list of scores |
| e.g., {"format_exact": [3.0, 3.0, 0.0, 3.0]} |
| """ |
| wandb = _get_wandb() |
| if wandb is None: |
| return |
|
|
| metrics: dict[str, float] = {} |
|
|
| for name, scores in reward_scores.items(): |
| if not scores: |
| continue |
|
|
| mean_score = sum(scores) / len(scores) |
| min_score = min(scores) |
| max_score = max(scores) |
|
|
| metrics[f"rewards/{name}"] = mean_score |
| metrics[f"rewards/{name}_min"] = min_score |
| metrics[f"rewards/{name}_max"] = max_score |
|
|
| |
| if reward_scores: |
| all_totals = [] |
| num_samples = len(next(iter(reward_scores.values()))) |
| for i in range(num_samples): |
| total = sum(scores[i] for scores in reward_scores.values() if i < len(scores)) |
| all_totals.append(total) |
|
|
| if all_totals: |
| metrics["rewards/total"] = sum(all_totals) / len(all_totals) |
| metrics["rewards/total_min"] = min(all_totals) |
| metrics["rewards/total_max"] = max(all_totals) |
|
|
| wandb.log(metrics, step=step) |
|
|
|
|
| |
| |
| |
|
|
| |
| _LOGGING_STEP = 0 |
|
|
|
|
| def create_logging_reward( |
| sample_logger: WandbSampleLogger | None = None, |
| compute_all_rewards: bool = True, |
| ) -> Callable[..., list[float]]: |
| """ |
| Create a reward function that logs metrics and samples to wandb. |
| |
| This replaces debug_print_reward with comprehensive wandb logging. |
| The returned function computes ALL individual rewards internally, |
| logs them to wandb, and returns [0.0] * len(completions) (no training effect). |
| |
| Args: |
| sample_logger: WandbSampleLogger instance for sample table logging |
| compute_all_rewards: If True, compute and log all reward functions |
| |
| Returns: |
| A reward function compatible with GRPOTrainer |
| |
| Example: |
| sample_logger = WandbSampleLogger(log_every_n_steps=10) |
| logging_reward = create_logging_reward(sample_logger) |
| |
| trainer = GRPOTrainer( |
| reward_funcs=[..., logging_reward], |
| ... |
| ) |
| """ |
| global _LOGGING_STEP |
|
|
| def logging_reward( |
| prompts: Sequence[Sequence[dict[str, str]]], |
| completions: Sequence[Sequence[dict[str, str]]], |
| answer: Sequence[str], |
| **kwargs: object, |
| ) -> list[float]: |
| """Log rewards and samples to wandb. Returns 0.0 (no training effect).""" |
| global _LOGGING_STEP |
| _LOGGING_STEP += 1 |
| step = _LOGGING_STEP |
|
|
| wandb = _get_wandb() |
| if wandb is None or wandb.run is None: |
| |
| if step % 10 == 0: |
| print(f"[Step {step}] Q: {prompts[0][-1]['content'][:80]}...") |
| return [0.0] * len(completions) |
|
|
| |
| if compute_all_rewards: |
| reward_scores = _compute_all_reward_scores(prompts, completions, answer, **kwargs) |
| log_reward_metrics(step, reward_scores) |
| else: |
| reward_scores = {} |
|
|
| |
| if sample_logger and sample_logger.should_log(step): |
| |
| for i in range(min(sample_logger.max_samples_per_log, len(prompts))): |
| question = prompts[i][-1]["content"] |
| response = completions[i][0]["content"] |
| truth = answer[i] if i < len(answer) else "" |
|
|
| |
| sample_rewards = { |
| name: scores[i] if i < len(scores) else 0.0 |
| for name, scores in reward_scores.items() |
| } |
|
|
| sample_logger.add_sample( |
| step=step, |
| question=question, |
| response=response, |
| ground_truth=truth, |
| rewards=sample_rewards, |
| ) |
|
|
| sample_logger.log_table(step) |
|
|
| return [0.0] * len(completions) |
|
|
| return logging_reward |
|
|
|
|
| def _compute_all_reward_scores( |
| prompts: Sequence[Sequence[dict[str, str]]], |
| completions: Sequence[Sequence[dict[str, str]]], |
| answer: Sequence[str], |
| **kwargs: object, |
| ) -> dict[str, list[float]]: |
| """ |
| Compute all reward function scores for logging. |
| |
| Returns dict mapping reward name to list of scores. |
| """ |
| |
| from prolewiki_llm.grpo_rewards import ( |
| completeness_reward, |
| interconnection_depth_reward, |
| match_format_approximately, |
| match_format_exactly, |
| nli_coherence_reward, |
| topic_relevance_reward, |
| ) |
|
|
| reward_scores: dict[str, list[float]] = {} |
|
|
| |
| try: |
| reward_scores["format_exact"] = match_format_exactly(completions, **kwargs) |
| except Exception as e: |
| print(f"[WandbLogging] Error in format_exact: {e}") |
| reward_scores["format_exact"] = [0.0] * len(completions) |
|
|
| try: |
| reward_scores["format_approx"] = match_format_approximately(completions, **kwargs) |
| except Exception as e: |
| print(f"[WandbLogging] Error in format_approx: {e}") |
| reward_scores["format_approx"] = [0.0] * len(completions) |
|
|
| |
| try: |
| reward_scores["nli_coherence"] = nli_coherence_reward(completions, answer, **kwargs) |
| except Exception as e: |
| print(f"[WandbLogging] Error in nli_coherence: {e}") |
| reward_scores["nli_coherence"] = [0.0] * len(completions) |
|
|
| |
| try: |
| reward_scores["topic_relevance"] = topic_relevance_reward(prompts, completions, **kwargs) |
| except Exception as e: |
| print(f"[WandbLogging] Error in topic_relevance: {e}") |
| reward_scores["topic_relevance"] = [0.0] * len(completions) |
|
|
| |
| try: |
| reward_scores["interconnection_depth"] = interconnection_depth_reward(completions, **kwargs) |
| except Exception as e: |
| print(f"[WandbLogging] Error in interconnection_depth: {e}") |
| reward_scores["interconnection_depth"] = [0.0] * len(completions) |
|
|
| |
| try: |
| reward_scores["completeness"] = completeness_reward(completions, answer, **kwargs) |
| except Exception as e: |
| print(f"[WandbLogging] Error in completeness: {e}") |
| reward_scores["completeness"] = [0.0] * len(completions) |
|
|
| return reward_scores |
|
|
|
|
| |
| |
| |
|
|
|
|
| def finish_wandb_logging(summary: dict[str, Any] | None = None) -> None: |
| """ |
| Finish the wandb run with optional summary statistics. |
| |
| Args: |
| summary: Optional dict of final summary metrics |
| """ |
| wandb = _get_wandb() |
| if wandb is None or wandb.run is None: |
| return |
|
|
| if summary: |
| for key, value in summary.items(): |
| wandb.run.summary[key] = value |
|
|
| wandb.finish() |
| print("[WandbLogging] Run finished.") |
|
|
|
|
| def log_model_checkpoint( |
| checkpoint_path: str, |
| metadata: dict[str, Any] | None = None, |
| ) -> None: |
| """ |
| Log a model checkpoint as a wandb artifact. |
| |
| Args: |
| checkpoint_path: Path to the checkpoint directory |
| metadata: Optional metadata about the checkpoint |
| """ |
| wandb = _get_wandb() |
| if wandb is None or wandb.run is None: |
| return |
|
|
| artifact = wandb.Artifact( |
| name=f"checkpoint-{wandb.run.name}", |
| type="model", |
| metadata=metadata or {}, |
| ) |
| artifact.add_dir(checkpoint_path) |
| wandb.log_artifact(artifact) |
| print(f"[WandbLogging] Logged checkpoint: {checkpoint_path}") |
|
|