File size: 1,777 Bytes
3c7c02f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from pathlib import Path

import anywidget
import traitlets


_ASSET_DIR = Path(__file__).parent
_JS_SOURCE = (_ASSET_DIR / "widget.js").read_text()
_CSS_SOURCE = (_ASSET_DIR / "widget.css").read_text()


class GrpoGdpoWidget(anywidget.AnyWidget):
    """Interactive widget for comparing GRPO vs GDPO advantage calculations.

    Allows users to toggle binary rewards (correctness, style, conciseness)
    and see how the two normalization approaches differ.
    """
    _esm = _JS_SOURCE
    _css = _CSS_SOURCE

    rewards = traitlets.List(
        traitlets.Dict(),
        default_value=[
            {"correctness": 1, "style": 0, "conciseness": 0},
            {"correctness": 1, "style": 0, "conciseness": 1},
            {"correctness": 0, "style": 1, "conciseness": 1},
            {"correctness": 0, "style": 1, "conciseness": 1},
            {"correctness": 1, "style": 0, "conciseness": 0},
            {"correctness": 0, "style": 1, "conciseness": 0},
            {"correctness": 1, "style": 1, "conciseness": 0},
            {"correctness": 0, "style": 0, "conciseness": 1},
            {"correctness": 1, "style": 1, "conciseness": 1},
            {"correctness": 0, "style": 0, "conciseness": 0},
            {"correctness": 1, "style": 0, "conciseness": 1},
            {"correctness": 0, "style": 1, "conciseness": 0},
        ]
    ).tag(sync=True)

    def add_rollout(self):
        """Add a new rollout with default values."""
        self.rewards = self.rewards + [{"correctness": 0, "style": 0, "conciseness": 0}]

    def remove_rollout(self, index=-1):
        """Remove a rollout by index (default: last)."""
        if len(self.rewards) > 2:
            rewards = list(self.rewards)
            rewards.pop(index)
            self.rewards = rewards