File size: 3,587 Bytes
48d3037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Alignment Comparison — Base vs SFT vs RLHF
Course: 380 LLM Post-training ch1
"""

import json
import os
import gradio as gr

# Load cases
DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "cases.json")
with open(DATA_PATH, encoding="utf-8") as f:
    CASES = json.load(f)

CATEGORIES = sorted(set(c["category"] for c in CASES))
CATEGORY_LABELS = {
    "safety": "Safety",
    "helpfulness": "Helpfulness",
    "honesty": "Honesty",
    "bias": "Bias & Fairness",
    "empathy": "Empathy",
}


def get_case(index: int, category: str):
    """Return the case at the given index, optionally filtered by category."""
    filtered = CASES if category == "All" else [c for c in CASES if c["category"] == category]
    if not filtered:
        return "", "", "", "", "", 0, 0
    idx = index % len(filtered)
    case = filtered[idx]
    return (
        f"### Question\n\n{case['question']}",
        case["base_response"],
        case["sft_response"],
        case["rlhf_response"],
        f"### Teaching Point\n\n{case['teaching_point']}\n\n*Category: {CATEGORY_LABELS.get(case['category'], case['category'])}*",
        idx,
        len(filtered),
    )


def navigate(index, category, direction):
    new_index = index + direction
    q, base, sft, rlhf, tp, idx, total = get_case(new_index, category)
    progress = f"Case {idx + 1} / {total}"
    return q, base, sft, rlhf, tp, new_index, progress


# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Alignment Comparison") as demo:
    gr.Markdown(
        "# Alignment Comparison — Base vs SFT vs RLHF\n"
        "See how different training stages change model behavior on the same question.\n"
        "*Course: 380 LLM Post-training — Why Alignment Matters*"
    )

    state_index = gr.State(0)

    with gr.Row():
        cat_filter = gr.Dropdown(
            ["All"] + CATEGORIES,
            value="All",
            label="Filter by Category",
        )
        progress_text = gr.Textbox(
            value=f"Case 1 / {len(CASES)}", label="Progress", interactive=False
        )

    question_md = gr.Markdown()

    with gr.Row(equal_height=True):
        with gr.Column():
            gr.Markdown("#### Base Model (Pre-training only)")
            base_out = gr.Textbox(lines=10, interactive=False, show_label=False)
        with gr.Column():
            gr.Markdown("#### SFT (Supervised Fine-Tuning)")
            sft_out = gr.Textbox(lines=10, interactive=False, show_label=False)
        with gr.Column():
            gr.Markdown("#### RLHF (Reinforcement Learning from Human Feedback)")
            rlhf_out = gr.Textbox(lines=10, interactive=False, show_label=False)

    teaching_md = gr.Markdown()

    with gr.Row():
        prev_btn = gr.Button("Previous", size="sm")
        next_btn = gr.Button("Next Case", variant="primary")

    # Navigation
    def go_prev(idx, cat):
        return navigate(idx, cat, -1)

    def go_next(idx, cat):
        return navigate(idx, cat, 1)

    outputs = [question_md, base_out, sft_out, rlhf_out, teaching_md, state_index, progress_text]

    prev_btn.click(go_prev, [state_index, cat_filter], outputs)
    next_btn.click(go_next, [state_index, cat_filter], outputs)

    def on_filter_change(cat):
        return navigate(0, cat, 0)

    cat_filter.change(on_filter_change, [cat_filter], outputs)

    # Initialize
    demo.load(lambda: navigate(0, "All", 0), outputs=outputs)

if __name__ == "__main__":
    demo.launch()