jeffliulab's picture
Initial deploy
48d3037 verified
"""
Alignment Comparison — Base vs SFT vs RLHF
Course: 380 LLM Post-training ch1
"""
import json
import os
import gradio as gr
# Load cases
DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "cases.json")
with open(DATA_PATH, encoding="utf-8") as f:
CASES = json.load(f)
CATEGORIES = sorted(set(c["category"] for c in CASES))
CATEGORY_LABELS = {
"safety": "Safety",
"helpfulness": "Helpfulness",
"honesty": "Honesty",
"bias": "Bias & Fairness",
"empathy": "Empathy",
}
def get_case(index: int, category: str):
"""Return the case at the given index, optionally filtered by category."""
filtered = CASES if category == "All" else [c for c in CASES if c["category"] == category]
if not filtered:
return "", "", "", "", "", 0, 0
idx = index % len(filtered)
case = filtered[idx]
return (
f"### Question\n\n{case['question']}",
case["base_response"],
case["sft_response"],
case["rlhf_response"],
f"### Teaching Point\n\n{case['teaching_point']}\n\n*Category: {CATEGORY_LABELS.get(case['category'], case['category'])}*",
idx,
len(filtered),
)
def navigate(index, category, direction):
new_index = index + direction
q, base, sft, rlhf, tp, idx, total = get_case(new_index, category)
progress = f"Case {idx + 1} / {total}"
return q, base, sft, rlhf, tp, new_index, progress
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Alignment Comparison") as demo:
gr.Markdown(
"# Alignment Comparison — Base vs SFT vs RLHF\n"
"See how different training stages change model behavior on the same question.\n"
"*Course: 380 LLM Post-training — Why Alignment Matters*"
)
state_index = gr.State(0)
with gr.Row():
cat_filter = gr.Dropdown(
["All"] + CATEGORIES,
value="All",
label="Filter by Category",
)
progress_text = gr.Textbox(
value=f"Case 1 / {len(CASES)}", label="Progress", interactive=False
)
question_md = gr.Markdown()
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("#### Base Model (Pre-training only)")
base_out = gr.Textbox(lines=10, interactive=False, show_label=False)
with gr.Column():
gr.Markdown("#### SFT (Supervised Fine-Tuning)")
sft_out = gr.Textbox(lines=10, interactive=False, show_label=False)
with gr.Column():
gr.Markdown("#### RLHF (Reinforcement Learning from Human Feedback)")
rlhf_out = gr.Textbox(lines=10, interactive=False, show_label=False)
teaching_md = gr.Markdown()
with gr.Row():
prev_btn = gr.Button("Previous", size="sm")
next_btn = gr.Button("Next Case", variant="primary")
# Navigation
def go_prev(idx, cat):
return navigate(idx, cat, -1)
def go_next(idx, cat):
return navigate(idx, cat, 1)
outputs = [question_md, base_out, sft_out, rlhf_out, teaching_md, state_index, progress_text]
prev_btn.click(go_prev, [state_index, cat_filter], outputs)
next_btn.click(go_next, [state_index, cat_filter], outputs)
def on_filter_change(cat):
return navigate(0, cat, 0)
cat_filter.change(on_filter_change, [cat_filter], outputs)
# Initialize
demo.load(lambda: navigate(0, "All", 0), outputs=outputs)
if __name__ == "__main__":
demo.launch()