""" Alignment Comparison — Base vs SFT vs RLHF Course: 380 LLM Post-training ch1 """ import json import os import gradio as gr # Load cases DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "cases.json") with open(DATA_PATH, encoding="utf-8") as f: CASES = json.load(f) CATEGORIES = sorted(set(c["category"] for c in CASES)) CATEGORY_LABELS = { "safety": "Safety", "helpfulness": "Helpfulness", "honesty": "Honesty", "bias": "Bias & Fairness", "empathy": "Empathy", } def get_case(index: int, category: str): """Return the case at the given index, optionally filtered by category.""" filtered = CASES if category == "All" else [c for c in CASES if c["category"] == category] if not filtered: return "", "", "", "", "", 0, 0 idx = index % len(filtered) case = filtered[idx] return ( f"### Question\n\n{case['question']}", case["base_response"], case["sft_response"], case["rlhf_response"], f"### Teaching Point\n\n{case['teaching_point']}\n\n*Category: {CATEGORY_LABELS.get(case['category'], case['category'])}*", idx, len(filtered), ) def navigate(index, category, direction): new_index = index + direction q, base, sft, rlhf, tp, idx, total = get_case(new_index, category) progress = f"Case {idx + 1} / {total}" return q, base, sft, rlhf, tp, new_index, progress # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- with gr.Blocks(title="Alignment Comparison") as demo: gr.Markdown( "# Alignment Comparison — Base vs SFT vs RLHF\n" "See how different training stages change model behavior on the same question.\n" "*Course: 380 LLM Post-training — Why Alignment Matters*" ) state_index = gr.State(0) with gr.Row(): cat_filter = gr.Dropdown( ["All"] + CATEGORIES, value="All", label="Filter by Category", ) progress_text = gr.Textbox( value=f"Case 1 / {len(CASES)}", label="Progress", interactive=False ) question_md = gr.Markdown() with gr.Row(equal_height=True): with gr.Column(): gr.Markdown("#### Base Model (Pre-training only)") base_out = gr.Textbox(lines=10, interactive=False, show_label=False) with gr.Column(): gr.Markdown("#### SFT (Supervised Fine-Tuning)") sft_out = gr.Textbox(lines=10, interactive=False, show_label=False) with gr.Column(): gr.Markdown("#### RLHF (Reinforcement Learning from Human Feedback)") rlhf_out = gr.Textbox(lines=10, interactive=False, show_label=False) teaching_md = gr.Markdown() with gr.Row(): prev_btn = gr.Button("Previous", size="sm") next_btn = gr.Button("Next Case", variant="primary") # Navigation def go_prev(idx, cat): return navigate(idx, cat, -1) def go_next(idx, cat): return navigate(idx, cat, 1) outputs = [question_md, base_out, sft_out, rlhf_out, teaching_md, state_index, progress_text] prev_btn.click(go_prev, [state_index, cat_filter], outputs) next_btn.click(go_next, [state_index, cat_filter], outputs) def on_filter_change(cat): return navigate(0, cat, 0) cat_filter.change(on_filter_change, [cat_filter], outputs) # Initialize demo.load(lambda: navigate(0, "All", 0), outputs=outputs) if __name__ == "__main__": demo.launch()