from pathlib import Path import anywidget import traitlets _ASSET_DIR = Path(__file__).parent _JS_SOURCE = (_ASSET_DIR / "widget.js").read_text() _CSS_SOURCE = (_ASSET_DIR / "widget.css").read_text() class GrpoGdpoWidget(anywidget.AnyWidget): """Interactive widget for comparing GRPO vs GDPO advantage calculations. Allows users to toggle binary rewards (correctness, style, conciseness) and see how the two normalization approaches differ. """ _esm = _JS_SOURCE _css = _CSS_SOURCE rewards = traitlets.List( traitlets.Dict(), default_value=[ {"correctness": 1, "style": 0, "conciseness": 0}, {"correctness": 1, "style": 0, "conciseness": 1}, {"correctness": 0, "style": 1, "conciseness": 1}, {"correctness": 0, "style": 1, "conciseness": 1}, {"correctness": 1, "style": 0, "conciseness": 0}, {"correctness": 0, "style": 1, "conciseness": 0}, {"correctness": 1, "style": 1, "conciseness": 0}, {"correctness": 0, "style": 0, "conciseness": 1}, {"correctness": 1, "style": 1, "conciseness": 1}, {"correctness": 0, "style": 0, "conciseness": 0}, {"correctness": 1, "style": 0, "conciseness": 1}, {"correctness": 0, "style": 1, "conciseness": 0}, ] ).tag(sync=True) def add_rollout(self): """Add a new rollout with default values.""" self.rewards = self.rewards + [{"correctness": 0, "style": 0, "conciseness": 0}] def remove_rollout(self, index=-1): """Remove a rollout by index (default: last).""" if len(self.rewards) > 2: rewards = list(self.rewards) rewards.pop(index) self.rewards = rewards