Spaces:
Sleeping
Sleeping
| """ | |
| Alignment Comparison — Base vs SFT vs RLHF | |
| Course: 380 LLM Post-training ch1 | |
| """ | |
| import json | |
| import os | |
| import gradio as gr | |
| # Load cases | |
| DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "cases.json") | |
| with open(DATA_PATH, encoding="utf-8") as f: | |
| CASES = json.load(f) | |
| CATEGORIES = sorted(set(c["category"] for c in CASES)) | |
| CATEGORY_LABELS = { | |
| "safety": "Safety", | |
| "helpfulness": "Helpfulness", | |
| "honesty": "Honesty", | |
| "bias": "Bias & Fairness", | |
| "empathy": "Empathy", | |
| } | |
| def get_case(index: int, category: str): | |
| """Return the case at the given index, optionally filtered by category.""" | |
| filtered = CASES if category == "All" else [c for c in CASES if c["category"] == category] | |
| if not filtered: | |
| return "", "", "", "", "", 0, 0 | |
| idx = index % len(filtered) | |
| case = filtered[idx] | |
| return ( | |
| f"### Question\n\n{case['question']}", | |
| case["base_response"], | |
| case["sft_response"], | |
| case["rlhf_response"], | |
| f"### Teaching Point\n\n{case['teaching_point']}\n\n*Category: {CATEGORY_LABELS.get(case['category'], case['category'])}*", | |
| idx, | |
| len(filtered), | |
| ) | |
| def navigate(index, category, direction): | |
| new_index = index + direction | |
| q, base, sft, rlhf, tp, idx, total = get_case(new_index, category) | |
| progress = f"Case {idx + 1} / {total}" | |
| return q, base, sft, rlhf, tp, new_index, progress | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="Alignment Comparison") as demo: | |
| gr.Markdown( | |
| "# Alignment Comparison — Base vs SFT vs RLHF\n" | |
| "See how different training stages change model behavior on the same question.\n" | |
| "*Course: 380 LLM Post-training — Why Alignment Matters*" | |
| ) | |
| state_index = gr.State(0) | |
| with gr.Row(): | |
| cat_filter = gr.Dropdown( | |
| ["All"] + CATEGORIES, | |
| value="All", | |
| label="Filter by Category", | |
| ) | |
| progress_text = gr.Textbox( | |
| value=f"Case 1 / {len(CASES)}", label="Progress", interactive=False | |
| ) | |
| question_md = gr.Markdown() | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| gr.Markdown("#### Base Model (Pre-training only)") | |
| base_out = gr.Textbox(lines=10, interactive=False, show_label=False) | |
| with gr.Column(): | |
| gr.Markdown("#### SFT (Supervised Fine-Tuning)") | |
| sft_out = gr.Textbox(lines=10, interactive=False, show_label=False) | |
| with gr.Column(): | |
| gr.Markdown("#### RLHF (Reinforcement Learning from Human Feedback)") | |
| rlhf_out = gr.Textbox(lines=10, interactive=False, show_label=False) | |
| teaching_md = gr.Markdown() | |
| with gr.Row(): | |
| prev_btn = gr.Button("Previous", size="sm") | |
| next_btn = gr.Button("Next Case", variant="primary") | |
| # Navigation | |
| def go_prev(idx, cat): | |
| return navigate(idx, cat, -1) | |
| def go_next(idx, cat): | |
| return navigate(idx, cat, 1) | |
| outputs = [question_md, base_out, sft_out, rlhf_out, teaching_md, state_index, progress_text] | |
| prev_btn.click(go_prev, [state_index, cat_filter], outputs) | |
| next_btn.click(go_next, [state_index, cat_filter], outputs) | |
| def on_filter_change(cat): | |
| return navigate(0, cat, 0) | |
| cat_filter.change(on_filter_change, [cat_filter], outputs) | |
| # Initialize | |
| demo.load(lambda: navigate(0, "All", 0), outputs=outputs) | |
| if __name__ == "__main__": | |
| demo.launch() | |